Skip to content

Commit eab9358

Browse files
committed
update for keras 3.0
1 parent 721d187 commit eab9358

File tree

7 files changed

+445
-13
lines changed

7 files changed

+445
-13
lines changed

kgcnn/layers/message.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,21 @@ class MessagePassingBase(ks.layers.Layer):
1515
The original message passing scheme was proposed by `NMPNN <http://arxiv.org/abs/1704.01212>`__ .
1616
"""
1717

18-
def __init__(self, pooling_method: str = "scatter_sum", **kwargs):
18+
def __init__(self,
19+
pooling_method: str = "scatter_sum",
20+
use_id_tensors: int = None,
21+
**kwargs):
1922
r"""Initialize :obj:`MessagePassingBase` layer.
2023
2124
Args:
2225
pooling_method (str): Aggregation method of edges. Default is "sum".
26+
use_id_tensors (int): Whether :obj:`call` receives graph ID information, which is passed onto message and
27+
aggregation function. Specifies the number of additional tensors.
2328
"""
2429
super(MessagePassingBase, self).__init__(**kwargs)
2530
self.pooling_method = pooling_method
2631
self.lay_gather = GatherNodes(concat_axis=None)
32+
self.use_id_tensors = use_id_tensors
2733
self.lay_pool_default = AggregateLocalEdges(pooling_method=self.pooling_method)
2834

2935
def build(self, input_shape):
@@ -90,27 +96,34 @@ def call(self, inputs, **kwargs):
9096
Returns:
9197
Tensor: Updated node embeddings of shape ([N], F)
9298
"""
99+
if self.use_id_tensors is not None:
100+
ids = inputs[-int(self.use_id_tensors):]
101+
inputs = inputs[:-int(self.use_id_tensors)]
102+
else:
103+
ids = []
104+
93105
if len(inputs) == 2:
94-
nodes, edge_index = inputs
106+
nodes, edge_index = inputs[:2]
95107
edges = None
96108
else:
97-
nodes, edges, edge_index = inputs
109+
nodes, edges, edge_index = inputs[:3]
98110

99111
n_in, n_out = self.lay_gather([nodes, edge_index], **kwargs)
100112

101113
if edges is None:
102-
msg = self.message_function([n_in, n_out], **kwargs)
114+
msg = self.message_function([n_in, n_out] + ids, **kwargs)
103115
else:
104-
msg = self.message_function([n_in, n_out, edges], **kwargs)
116+
msg = self.message_function([n_in, n_out, edges] + ids, **kwargs)
105117

106118
pool_n = self.aggregate_message([nodes, msg, edge_index], **kwargs)
107-
n_new = self.update_nodes([nodes, pool_n], **kwargs)
119+
120+
n_new = self.update_nodes([nodes, pool_n] + ids, **kwargs)
108121
return n_new
109122

110123
def get_config(self):
111124
"""Update config."""
112125
config = super(MessagePassingBase, self).get_config()
113-
config.update({"pooling_method": self.pooling_method})
126+
config.update({"pooling_method": self.pooling_method, "use_id_tensors": self.use_id_tensors})
114127
return config
115128

116129

@@ -125,7 +138,6 @@ class MatMulMessages(ks.layers.Layer):
125138
.. math::
126139
127140
x_i' = \mathbf{A_i} \; x_i
128-
129141
"""
130142

131143
def __init__(self, **kwargs):

kgcnn/layers/norm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,8 @@ def call(self, inputs, **kwargs):
223223
inputs (list): `[values, graph_id, reference]` .
224224
225225
- values (Tensor): Tensor to normalize of shape `(None, F, ...)` .
226-
- graph_size (Tensor, optional): Size of each graph for nodes in disjoint batch of shape `(batch, )` .
226+
- graph_id (Tensor): Tensor of graph IDs of shape `(None, )` .
227+
- reference (Tensor, optional): Graph reference of disjoint batch of shape `(batch, )` .
227228
228229
Returns:
229230
Tensor: Normalized tensor of identical shape (None, F, ...)

kgcnn/literature/CGCNN/__init__.py

Whitespace-only changes.

kgcnn/literature/CGCNN/_layers.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import keras_core as ks
2+
from kgcnn.layers.message import MessagePassingBase
3+
from kgcnn.layers.norm import GraphBatchNormalization
4+
from keras_core.layers import Activation, Multiply, Concatenate, Add, Dense
5+
6+
7+
class CGCNNLayer(MessagePassingBase):
8+
r"""Message Passing Layer used in the Crystal Graph Convolutional Neural Network:
9+
10+
`CGCNN <https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.120.145301>`__ .
11+
12+
Based on the original code in pytorch (<https://github.com/txie-93/cgcnn>).
13+
14+
Args:
15+
units (int): Units for Dense layer.
16+
activation_s (str): Tensorflow activation applied before gating the message.
17+
activation_out (str): Tensorflow activation applied the very end of the layer (after gating).
18+
batch_normalization (bool): Whether to use batch normalization (:obj:`GraphBatchNormalization`) or not.
19+
use_bias (bool): Boolean, whether the layer uses a bias vector. Default is True.
20+
kernel_initializer: Initializer for the `kernel` weights matrix. Default is "glorot_uniform".
21+
bias_initializer: Initializer for the bias vector. Default is "zeros".
22+
padded_disjoint: Whether disjoint tensors have padded nodes. Default if False.
23+
kernel_regularizer: Regularizer function applied to
24+
the `kernel` weights matrix. Default is None.
25+
bias_regularizer: Regularizer function applied to the bias vector. Default is None.
26+
activity_regularizer: Regularizer function applied to
27+
the output of the layer (its "activation"). Default is None.
28+
kernel_constraint: Constraint function applied to
29+
the `kernel` weights matrix. Default is None.
30+
bias_constraint: Constraint function applied to the bias vector. Default is None.
31+
"""
32+
33+
def __init__(self, units: int = 64,
34+
activation_s="softplus",
35+
activation_out="softplus",
36+
batch_normalization: bool = False,
37+
use_bias: bool = True,
38+
kernel_regularizer=None,
39+
padded_disjoint: bool = False,
40+
bias_regularizer=None,
41+
activity_regularizer=None,
42+
kernel_constraint=None,
43+
bias_constraint=None,
44+
kernel_initializer='glorot_uniform',
45+
bias_initializer='zeros',
46+
**kwargs):
47+
super(CGCNNLayer, self).__init__(use_id_tensors=4, **kwargs)
48+
self.units = units
49+
self.use_bias = use_bias
50+
self.padded_disjoint = padded_disjoint
51+
self.batch_normalization = batch_normalization
52+
kernel_args = {"kernel_regularizer": kernel_regularizer, "bias_regularizer": bias_regularizer,
53+
"kernel_constraint": kernel_constraint, "bias_constraint": bias_constraint,
54+
"kernel_initializer": kernel_initializer, "bias_initializer": bias_initializer}
55+
56+
self.activation_f_layer = Activation(activation="sigmoid", activity_regularizer=activity_regularizer)
57+
self.activation_s_layer = Activation(activation_s, activity_regularizer=activity_regularizer)
58+
self.activation_out_layer = Activation(activation_out, activity_regularizer=activity_regularizer)
59+
if batch_normalization:
60+
self.batch_norm_f = GraphBatchNormalization(padded_disjoint=padded_disjoint)
61+
self.batch_norm_s = GraphBatchNormalization(padded_disjoint=padded_disjoint)
62+
self.batch_norm_out = GraphBatchNormalization(padded_disjoint=padded_disjoint)
63+
self.f = Dense(self.units, activation="linear", use_bias=use_bias, **kernel_args)
64+
self.s = Dense(self.units, activation="linear", use_bias=use_bias, **kernel_args)
65+
self.lazy_mult = Multiply()
66+
self.lazy_add = Add()
67+
self.lazy_concat = Concatenate(axis=2)
68+
69+
def message_function(self, inputs, **kwargs):
70+
r"""Prepare messages.
71+
72+
Args:
73+
inputs: [nodes_in, nodes_out, edges, graph_id_node, graph_id_edge, count_nodes, count_edges]
74+
75+
- nodes_in (Tensor): Embedding of sending nodes of shape `([M], F)`
76+
- nodes_out (Tensor): Embedding of sending nodes of shape `([M], F)`
77+
- edges (Tensor): Embedding of edges of shape `([M], E)`
78+
- graph_id_node (Tensor): ID tensor of batch assignment in disjoint graph of shape `([N], )` .
79+
- graph_id_edge (Tensor): ID tensor of batch assignment in disjoint graph of shape `([M], )` .
80+
- nodes_count (Tensor): Tensor of number of nodes for each graph of shape `(batch, )` .
81+
- edges_count (Tensor): Tensor of number of edges for each graph of shape `(batch, )` .
82+
83+
Returns:
84+
Tensor: Messages for updates of shape `([M], units)`.
85+
"""
86+
87+
nodes_in = inputs[0] # shape: (batch_size, M, F)
88+
nodes_out = inputs[1] # shape: (batch_size, M, F)
89+
edge_features = inputs[2] # shape: (batch_size, M, E)
90+
graph_id_node, graph_id_edge, count_nodes, count_edges = inputs[3:]
91+
92+
x = self.lazy_concat([nodes_in, nodes_out, edge_features], **kwargs)
93+
x_s, x_f = self.s(x, **kwargs), self.f(x, **kwargs)
94+
if self.batch_normalization:
95+
x_s = self.batch_norm_s([x_s, graph_id_edge, count_edges], **kwargs)
96+
x_f = self.batch_norm_f([x_f, graph_id_edge, count_edges], **kwargs)
97+
x_s, x_f = self.activation_s_layer(x_s, **kwargs), self.activation_f_layer(x_f, **kwargs)
98+
x_out = self.lazy_mult([x_s, x_f], **kwargs) # shape: (batch_size, M, self.units)
99+
return x_out
100+
101+
def update_nodes(self, inputs, **kwargs):
102+
"""Update node embeddings.
103+
104+
Args:
105+
inputs: [nodes, nodes_updates, graph_id_node, graph_id_edge, count_nodes, count_edges]
106+
107+
- nodes (Tensor): Embedding of nodes of previous layer of shape `([M], F)`
108+
- nodes_updates (Tensor): Node updates of shape `([M], F)`
109+
- graph_id_node (Tensor): ID tensor of batch assignment in disjoint graph of shape `([N], )` .
110+
- graph_id_edge (Tensor): ID tensor of batch assignment in disjoint graph of shape `([M], )` .
111+
- nodes_count (Tensor): Tensor of number of nodes for each graph of shape `(batch, )` .
112+
- edges_count (Tensor): Tensor of number of edges for each graph of shape `(batch, )` .
113+
114+
Returns:
115+
Tensor: Updated nodes of shape `([N], F)`.
116+
"""
117+
nodes = inputs[0]
118+
nodes_update = inputs[1]
119+
graph_id_node, graph_id_edge, count_nodes, count_edges = inputs[2:]
120+
121+
if self.batch_normalization:
122+
nodes_update = self.batch_norm_out([nodes_update, graph_id_node, count_nodes], **kwargs)
123+
124+
nodes_updated = self.lazy_add([nodes, nodes_update], **kwargs)
125+
nodes_updated = self.activation_out_layer(nodes_updated, **kwargs)
126+
return nodes_updated
127+
128+
def get_config(self):
129+
"""Update layer config."""
130+
config = super(CGCNNLayer, self).get_config()
131+
config.update({
132+
"units": self.units, "use_bias": self.use_bias, "padded_disjoint": self.padded_disjoint,
133+
"batch_normalization": self.batch_normalization})
134+
conf_s = self.activation_s_layer.get_config()
135+
conf_out = self.activation_out_layer.get_config()
136+
config.update({"activation_s": conf_s["activation"]})
137+
config.update({"activation_out": conf_out["activation"]})
138+
if "activity_regularizer" in conf_out.keys():
139+
config.update({"activity_regularizer": conf_out["activity_regularizer"]})
140+
conf_f = self.f.get_config()
141+
for x in ["kernel_regularizer", "bias_regularizer", "kernel_constraint",
142+
"bias_constraint", "kernel_initializer", "bias_initializer"]:
143+
if x in conf_f.keys():
144+
config.update({x: conf_f[x]})
145+
return config

0 commit comments

Comments
 (0)