@@ -132,22 +132,6 @@ def Params(cls) -> InstantiableParams:
132132 'all the inputs are replicated. For sharding inputs, this is a '
133133 '`NestedMap` with keys `map_1d`, `map_2d`, ..., etc.,'
134134 'which specifies how to shard the inputs of that dimension.' )
135- tp .Define (
136- 'decoder_inputs_split_mapping' , None , 'The PartitionSpec for decoder'
137- 'inputs such as partially completed sequence. This is only relevant'
138- 'for SPMD sharded models. By default it is None, which means all the'
139- 'inputs are replicated. For sharding the decoder inputs, this is a '
140- '`NestedMap` with keys `map_1d`, `map_2d` ..., etc., which specifies'
141- 'how to shard the decoder inputs corresponding to that dimension.' )
142- tp .Define (
143- 'decoder_states_split_mapping' , None , 'The PartitionSpec for cached'
144- 'decoder states such as keys, values, steps etc. This is only relevant'
145- 'for SPMD sharded models. By default it is None, which means all the'
146- 'inputs are replicated. For sharding the decoder states, this is a '
147- '`NestedMap` with keys `map_1d`, `map_2d` ..., etc., which specifies'
148- 'how to shard the decoder states corresponding to that dimension.' )
149-
150- # TODO(yonghui): Add other hyper-params.
151135 return p
152136
153137 def __init__ (self , params : InstantiableParams ) -> None :
0 commit comments