File tree Expand file tree Collapse file tree 2 files changed +5
-33
lines changed
Expand file tree Collapse file tree 2 files changed +5
-33
lines changed Load Diff This file was deleted.
Original file line number Diff line number Diff line change @@ -536,6 +536,8 @@ def split_params(data):
536536 split ["scanned_decoder" ][k ] = v
537537 else :
538538 split ["standard" ][k ] = v
539+ # remove empty keys
540+ split = {k : v for k , v in split .items () if v }
539541 for k , v in split .items ():
540542 split [k ] = freeze (traverse_util .unflatten_dict (v ))
541543 return split
@@ -544,7 +546,8 @@ def split_params(data):
544546def unsplit_params (data ):
545547 flat = {}
546548 for k in ["standard" , "scanned_encoder" , "scanned_decoder" ]:
547- flat .update (traverse_util .flatten_dict (unfreeze (data [k ])))
549+ if k in data :
550+ flat .update (traverse_util .flatten_dict (unfreeze (data [k ])))
548551 return freeze (traverse_util .unflatten_dict (flat ))
549552
550553
@@ -1483,7 +1486,7 @@ def run_save_model(state, eval_metrics=None):
14831486 logger .info (" Ready to start training" )
14841487 with mesh :
14851488 for epoch in epochs :
1486- state .replace (epoch = epoch )
1489+ state = state .replace (epoch = epoch )
14871490 local_state ["epoch" ] = epoch
14881491 # ======================== Training ================================
14891492 metrics_logger .update_state_metrics (local_state )
You can’t perform that action at this time.
0 commit comments