Skip to content

Commit ab84f46

Browse files
committed
continue keras core integration
1 parent 49511aa commit ab84f46

File tree

22 files changed

+1469
-1435
lines changed

22 files changed

+1469
-1435
lines changed
File renamed without changes.

docs/source/data.ipynb

Lines changed: 28 additions & 14 deletions
Large diffs are not rendered by default.

docs/source/layers.ipynb

Lines changed: 1 addition & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -48,112 +48,6 @@
4848
"## Implementaion details"
4949
]
5050
},
51-
{
52-
"cell_type": "markdown",
53-
"id": "fc229b02",
54-
"metadata": {},
55-
"source": [
56-
"Most tensorflow methods already support ragged tensors, which can be looked up here: https://www.tensorflow.org/api_docs/python/tf/ragged. \n",
57-
"\n",
58-
"For using keras layers, most layers in `kgcnn` inherit from `kgcnn.layers.base.GraphBaseLayer` which adds some utility methods and arguments, such as the `ragged_validate` parameter that is used for ragged tensor creation."
59-
]
60-
},
61-
{
62-
"cell_type": "code",
63-
"execution_count": 1,
64-
"id": "49ec83af",
65-
"metadata": {},
66-
"outputs": [
67-
{
68-
"name": "stdout",
69-
"output_type": "stream",
70-
"text": [
71-
"True\n"
72-
]
73-
}
74-
],
75-
"source": [
76-
"import tensorflow as tf\n",
77-
"from kgcnn.layers.base import GraphBaseLayer\n",
78-
"\n",
79-
"class NewLayer(GraphBaseLayer):\n",
80-
" \n",
81-
" def __init__(self, **kwargs):\n",
82-
" super().__init__(**kwargs)\n",
83-
" \n",
84-
" def call(self, inputs, **kwargs):\n",
85-
" # Do something in call.\n",
86-
" return inputs\n",
87-
"\n",
88-
"new_layer = NewLayer(ragged_validate=False)\n",
89-
"print(new_layer._supports_ragged_inputs)"
90-
]
91-
},
92-
{
93-
"cell_type": "markdown",
94-
"id": "5c664a24",
95-
"metadata": {},
96-
"source": [
97-
"> **WARNING**: Since ragged tensors can result in quite a performance loss due to shape checks of ragged dimensions on runtime, it is recommended to directly work with the values tensor/information (if possible) and to use `ragged_validate` to `False` in production (for using parallization this might be different). An example is the `tf.ragged.map_flat_values` method. Due to this reason, there are `kgcnn.layers.modules.LazyAdd` and `kgcnn.layers.modules.LazyConcatenate` etc. layers.\n",
98-
"\n",
99-
"Utility methods of `GraphBaseLayer` to work with values directly are `assert_ragged_input_rank` and `map_values` . Note that this can speed up models and is essentially equal to a disjoint graph tensor representation. \n",
100-
"However, with ragged tensors there is also the possibility to try `tf.vectorized_map` or `tf.map_fn` if values can not be accessed.\n",
101-
"\n",
102-
"Here is an example of how to use `assert_ragged_input_rank` and `map_values` . With `assert_ragged_input_rank` it can be ensured that a ragged tensor is given in `call` by casting padded+mask or normal tensor (for example in case of equal sized graphs) to a ragged version, in order to accesss e.g. `inputs.values` etc.\n",
103-
"With `map_values` a function can be directly applied to the values tensor or a list of value tensors. Axis argument can refer to the ragged tensor shape but the `map_values` is restricted to ragged rank of one. Fallback for `map_values` is applying the function directly to its input."
104-
]
105-
},
106-
{
107-
"cell_type": "code",
108-
"execution_count": 2,
109-
"id": "2a02b24c",
110-
"metadata": {},
111-
"outputs": [],
112-
"source": [
113-
"class NewLayer(GraphBaseLayer):\n",
114-
" \n",
115-
" def __init__(self, **kwargs):\n",
116-
" super().__init__(**kwargs)\n",
117-
" \n",
118-
" def call_v1(self, inputs, **kwargs):\n",
119-
" inputs = self.assert_ragged_input_rank(inputs, ragged_rank=1)\n",
120-
" return tf.RaggedTensor.from_row_splits(\n",
121-
" tf.exp(inputs.values), inputs.row_splits, validate=self.ragged_validate)\n",
122-
"\n",
123-
" def call_v2(self, inputs, **kwargs):\n",
124-
" # Possible kwargs for function can be added. Can have axis argument (special case).\n",
125-
" return self.map_values(tf.exp, inputs)\n",
126-
" \n",
127-
"new_layer = NewLayer()"
128-
]
129-
},
130-
{
131-
"cell_type": "code",
132-
"execution_count": 3,
133-
"id": "e4699f10",
134-
"metadata": {},
135-
"outputs": [
136-
{
137-
"name": "stdout",
138-
"output_type": "stream",
139-
"text": [
140-
"x (v1): (2, 3, 4) (2, None, 4)\n",
141-
"x (v2): (2, 3, 4) (2, 3, 4)\n",
142-
"x_ragged (v1): (2, None, 4) (2, None, 4)\n",
143-
"x_ragged (v2): (2, None, 4) (2, None, 4)\n"
144-
]
145-
}
146-
],
147-
"source": [
148-
"x = tf.ones((2, 3, 4))\n",
149-
"x_ragged = tf.ragged.constant([[[1.0, 1.0, 1.0, 1.0]],[[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]], ragged_rank=1)\n",
150-
"\n",
151-
"print(\"x (v1):\", x.shape, new_layer.call_v1(x).shape)\n",
152-
"print(\"x (v2):\", x.shape, new_layer.call_v2(x).shape)\n",
153-
"print(\"x_ragged (v1):\", x_ragged.shape, new_layer.call_v1(x_ragged).shape)\n",
154-
"print(\"x_ragged (v2):\", x_ragged.shape, new_layer.call_v2(x_ragged).shape)"
155-
]
156-
},
15751
{
15852
"cell_type": "markdown",
15953
"id": "212298b7",
@@ -187,7 +81,7 @@
18781
"name": "python",
18882
"nbconvert_exporter": "python",
18983
"pygments_lexer": "ipython3",
190-
"version": "3.9.7"
84+
"version": "3.10.13"
19185
}
19286
},
19387
"nbformat": 4,

0 commit comments

Comments
 (0)