@@ -87,10 +87,12 @@ def __init__(self,
8787 # lambda is used in Sequential to construct the module under tf.name_scope.
8888 def conv_block (filters , width = 1 , w_init = None , name = 'conv_block' , ** kwargs ):
8989 return Sequential (lambda : [
90- snt .BatchNorm (create_scale = True ,
91- create_offset = True ,
92- decay_rate = 0.9 ,
93- scale_init = snt .initializers .Ones ()),
90+ snt .distribute .CrossReplicaBatchNorm (
91+ create_scale = True ,
92+ create_offset = True ,
93+ scale_init = snt .initializers .Ones (),
94+ moving_mean = snt .ExponentialMovingAverage (0.9 ),
95+ moving_variance = snt .ExponentialMovingAverage (0.9 )),
9496 gelu ,
9597 snt .Conv1D (filters , width , w_init = w_init , ** kwargs )
9698 ], name = name )
@@ -184,16 +186,22 @@ def predict_on_batch(self, x):
184186class TargetLengthCrop1D (snt .Module ):
185187 """Crop sequence to match the desired target length."""
186188
187- def __init__ (self , target_length : int , name = 'target_length_crop' ):
189+ def __init__ (self ,
190+ target_length : Optional [int ],
191+ name : str = 'target_length_crop' ):
188192 super ().__init__ (name = name )
189193 self ._target_length = target_length
190194
191195 def __call__ (self , inputs ):
196+ if self ._target_length is None :
197+ return inputs
192198 trim = (inputs .shape [- 2 ] - self ._target_length ) // 2
193199 if trim < 0 :
194200 raise ValueError ('inputs longer than target length' )
195-
196- return inputs [..., trim :- trim , :]
201+ elif trim == 0 :
202+ return inputs
203+ else :
204+ return inputs [..., trim :- trim , :]
197205
198206
199207class Sequential (snt .Module ):
@@ -209,8 +217,7 @@ def __init__(self,
209217 else :
210218 # layers wrapped in a lambda function to have a common namespace.
211219 if hasattr (layers , '__call__' ):
212- with tf .name_scope (name ):
213- layers = layers ()
220+ layers = layers ()
214221 self ._layers = [layer for layer in layers if layer is not None ]
215222
216223 def __call__ (self , inputs : tf .Tensor , is_training : bool , ** kwargs ):
0 commit comments