File tree Expand file tree Collapse file tree 1 file changed +5
-10
lines changed
Expand file tree Collapse file tree 1 file changed +5
-10
lines changed Original file line number Diff line number Diff line change @@ -130,17 +130,12 @@ def step(self, closure=None):
130130 #begin computations
131131 exp_avg , exp_avg_sq = state ['exp_avg' ], state ['exp_avg_sq' ]
132132 beta1 , beta2 = group ['betas' ]
133-
134-
135-
136- #gc work #3 = conv only, 1 = conv +fc
137- if len (list (grad .size ()))> self .gc_gradient_threshold :
138- grad .add_ (- grad .mean (dim = tuple (range (1 ,len (list (grad .size ())))), keepdim = True ))
133+
139134
140- #GC operation for Conv layers and FC layers
141- if len ( list ( grad .size ())) > self .gc_gradient_threshold :
142- grad .add_ (- grad .mean (dim = tuple (range (1 ,len ( list ( grad .size ()) ))), keepdim = True ))
143-
135+ #GC operation for Conv layers and FC layers
136+ if grad .dim () > self .gc_gradient_threshold :
137+ grad .add_ (- grad .mean (dim = tuple (range (1 ,grad .dim ( ))), keepdim = True ))
138+
144139
145140
146141 state ['step' ] += 1
You can’t perform that action at this time.
0 commit comments