|
| 1 | +import keras_core as ks |
| 2 | +from kgcnn.layers.message import MessagePassingBase |
| 3 | +from kgcnn.layers.norm import GraphBatchNormalization |
| 4 | +from keras_core.layers import Activation, Multiply, Concatenate, Add, Dense |
| 5 | + |
| 6 | + |
| 7 | +class CGCNNLayer(MessagePassingBase): |
| 8 | + r"""Message Passing Layer used in the Crystal Graph Convolutional Neural Network: |
| 9 | +
|
| 10 | + `CGCNN <https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.120.145301>`__ . |
| 11 | +
|
| 12 | + Based on the original code in pytorch (<https://github.com/txie-93/cgcnn>). |
| 13 | +
|
| 14 | + Args: |
| 15 | + units (int): Units for Dense layer. |
| 16 | + activation_s (str): Tensorflow activation applied before gating the message. |
| 17 | + activation_out (str): Tensorflow activation applied the very end of the layer (after gating). |
| 18 | + batch_normalization (bool): Whether to use batch normalization (:obj:`GraphBatchNormalization`) or not. |
| 19 | + use_bias (bool): Boolean, whether the layer uses a bias vector. Default is True. |
| 20 | + kernel_initializer: Initializer for the `kernel` weights matrix. Default is "glorot_uniform". |
| 21 | + bias_initializer: Initializer for the bias vector. Default is "zeros". |
| 22 | + padded_disjoint: Whether disjoint tensors have padded nodes. Default if False. |
| 23 | + kernel_regularizer: Regularizer function applied to |
| 24 | + the `kernel` weights matrix. Default is None. |
| 25 | + bias_regularizer: Regularizer function applied to the bias vector. Default is None. |
| 26 | + activity_regularizer: Regularizer function applied to |
| 27 | + the output of the layer (its "activation"). Default is None. |
| 28 | + kernel_constraint: Constraint function applied to |
| 29 | + the `kernel` weights matrix. Default is None. |
| 30 | + bias_constraint: Constraint function applied to the bias vector. Default is None. |
| 31 | + """ |
| 32 | + |
| 33 | + def __init__(self, units: int = 64, |
| 34 | + activation_s="softplus", |
| 35 | + activation_out="softplus", |
| 36 | + batch_normalization: bool = False, |
| 37 | + use_bias: bool = True, |
| 38 | + kernel_regularizer=None, |
| 39 | + padded_disjoint: bool = False, |
| 40 | + bias_regularizer=None, |
| 41 | + activity_regularizer=None, |
| 42 | + kernel_constraint=None, |
| 43 | + bias_constraint=None, |
| 44 | + kernel_initializer='glorot_uniform', |
| 45 | + bias_initializer='zeros', |
| 46 | + **kwargs): |
| 47 | + super(CGCNNLayer, self).__init__(use_id_tensors=4, **kwargs) |
| 48 | + self.units = units |
| 49 | + self.use_bias = use_bias |
| 50 | + self.padded_disjoint = padded_disjoint |
| 51 | + self.batch_normalization = batch_normalization |
| 52 | + kernel_args = {"kernel_regularizer": kernel_regularizer, "bias_regularizer": bias_regularizer, |
| 53 | + "kernel_constraint": kernel_constraint, "bias_constraint": bias_constraint, |
| 54 | + "kernel_initializer": kernel_initializer, "bias_initializer": bias_initializer} |
| 55 | + |
| 56 | + self.activation_f_layer = Activation(activation="sigmoid", activity_regularizer=activity_regularizer) |
| 57 | + self.activation_s_layer = Activation(activation_s, activity_regularizer=activity_regularizer) |
| 58 | + self.activation_out_layer = Activation(activation_out, activity_regularizer=activity_regularizer) |
| 59 | + if batch_normalization: |
| 60 | + self.batch_norm_f = GraphBatchNormalization(padded_disjoint=padded_disjoint) |
| 61 | + self.batch_norm_s = GraphBatchNormalization(padded_disjoint=padded_disjoint) |
| 62 | + self.batch_norm_out = GraphBatchNormalization(padded_disjoint=padded_disjoint) |
| 63 | + self.f = Dense(self.units, activation="linear", use_bias=use_bias, **kernel_args) |
| 64 | + self.s = Dense(self.units, activation="linear", use_bias=use_bias, **kernel_args) |
| 65 | + self.lazy_mult = Multiply() |
| 66 | + self.lazy_add = Add() |
| 67 | + self.lazy_concat = Concatenate(axis=2) |
| 68 | + |
| 69 | + def message_function(self, inputs, **kwargs): |
| 70 | + r"""Prepare messages. |
| 71 | +
|
| 72 | + Args: |
| 73 | + inputs: [nodes_in, nodes_out, edges, graph_id_node, graph_id_edge, count_nodes, count_edges] |
| 74 | +
|
| 75 | + - nodes_in (Tensor): Embedding of sending nodes of shape `([M], F)` |
| 76 | + - nodes_out (Tensor): Embedding of sending nodes of shape `([M], F)` |
| 77 | + - edges (Tensor): Embedding of edges of shape `([M], E)` |
| 78 | + - graph_id_node (Tensor): ID tensor of batch assignment in disjoint graph of shape `([N], )` . |
| 79 | + - graph_id_edge (Tensor): ID tensor of batch assignment in disjoint graph of shape `([M], )` . |
| 80 | + - nodes_count (Tensor): Tensor of number of nodes for each graph of shape `(batch, )` . |
| 81 | + - edges_count (Tensor): Tensor of number of edges for each graph of shape `(batch, )` . |
| 82 | +
|
| 83 | + Returns: |
| 84 | + Tensor: Messages for updates of shape `([M], units)`. |
| 85 | + """ |
| 86 | + |
| 87 | + nodes_in = inputs[0] # shape: (batch_size, M, F) |
| 88 | + nodes_out = inputs[1] # shape: (batch_size, M, F) |
| 89 | + edge_features = inputs[2] # shape: (batch_size, M, E) |
| 90 | + graph_id_node, graph_id_edge, count_nodes, count_edges = inputs[3:] |
| 91 | + |
| 92 | + x = self.lazy_concat([nodes_in, nodes_out, edge_features], **kwargs) |
| 93 | + x_s, x_f = self.s(x, **kwargs), self.f(x, **kwargs) |
| 94 | + if self.batch_normalization: |
| 95 | + x_s = self.batch_norm_s([x_s, graph_id_edge, count_edges], **kwargs) |
| 96 | + x_f = self.batch_norm_f([x_f, graph_id_edge, count_edges], **kwargs) |
| 97 | + x_s, x_f = self.activation_s_layer(x_s, **kwargs), self.activation_f_layer(x_f, **kwargs) |
| 98 | + x_out = self.lazy_mult([x_s, x_f], **kwargs) # shape: (batch_size, M, self.units) |
| 99 | + return x_out |
| 100 | + |
| 101 | + def update_nodes(self, inputs, **kwargs): |
| 102 | + """Update node embeddings. |
| 103 | +
|
| 104 | + Args: |
| 105 | + inputs: [nodes, nodes_updates, graph_id_node, graph_id_edge, count_nodes, count_edges] |
| 106 | +
|
| 107 | + - nodes (Tensor): Embedding of nodes of previous layer of shape `([M], F)` |
| 108 | + - nodes_updates (Tensor): Node updates of shape `([M], F)` |
| 109 | + - graph_id_node (Tensor): ID tensor of batch assignment in disjoint graph of shape `([N], )` . |
| 110 | + - graph_id_edge (Tensor): ID tensor of batch assignment in disjoint graph of shape `([M], )` . |
| 111 | + - nodes_count (Tensor): Tensor of number of nodes for each graph of shape `(batch, )` . |
| 112 | + - edges_count (Tensor): Tensor of number of edges for each graph of shape `(batch, )` . |
| 113 | +
|
| 114 | + Returns: |
| 115 | + Tensor: Updated nodes of shape `([N], F)`. |
| 116 | + """ |
| 117 | + nodes = inputs[0] |
| 118 | + nodes_update = inputs[1] |
| 119 | + graph_id_node, graph_id_edge, count_nodes, count_edges = inputs[2:] |
| 120 | + |
| 121 | + if self.batch_normalization: |
| 122 | + nodes_update = self.batch_norm_out([nodes_update, graph_id_node, count_nodes], **kwargs) |
| 123 | + |
| 124 | + nodes_updated = self.lazy_add([nodes, nodes_update], **kwargs) |
| 125 | + nodes_updated = self.activation_out_layer(nodes_updated, **kwargs) |
| 126 | + return nodes_updated |
| 127 | + |
| 128 | + def get_config(self): |
| 129 | + """Update layer config.""" |
| 130 | + config = super(CGCNNLayer, self).get_config() |
| 131 | + config.update({ |
| 132 | + "units": self.units, "use_bias": self.use_bias, "padded_disjoint": self.padded_disjoint, |
| 133 | + "batch_normalization": self.batch_normalization}) |
| 134 | + conf_s = self.activation_s_layer.get_config() |
| 135 | + conf_out = self.activation_out_layer.get_config() |
| 136 | + config.update({"activation_s": conf_s["activation"]}) |
| 137 | + config.update({"activation_out": conf_out["activation"]}) |
| 138 | + if "activity_regularizer" in conf_out.keys(): |
| 139 | + config.update({"activity_regularizer": conf_out["activity_regularizer"]}) |
| 140 | + conf_f = self.f.get_config() |
| 141 | + for x in ["kernel_regularizer", "bias_regularizer", "kernel_constraint", |
| 142 | + "bias_constraint", "kernel_initializer", "bias_initializer"]: |
| 143 | + if x in conf_f.keys(): |
| 144 | + config.update({x: conf_f[x]}) |
| 145 | + return config |
0 commit comments