@@ -672,30 +672,32 @@ def main():
672672
673673 # Load or create new model
674674 if model_args .model_name_or_path :
675- model = DalleBart .from_pretrained (
675+ model , params = DalleBart .from_pretrained (
676676 model_args .model_name_or_path ,
677677 config = config ,
678678 seed = training_args .seed_model ,
679679 dtype = getattr (jnp , model_args .dtype ),
680- abstract_init = True , # we overwrite them with loaded checkpoint
680+ _do_init = False , # we overwrite them with loaded checkpoint
681681 gradient_checkpointing = training_args .gradient_checkpointing ,
682682 )
683683 else :
684684 model = DalleBart (
685685 config ,
686686 seed = training_args .seed_model ,
687687 dtype = getattr (jnp , model_args .dtype ),
688- abstract_init = True ,
688+ _do_init = False ,
689689 )
690+ params = None
691+ params_shape = model .params_shape_tree
690692
691693 # get model metadata
692694 model_metadata = model_args .get_metadata ()
693695
694696 # get PartitionSpec for model params (required to be a dict)
695- param_spec = set_partitions (model . params , model .config .use_scan )
696-
697- # convert params to frozen dict
698- model . _params = freeze (model . params )
697+ param_spec = set_partitions (params_shape , model .config .use_scan )
698+ params_shape = freeze ( params_shape )
699+ if params is not None :
700+ params = freeze (params )
699701
700702 # Load tokenizer
701703 tokenizer = DalleBartTokenizer .from_pretrained (
@@ -736,7 +738,7 @@ def main():
736738 num_train_steps = (
737739 steps_per_epoch * num_epochs if steps_per_epoch is not None else None
738740 )
739- num_params = model .num_params
741+ num_params = model .num_params ( params_shape )
740742
741743 logger .info ("***** Running training *****" )
742744 logger .info (f" Num examples = { len_train_dataset } " )
@@ -875,7 +877,7 @@ def create_learning_rate_fn() -> Callable[[int], jnp.array]:
875877
876878 optimizer = {}
877879 opt_fn = {}
878- for k , p in split_params (model . params ).items ():
880+ for k , p in split_params (params_shape ).items ():
879881 if "scanned" in k :
880882 p = jax .eval_shape (lambda x : jax .tree_map (lambda y : y [0 ], x ), p )
881883 optimizer [k ] = opt .init (p )
@@ -891,7 +893,7 @@ def create_learning_rate_fn() -> Callable[[int], jnp.array]:
891893 b2 = training_args .beta2 ,
892894 eps = training_args .adam_epsilon ,
893895 )
894- optimizer = {k : optimizer for k in split_params (model . params )}
896+ optimizer = {k : optimizer for k in split_params (params_shape )}
895897
896898 elif training_args .optim == "adafactor" :
897899 # We use the default parameters here to initialize adafactor,
@@ -900,21 +902,21 @@ def create_learning_rate_fn() -> Callable[[int], jnp.array]:
900902 learning_rate = learning_rate_fn ,
901903 clipping_threshold = training_args .max_grad_norm ,
902904 )
903- optimizer = {k : optimizer for k in split_params (model . params )}
905+ optimizer = {k : optimizer for k in split_params (params_shape )}
904906
905907 # get PartitionSpec for optimizer state
906908 def get_opt_state_spec_and_shape ():
907909 # get opt_state shape without actual init
908910 opt_state_shape = {}
909- for k , p in split_params (model . params ).items ():
911+ for k , p in split_params (params_shape ).items ():
910912 if "scanned" not in k :
911913 opt_state_shape [k ] = jax .eval_shape (optimizer [k ].init , p )
912914 else :
913915 opt_state_shape [k ] = jax .eval_shape (jax .vmap (optimizer [k ].init ), p )
914916
915917 if training_args .optim == "adafactor" :
916918 # factorized state must be replicated (rank different than params)
917- opt_state_spec = {k : None for k in split_params (model . params )}
919+ opt_state_spec = {k : None for k in split_params (params_shape )}
918920
919921 elif training_args .optim in ["adam" , "distributed_shampoo" ]:
920922
@@ -926,9 +928,9 @@ def _opt_state_spec_per_leaf(x, spec):
926928 # other variables such as count
927929 return None
928930
929- split_spec = split_params (set_partitions (model . params , False ))
931+ split_spec = split_params (set_partitions (params_shape , False ))
930932 opt_state_spec = {}
931- for k , p in split_params (model . params ).items ():
933+ for k , p in split_params (params_shape ).items ():
932934 if "scanned" in k :
933935 p = jax .eval_shape (lambda x : jax .tree_map (lambda y : y [0 ], x ), p )
934936 if training_args .optim == "adam" :
@@ -982,12 +984,12 @@ def _opt_state_spec_per_leaf(x, spec):
982984
983985 # init params if not available yet
984986 def maybe_init_params (params ):
985- if model_args . model_name_or_path :
987+ if params is not None :
986988 # model params are correctly loaded
987989 return params
988990 else :
989991 # params have not been initialized yet
990- return model .init_weights ()
992+ return model .init_weights (model . key , model . input_shape )
991993
992994 with mesh :
993995 logger .info (" Creating state" )
@@ -1008,7 +1010,7 @@ def init_state(params):
10081010 else None ,
10091011 out_axis_resources = state_spec ,
10101012 donate_argnums = (0 ,),
1011- )(model . params if model_args . model_name_or_path else None )
1013+ )(params )
10121014
10131015 else :
10141016 # load opt_state
@@ -1038,13 +1040,13 @@ def restore_state(params, opt_state):
10381040 ),
10391041 out_axis_resources = state_spec ,
10401042 donate_argnums = (0 , 1 ),
1041- )(model . params , opt_state )
1043+ )(params , opt_state )
10421044
10431045 # remove opt_state from CPU
10441046 del opt_state
10451047
10461048 # free CPU memory
1047- del model . _params , opt_state_spec , opt_state_shape
1049+ del params , opt_state_spec , opt_state_shape
10481050
10491051 # define batch specs
10501052 batch_spec = PartitionSpec ("dp" )
0 commit comments