Skip to content

Commit 73811db

Browse files
authored
use dim() to simplify gc
thanks @nestordemeure for pointing out code simplification
1 parent f97c298 commit 73811db

File tree

1 file changed

+5
-10
lines changed

1 file changed

+5
-10
lines changed

ranger/ranger.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -130,17 +130,12 @@ def step(self, closure=None):
130130
#begin computations
131131
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
132132
beta1, beta2 = group['betas']
133-
134-
135-
136-
#gc work #3 = conv only, 1 = conv +fc
137-
if len(list(grad.size()))>self.gc_gradient_threshold:
138-
grad.add_(-grad.mean(dim = tuple(range(1,len(list(grad.size())))), keepdim = True))
133+
139134

140-
#GC operation for Conv layers and FC layers
141-
if len(list(grad.size()))>self.gc_gradient_threshold:
142-
grad.add_(-grad.mean(dim = tuple(range(1,len(list(grad.size())))), keepdim = True))
143-
135+
#GC operation for Conv layers and FC layers
136+
if grad.dim() > self.gc_gradient_threshold:
137+
grad.add_(-grad.mean(dim = tuple(range(1,grad.dim())), keepdim = True))
138+
144139

145140

146141
state['step'] += 1

0 commit comments

Comments
 (0)