@@ -89,17 +89,29 @@ def nmf_gradient(structures, oh):
8989 """
9090 choice = K .argmax (oh , axis = - 1 )
9191 prob = K .softmax (K .real (structures ), axis = - 1 )
92- indices = K .transpose (
93- K .stack ([K .cast (torch .arange (structures .shape [0 ]), "int64" ), choice ])
94- )
95- prob = torch .gather (prob , 0 , indices .unsqueeze (0 )).squeeze (0 )
96- prob = K .reshape (prob , [- 1 , 1 ])
97- prob = K .tile (prob , [1 , structures .shape [- 1 ]])
92+
93+ # In vmap, structures shape is [nlayers * n, 7], oh shape is [nlayers * n, 7]
94+ # choice shape is [nlayers * n]
95+ # prob shape is [nlayers * n, 7]
96+
97+ # Create indices for gathering
98+ seq_len = structures .shape [0 ]
99+ seq_indices = torch .arange (seq_len , device = structures .device )
100+ indices = torch .stack ([seq_indices , choice ], dim = - 1 )
101+
102+ # Gather the selected probabilities
103+ prob_gathered = torch .gather (prob , 1 , choice .unsqueeze (- 1 )).squeeze (- 1 )
104+
105+ prob_gathered = K .reshape (prob_gathered , [- 1 , 1 ])
106+ prob_gathered = K .tile (prob_gathered , [1 , structures .shape [- 1 ]])
98107
99108 # warning pytorch might be unable to do this exactly
100- result = torch .zeros_like (structures , dtype = ctype )
101- result .scatter_add_ (0 , indices , torch .ones ([structures .shape [0 ]], dtype = ctype ))
102- return K .real (result - prob )
109+ result = torch .zeros_like (structures , dtype = torch .complex64 )
110+
111+ # Use scatter_add_ to accumulate gradients
112+ result .scatter_add_ (1 , choice .unsqueeze (- 1 ), torch .ones_like (choice , dtype = torch .complex64 ).unsqueeze (- 1 ))
113+
114+ return K .real (result - prob_gathered )
103115
104116
105117# warning pytorch might be unable to do this exactly
@@ -182,9 +194,9 @@ def main(stddev=0.05, lr=None, epochs=2000, debug_step=50, batch=256, verbose=Fa
182194
183195
184196if __name__ == "__main__" :
185- tries = 5
197+ tries = 1 # 减少尝试次数
186198 rs = []
187199 for _ in range (tries ):
188- ee , _ , _ = main ()
200+ ee , _ , _ = main (epochs = 10 , batch = 32 ) # 减少epochs和batch size
189201 rs .append (- K .numpy (ee ))
190202 print (np .min (rs ))
0 commit comments