Skip to content

Commit 7891c0a

Browse files
lingvo-botcopybara-github
authored andcommitted
Remove unused model_p.decoder_inputs_split_mapping and model_p.decoder_states_split_mapping.
PiperOrigin-RevId: 415406130
1 parent 483d896 commit 7891c0a

File tree

2 files changed

+0
-25
lines changed

2 files changed

+0
-25
lines changed

lingvo/jax/model.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -132,22 +132,6 @@ def Params(cls) -> InstantiableParams:
132132
'all the inputs are replicated. For sharding inputs, this is a '
133133
'`NestedMap` with keys `map_1d`, `map_2d`, ..., etc.,'
134134
'which specifies how to shard the inputs of that dimension.')
135-
tp.Define(
136-
'decoder_inputs_split_mapping', None, 'The PartitionSpec for decoder'
137-
'inputs such as partially completed sequence. This is only relevant'
138-
'for SPMD sharded models. By default it is None, which means all the'
139-
'inputs are replicated. For sharding the decoder inputs, this is a '
140-
'`NestedMap` with keys `map_1d`, `map_2d` ..., etc., which specifies'
141-
'how to shard the decoder inputs corresponding to that dimension.')
142-
tp.Define(
143-
'decoder_states_split_mapping', None, 'The PartitionSpec for cached'
144-
'decoder states such as keys, values, steps etc. This is only relevant'
145-
'for SPMD sharded models. By default it is None, which means all the'
146-
'inputs are replicated. For sharding the decoder states, this is a '
147-
'`NestedMap` with keys `map_1d`, `map_2d` ..., etc., which specifies'
148-
'how to shard the decoder states corresponding to that dimension.')
149-
150-
# TODO(yonghui): Add other hyper-params.
151135
return p
152136

153137
def __init__(self, params: InstantiableParams) -> None:

lingvo/jax/tasks/lm/model_params.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,6 @@ def set_sharding_annotations_v1(model_p: InstantiableParams,
5151
model_p.train.inputs_split_mapping = NestedMap(
5252
map_1d=((replica_axis, data_axis),),
5353
map_2d=((replica_axis, data_axis), None))
54-
model_p.train.decoder_inputs_split_mapping = NestedMap(
55-
map_1d=((replica_axis, data_axis),))
56-
model_p.train.decoder_states_split_mapping = NestedMap(
57-
map_0d=None,
58-
map_4d=(None, (replica_axis, data_axis), mdl_axis, None),
59-
# 5d inputs are for the decoder states of shape [layers, seq_len,
60-
# batch_size, num_heads, dims_per_head]
61-
map_5d=(None, None, (replica_axis, data_axis), mdl_axis, None),
62-
)
6354
model_p.mesh_axis_names = mesh_axis_names
6455
model_p.lm = model_p.lm.cls.set_sharding_params_v1(
6556
model_p.lm,

0 commit comments

Comments
 (0)