Skip to content

Commit 57456a0

Browse files
Kyle Tayloralimuldal
authored andcommitted
Minor changes to match trained SOTA model.
PiperOrigin-RevId: 450716647
1 parent d436681 commit 57456a0

File tree

3 files changed

+19
-11
lines changed

3 files changed

+19
-11
lines changed

enformer/attention_module.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,8 @@ def positional_features_gamma(positions: tf.Tensor,
453453
tf.abs(tf.cast(positions, dtype=tf.float32))[..., tf.newaxis],
454454
concentration, rate)
455455
probabilities += 1e-8 # To ensure numerical stability.
456-
outputs = probabilities / tf.reduce_max(probabilities)
456+
outputs = probabilities / tf.reduce_max(probabilities,
457+
axis=1, keepdims=True)
457458
tf.TensorShape(outputs.shape).assert_is_compatible_with(
458459
positions.shape + [feature_size])
459460
return outputs

enformer/enformer.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,12 @@ def __init__(self,
8787
# lambda is used in Sequential to construct the module under tf.name_scope.
8888
def conv_block(filters, width=1, w_init=None, name='conv_block', **kwargs):
8989
return Sequential(lambda: [
90-
snt.BatchNorm(create_scale=True,
91-
create_offset=True,
92-
decay_rate=0.9,
93-
scale_init=snt.initializers.Ones()),
90+
snt.distribute.CrossReplicaBatchNorm(
91+
create_scale=True,
92+
create_offset=True,
93+
scale_init=snt.initializers.Ones(),
94+
moving_mean=snt.ExponentialMovingAverage(0.9),
95+
moving_variance=snt.ExponentialMovingAverage(0.9)),
9496
gelu,
9597
snt.Conv1D(filters, width, w_init=w_init, **kwargs)
9698
], name=name)
@@ -184,16 +186,22 @@ def predict_on_batch(self, x):
184186
class TargetLengthCrop1D(snt.Module):
185187
"""Crop sequence to match the desired target length."""
186188

187-
def __init__(self, target_length: int, name='target_length_crop'):
189+
def __init__(self,
190+
target_length: Optional[int],
191+
name: str = 'target_length_crop'):
188192
super().__init__(name=name)
189193
self._target_length = target_length
190194

191195
def __call__(self, inputs):
196+
if self._target_length is None:
197+
return inputs
192198
trim = (inputs.shape[-2] - self._target_length) // 2
193199
if trim < 0:
194200
raise ValueError('inputs longer than target length')
195-
196-
return inputs[..., trim:-trim, :]
201+
elif trim == 0:
202+
return inputs
203+
else:
204+
return inputs[..., trim:-trim, :]
197205

198206

199207
class Sequential(snt.Module):
@@ -209,8 +217,7 @@ def __init__(self,
209217
else:
210218
# layers wrapped in a lambda function to have a common namespace.
211219
if hasattr(layers, '__call__'):
212-
with tf.name_scope(name):
213-
layers = layers()
220+
layers = layers()
214221
self._layers = [layer for layer in layers if layer is not None]
215222

216223
def __call__(self, inputs: tf.Tensor, is_training: bool, **kwargs):

enformer/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@ dm-sonnet==2.0.0
22
kipoiseq==0.5.2
33
numpy==1.19.5
44
pandas==1.2.3
5-
tensorflow==2.4.1
5+
tensorflow==2.5.0
66
tensorflow-hub==0.11.0

0 commit comments

Comments
 (0)