Skip to content

Commit 6b84155

Browse files
committed
feat(train): use new HF _do_init api
1 parent f3a8cbb commit 6b84155

File tree

2 files changed

+26
-23
lines changed

2 files changed

+26
-23
lines changed

src/dalle_mini/model/modeling.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,10 +1330,11 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
13301330

13311331
config_class = DalleBartConfig
13321332

1333-
@property
1334-
def num_params(self):
1333+
def num_params(self, params=None):
1334+
if params is None:
1335+
params = self.params
13351336
num_params = jax.tree_map(
1336-
lambda param: param.size, flatten_dict(unfreeze(self.params))
1337+
lambda param: param.size, flatten_dict(unfreeze(params))
13371338
).values()
13381339
return sum(list(num_params))
13391340

tools/train/train.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -672,30 +672,32 @@ def main():
672672

673673
# Load or create new model
674674
if model_args.model_name_or_path:
675-
model = DalleBart.from_pretrained(
675+
model, params = DalleBart.from_pretrained(
676676
model_args.model_name_or_path,
677677
config=config,
678678
seed=training_args.seed_model,
679679
dtype=getattr(jnp, model_args.dtype),
680-
abstract_init=True, # we overwrite them with loaded checkpoint
680+
_do_init=False, # we overwrite them with loaded checkpoint
681681
gradient_checkpointing=training_args.gradient_checkpointing,
682682
)
683683
else:
684684
model = DalleBart(
685685
config,
686686
seed=training_args.seed_model,
687687
dtype=getattr(jnp, model_args.dtype),
688-
abstract_init=True,
688+
_do_init=False,
689689
)
690+
params = None
691+
params_shape = model.params_shape_tree
690692

691693
# get model metadata
692694
model_metadata = model_args.get_metadata()
693695

694696
# get PartitionSpec for model params (required to be a dict)
695-
param_spec = set_partitions(model.params, model.config.use_scan)
696-
697-
# convert params to frozen dict
698-
model._params = freeze(model.params)
697+
param_spec = set_partitions(params_shape, model.config.use_scan)
698+
params_shape = freeze(params_shape)
699+
if params is not None:
700+
params = freeze(params)
699701

700702
# Load tokenizer
701703
tokenizer = DalleBartTokenizer.from_pretrained(
@@ -736,7 +738,7 @@ def main():
736738
num_train_steps = (
737739
steps_per_epoch * num_epochs if steps_per_epoch is not None else None
738740
)
739-
num_params = model.num_params
741+
num_params = model.num_params(params_shape)
740742

741743
logger.info("***** Running training *****")
742744
logger.info(f" Num examples = {len_train_dataset}")
@@ -875,7 +877,7 @@ def create_learning_rate_fn() -> Callable[[int], jnp.array]:
875877

876878
optimizer = {}
877879
opt_fn = {}
878-
for k, p in split_params(model.params).items():
880+
for k, p in split_params(params_shape).items():
879881
if "scanned" in k:
880882
p = jax.eval_shape(lambda x: jax.tree_map(lambda y: y[0], x), p)
881883
optimizer[k] = opt.init(p)
@@ -891,7 +893,7 @@ def create_learning_rate_fn() -> Callable[[int], jnp.array]:
891893
b2=training_args.beta2,
892894
eps=training_args.adam_epsilon,
893895
)
894-
optimizer = {k: optimizer for k in split_params(model.params)}
896+
optimizer = {k: optimizer for k in split_params(params_shape)}
895897

896898
elif training_args.optim == "adafactor":
897899
# We use the default parameters here to initialize adafactor,
@@ -900,21 +902,21 @@ def create_learning_rate_fn() -> Callable[[int], jnp.array]:
900902
learning_rate=learning_rate_fn,
901903
clipping_threshold=training_args.max_grad_norm,
902904
)
903-
optimizer = {k: optimizer for k in split_params(model.params)}
905+
optimizer = {k: optimizer for k in split_params(params_shape)}
904906

905907
# get PartitionSpec for optimizer state
906908
def get_opt_state_spec_and_shape():
907909
# get opt_state shape without actual init
908910
opt_state_shape = {}
909-
for k, p in split_params(model.params).items():
911+
for k, p in split_params(params_shape).items():
910912
if "scanned" not in k:
911913
opt_state_shape[k] = jax.eval_shape(optimizer[k].init, p)
912914
else:
913915
opt_state_shape[k] = jax.eval_shape(jax.vmap(optimizer[k].init), p)
914916

915917
if training_args.optim == "adafactor":
916918
# factorized state must be replicated (rank different than params)
917-
opt_state_spec = {k: None for k in split_params(model.params)}
919+
opt_state_spec = {k: None for k in split_params(params_shape)}
918920

919921
elif training_args.optim in ["adam", "distributed_shampoo"]:
920922

@@ -926,9 +928,9 @@ def _opt_state_spec_per_leaf(x, spec):
926928
# other variables such as count
927929
return None
928930

929-
split_spec = split_params(set_partitions(model.params, False))
931+
split_spec = split_params(set_partitions(params_shape, False))
930932
opt_state_spec = {}
931-
for k, p in split_params(model.params).items():
933+
for k, p in split_params(params_shape).items():
932934
if "scanned" in k:
933935
p = jax.eval_shape(lambda x: jax.tree_map(lambda y: y[0], x), p)
934936
if training_args.optim == "adam":
@@ -982,12 +984,12 @@ def _opt_state_spec_per_leaf(x, spec):
982984

983985
# init params if not available yet
984986
def maybe_init_params(params):
985-
if model_args.model_name_or_path:
987+
if params is not None:
986988
# model params are correctly loaded
987989
return params
988990
else:
989991
# params have not been initialized yet
990-
return model.init_weights()
992+
return model.init_weights(model.key, model.input_shape)
991993

992994
with mesh:
993995
logger.info(" Creating state")
@@ -1008,7 +1010,7 @@ def init_state(params):
10081010
else None,
10091011
out_axis_resources=state_spec,
10101012
donate_argnums=(0,),
1011-
)(model.params if model_args.model_name_or_path else None)
1013+
)(params)
10121014

10131015
else:
10141016
# load opt_state
@@ -1038,13 +1040,13 @@ def restore_state(params, opt_state):
10381040
),
10391041
out_axis_resources=state_spec,
10401042
donate_argnums=(0, 1),
1041-
)(model.params, opt_state)
1043+
)(params, opt_state)
10421044

10431045
# remove opt_state from CPU
10441046
del opt_state
10451047

10461048
# free CPU memory
1047-
del model._params, opt_state_spec, opt_state_shape
1049+
del params, opt_state_spec, opt_state_shape
10481050

10491051
# define batch specs
10501052
batch_spec = PartitionSpec("dp")

0 commit comments

Comments
 (0)