-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathGAT_Layer.py
More file actions
301 lines (252 loc) · 10.9 KB
/
GAT_Layer.py
File metadata and controls
301 lines (252 loc) · 10.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
import os; os.environ['JAX_TRACEBACK_FILTERING'] = 'off'
import jax
import jax.numpy as jnp
from jax import random
from dataclasses import dataclass
from typing import Callable, Optional, Tuple
import jax.tree_util as jtu
def elu(x):
return jax.nn.elu(x)
def leaky_relu(x, negative_slope=0.2):
return jax.nn.leaky_relu(x, negative_slope)
def layer_norm(x, eps=1e-5):
mean = jnp.mean(x, axis=-1, keepdims=True)
var = jnp.var(x, axis=-1, keepdims=True)
return (x - mean) / jnp.sqrt(var + eps)
# 为 Linear 注册 PyTree 接口
@jtu.register_pytree_node_class
@dataclass
class Linear:
"""
- x: (..., in_features)
- W: (in_features, out_features)
- b: (out_features,)
"""
W: jnp.ndarray
b: Optional[jnp.ndarray] = None
@staticmethod
def init(key, in_features, out_features, bias=True, scale=1.0):
# Xavier uniform
lim = jnp.sqrt(6.0 / (in_features + out_features)) * scale
kW, kb = random.split(key)
W = random.uniform(kW, (in_features, out_features), minval=-lim, maxval=lim)
b_arr = random.uniform(kb, (out_features,), minval=-lim, maxval=lim) if bias else None
if bias:
pass
return Linear(W=W, b=b_arr)
def __call__(self, x):
y = x @ self.W
if self.b is not None:
y = y + self.b
return y
def tree_flatten(self):
# 将 Linear 对象展平为子节点和辅助数据
if self.b is not None:
children = (self.W, self.b)
else:
children = (self.W,)
aux_data = (self.b is not None,) # 存储是否有偏置的信息
return children, aux_data
@classmethod
def tree_unflatten(cls, aux_data, children):
has_bias = aux_data[0]
if has_bias:
W, b = children
return cls(W=W, b=b)
else:
W = children[0]
return cls(W=W, b=None)
def dropout(rng, x, rate: float, deterministic: bool):
if deterministic or rate == 0.0:
return x
keep_prob = 1.0 - rate
mask = random.bernoulli(rng, p=keep_prob, shape=x.shape)
return jnp.where(mask, x / keep_prob, 0.0)
# 为 GATLayerBaseParams 注册 PyTree 接口
@jtu.register_pytree_node_class
@dataclass
class GATLayerBaseParams:
head_projections: Tuple[Linear, ...] # len = num_heads
scoring_fn_target: jnp.ndarray # (num_heads, out_features, 1)
scoring_fn_source: jnp.ndarray # (num_heads, out_features, 1)
edge_distance_proj: Linear # in=1 out=num_heads, bias=False
edge_bond_proj: Linear # in=1 out=num_heads, bias=False
skip_proj: Optional[Linear] # in=in_features out=num_heads*out_features, bias=False
bias: Optional[jnp.ndarray] # (num_heads*out_features,) if concat else (out_features,)
def tree_flatten(self):
# 将所有属性作为子节点
children = (
self.head_projections,
self.scoring_fn_target,
self.scoring_fn_source,
self.edge_distance_proj,
self.edge_bond_proj,
self.skip_proj,
self.bias,
)
aux_data = None
return children, aux_data
@classmethod
def tree_unflatten(cls, aux_data, children):
(
head_projections,
scoring_fn_target,
scoring_fn_source,
edge_distance_proj,
edge_bond_proj,
skip_proj,
bias,
) = children
return cls(
head_projections=head_projections,
scoring_fn_target=scoring_fn_target,
scoring_fn_source=scoring_fn_source,
edge_distance_proj=edge_distance_proj,
edge_bond_proj=edge_bond_proj,
skip_proj=skip_proj,
bias=bias,
)
# 注意:GATLayerConfig 不需要 PyTree 注册,因为它只包含配置参数,不包含可训练参数
@dataclass
class GATLayerConfig:
num_in_features: int
num_out_features: int
num_heads: int = 1
concat: bool = False
add_skip_connection: bool = True
bias: bool = True
dropout_prob: float = 0.6
activation: Optional[Callable] = elu
log_attention_weights: bool = True
esp: float = 1e-6
def init_gat_layer_params(key, cfg: GATLayerConfig) -> GATLayerBaseParams:
k = key
keys = random.split(k, 3 + cfg.num_heads + (1 if cfg.add_skip_connection else 0) + (1 if cfg.bias else 0))
# head projections (bias=False)
head_proj_keys = keys[:cfg.num_heads]
head_projections = tuple(
Linear.init(hk, cfg.num_in_features, cfg.num_out_features, bias=False)
for hk in head_proj_keys
)
idx = cfg.num_heads
# scoring tensors: Xavier-ish init
k_t = keys[idx]; idx += 1
k_s = keys[idx]; idx += 1
lim = jnp.sqrt(6.0 / (cfg.num_out_features + 1.0))
scoring_fn_target = random.uniform(k_t, (cfg.num_heads, cfg.num_out_features, 1), minval=-lim, maxval=lim)
scoring_fn_source = random.uniform(k_s, (cfg.num_heads, cfg.num_out_features, 1), minval=-lim, maxval=lim)
# edge proj: Linear(1 -> num_heads, bias=False)
k_ed = keys[idx]; idx += 1
k_eb = keys[idx]; idx += 1
edge_distance_proj = Linear.init(k_ed, 1, cfg.num_heads, bias=False)
edge_bond_proj = Linear.init(k_eb, 1, cfg.num_heads, bias=False)
# bias
b_arr = None
if cfg.bias:
kb = keys[idx]; idx += 1
if cfg.concat:
b_arr = jnp.zeros((cfg.num_heads * cfg.num_out_features,), dtype=jnp.float32)
else:
b_arr = jnp.zeros((cfg.num_out_features,), dtype=jnp.float32)
# skip proj
skip_proj = None
if cfg.add_skip_connection:
ks = keys[idx]; idx += 1
skip_proj = Linear.init(ks, cfg.num_in_features, cfg.num_heads * cfg.num_out_features, bias=False)
return GATLayerBaseParams(
head_projections=head_projections,
scoring_fn_target=scoring_fn_target,
scoring_fn_source=scoring_fn_source,
edge_distance_proj=edge_distance_proj,
edge_bond_proj=edge_bond_proj,
skip_proj=skip_proj,
bias=b_arr,
)
def skip_concat_bias(
cfg: GATLayerConfig,
params: GATLayerBaseParams,
attention_coefficients: jnp.ndarray, # (H, N, N)
in_nodes_features: jnp.ndarray, # (N, Fin)
out_nodes_features: jnp.ndarray, # (H, N, Fout) or reshaped equivalent
):
if out_nodes_features.ndim != 3 or out_nodes_features.shape[0] != cfg.num_heads:
out_nodes_features = out_nodes_features.reshape(cfg.num_heads, -1, cfg.num_out_features)
# skip connection
if cfg.add_skip_connection:
if cfg.num_out_features == cfg.num_in_features:
out_nodes_features = out_nodes_features + in_nodes_features[None, :, :]
else:
assert params.skip_proj is not None
skip = params.skip_proj(in_nodes_features) # (N, H*Fout)
skip = skip.reshape(-1, cfg.num_heads, cfg.num_out_features).transpose(1, 0, 2) # (H, N, Fout)
out_nodes_features = out_nodes_features + skip
# concat or average heads
if cfg.concat:
out = out_nodes_features.transpose(1, 0, 2).reshape(-1, cfg.num_heads * cfg.num_out_features) # (N, H*Fout)
else:
out = jnp.mean(out_nodes_features, axis=0) # (N, Fout)
# bias
if params.bias is not None:
out = out + params.bias
# activation
if cfg.activation is None:
return out
return cfg.activation(out)
def gat_forward(
rng,
cfg: GATLayerConfig,
params: GATLayerBaseParams,
nodes_features: jnp.ndarray, # (N, Fin)
degree_matrix: jnp.ndarray, # (N, N)
edges_features_distance: jnp.ndarray, # (N, N) or (N, N, 1)
edges_features_bond: jnp.ndarray, # (N, N) or (N, N, 1)
deterministic: bool = False,
) -> Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray]]:
N = nodes_features.shape[0]
connectivity_mask = jnp.where(degree_matrix > 0, 0.0, -1e9)
assert connectivity_mask.shape == (N, N)
head_feats = []
for h in range(cfg.num_heads):
hf = params.head_projections[h](nodes_features)
head_feats.append(hf)
nodes_features_proj = jnp.stack(head_feats, axis=0)
# dropout on projected node features
rng, kdrop = random.split(rng)
nodes_features_proj = dropout(kdrop, nodes_features_proj, rate=cfg.dropout_prob, deterministic=deterministic)
# scores_source/target: bmm => (H, N, 1)
scores_source = jnp.matmul(nodes_features_proj, params.scoring_fn_source) # (H, N, 1)
scores_target = jnp.matmul(nodes_features_proj, params.scoring_fn_target) # (H, N, 1)
# all_scores: (H, N, N)
all_scores = leaky_relu(scores_source + jnp.swapaxes(scores_target, 1, 2), 0.2)
ed = edges_features_distance
if ed.ndim == 3 and ed.shape[-1] == 1:
ed = ed[..., 0]
ed_in = (-ed).reshape(-1, 1) # (N*N,1)
ed_out = params.edge_distance_proj(ed_in) # (N*N,H)
edge_distance_contribution = ed_out.reshape(N, N, cfg.num_heads).transpose(2, 0, 1) # (H,N,N)
all_scores = all_scores + edge_distance_contribution
# edge bond contribution
eb = edges_features_bond
if eb.ndim == 3 and eb.shape[-1] == 1:
eb = eb[..., 0]
eb_in = eb.reshape(-1, 1) # (N*N,1)
eb_out = params.edge_bond_proj(eb_in) # (N*N,H)
edge_bond_contribution = eb_out.reshape(N, N, cfg.num_heads).transpose(2, 0, 1) # (H,N,N)
all_scores = all_scores + edge_bond_contribution
masked_scores = all_scores + connectivity_mask[None, :, :]
max_vals = jnp.max(masked_scores, axis=-1, keepdims=True)
stable_scores = masked_scores - max_vals
all_attention_coefficients = jax.nn.softmax(stable_scores, axis=-1) # (H,N,N)
out_nodes_features = jnp.matmul(all_attention_coefficients, nodes_features_proj) # (H,N,Fout)
updated_nodes_features = skip_concat_bias(
cfg, params, all_attention_coefficients, nodes_features, out_nodes_features
)
unf = jax.lax.stop_gradient(updated_nodes_features)
node_similarity = jnp.matmul(unf, unf.T) # (N,N)
distance_decay = -ed # (N,N)
updated_connectivity_mask = jax.nn.sigmoid(node_similarity) * distance_decay * (degree_matrix > 0).astype(jnp.float32)
updated_connectivity_mask = updated_connectivity_mask + cfg.esp
updated_connectivity_mask = layer_norm(updated_connectivity_mask, eps=1e-5)
updated_connectivity_mask = updated_connectivity_mask + updated_connectivity_mask.T
attention_weights = all_attention_coefficients if cfg.log_attention_weights else None
return updated_nodes_features, updated_connectivity_mask, attention_weights