@@ -247,47 +247,97 @@ def get_config(self):
247247 return config
248248
249249
250- class MultiHeadGATV2Layer (AttentionHeadGATV2 ): # noqa
250+ class MultiHeadGATV2Layer (Layer ): # noqa
251+ r"""Single layer for multiple Attention heads from :obj:`AttentionHeadGATV2` .
252+
253+ Uses concatenation or averaging of heads for final output.
254+ """
251255
252256 def __init__ (self ,
253257 units : int ,
254258 num_heads : int ,
255259 activation : str = "kgcnn>leaky_relu2" ,
256260 use_bias : bool = True ,
257261 concat_heads : bool = True ,
262+ use_edge_features = False ,
263+ use_final_activation = True ,
264+ has_self_loops = True ,
265+ kernel_regularizer = None ,
266+ bias_regularizer = None ,
267+ activity_regularizer = None ,
268+ kernel_constraint = None ,
269+ bias_constraint = None ,
270+ kernel_initializer = 'glorot_uniform' ,
271+ bias_initializer = 'zeros' ,
272+ normalize_softmax : bool = False ,
258273 ** kwargs ):
259- super (MultiHeadGATV2Layer , self ).__init__ (
260- units = units ,
261- activation = activation ,
262- use_bias = use_bias ,
263- ** kwargs
264- )
274+ r"""Initialize layer.
275+
276+ Args:
277+ units (int): Units for the linear trafo of node features before attention.
278+ num_heads: Number of attention heads.
279+ concat_heads: Whether to concatenate heads or average.
280+ use_edge_features (bool): Append edge features to attention computation. Default is False.
281+ use_final_activation (bool): Whether to apply the final activation for the output.
282+ has_self_loops (bool): If the graph has self-loops. Not used here. Default is True.
283+ activation (str): Activation. Default is "kgcnn>leaky_relu2".
284+ use_bias (bool): Use bias. Default is True.
285+ kernel_regularizer: Kernel regularization. Default is None.
286+ bias_regularizer: Bias regularization. Default is None.
287+ activity_regularizer: Activity regularization. Default is None.
288+ kernel_constraint: Kernel constrains. Default is None.
289+ bias_constraint: Bias constrains. Default is None.
290+ kernel_initializer: Initializer for kernels. Default is 'glorot_uniform'.
291+ bias_initializer: Initializer for bias. Default is 'zeros'.
292+ """
293+ super (MultiHeadGATV2Layer , self ).__init__ (** kwargs )
265294 # Changes in keras serialization behaviour for activations in 3.0.2.
266295 # Keep string at least for default. Also renames to prevent clashes with keras leaky_relu.
267296 if activation in ["kgcnn>leaky_relu" , "kgcnn>leaky_relu2" ]:
268297 activation = {"class_name" : "function" , "config" : "kgcnn>leaky_relu2" }
269298 self .num_heads = num_heads
270299 self .concat_heads = concat_heads
300+ self .use_edge_features = use_edge_features
301+ self .use_final_activation = use_final_activation
302+ self .has_self_loops = has_self_loops
303+ self .units = int (units )
304+ self .normalize_softmax = normalize_softmax
305+ self .use_bias = use_bias
306+ kernel_args = {"kernel_regularizer" : kernel_regularizer ,
307+ "activity_regularizer" : activity_regularizer , "bias_regularizer" : bias_regularizer ,
308+ "kernel_constraint" : kernel_constraint , "bias_constraint" : bias_constraint ,
309+ "kernel_initializer" : kernel_initializer , "bias_initializer" : bias_initializer }
271310
272311 self .head_layers = []
273312 for _ in range (num_heads ):
274- lay_linear = Dense (units , activation = activation , use_bias = use_bias )
275- lay_alpha_activation = Dense (units , activation = activation , use_bias = use_bias )
276- lay_alpha = Dense (1 , activation = 'linear' , use_bias = False )
313+ lay_linear = Dense (units , activation = activation , use_bias = use_bias , ** kernel_args )
314+ lay_alpha_activation = Dense (units , activation = activation , use_bias = use_bias , ** kernel_args )
315+ lay_alpha = Dense (1 , activation = 'linear' , use_bias = False , ** kernel_args )
277316
278317 self .head_layers .append ((lay_linear , lay_alpha_activation , lay_alpha ))
279318
280319 self .lay_concat_alphas = Concatenate (axis = - 2 )
281- self .lay_concat_embeddings = Concatenate (axis = - 2 )
282- self .lay_pool_attention = AggregateLocalEdgesAttention ()
283- # self.lay_pool = AggregateLocalEdges()
320+
321+ # self.lay_linear_trafo = Dense(units, activation="linear", use_bias=use_bias, **kernel_args)
322+ # self.lay_alpha_activation = Dense(units, activation=activation, use_bias=use_bias, **kernel_args)
323+ # self.lay_alpha = Dense(1, activation="linear", use_bias=False, **kernel_args)
324+ self .lay_gather_in = GatherNodesIngoing ()
325+ self .lay_gather_out = GatherNodesOutgoing ()
326+ self .lay_concat = Concatenate (axis = - 1 )
327+ self .lay_pool_attention = AggregateLocalEdgesAttention (normalize_softmax = normalize_softmax )
328+ if self .use_final_activation :
329+ self .lay_final_activ = Activation (activation = activation )
284330
285331 if self .concat_heads :
286332 self .lay_combine_heads = Concatenate (axis = - 1 )
287333 else :
288334 self .lay_combine_heads = Average ()
289335
290- def __call__ (self , inputs , ** kwargs ):
336+ def build (self , input_shape ):
337+ """Build layer."""
338+ super (MultiHeadGATV2Layer , self ).build (input_shape )
339+
340+ def call (self , inputs , ** kwargs ):
291341 node , edge , edge_index = inputs
292342
293343 # "a_ij" is a single-channel edge attention logits tensor. "a_ijs" is consequently the list which
@@ -338,6 +388,16 @@ def __call__(self, inputs, **kwargs):
338388 def get_config (self ):
339389 """Update layer config."""
340390 config = super (MultiHeadGATV2Layer , self ).get_config ()
391+ config .update ({"use_edge_features" : self .use_edge_features , "use_bias" : self .use_bias ,
392+ "units" : self .units , "has_self_loops" : self .has_self_loops ,
393+ "normalize_softmax" : self .normalize_softmax ,
394+ "use_final_activation" : self .use_final_activation })
395+ if self .num_heads > 0 :
396+ conf_sub = self .head_layers [0 ][0 ].get_config ()
397+ for x in ["kernel_regularizer" , "activity_regularizer" , "bias_regularizer" , "kernel_constraint" ,
398+ "bias_constraint" , "kernel_initializer" , "bias_initializer" , "activation" ]:
399+ if x in conf_sub :
400+ config .update ({x : conf_sub [x ]})
341401 config .update ({
342402 'num_heads' : self .num_heads ,
343403 'concat_heads' : self .concat_heads
0 commit comments