Skip to content

Commit 8ae9176

Browse files
authored
fix: allow non-scanned models (borisdayma#168)
1 parent 3500e67 commit 8ae9176

File tree

2 files changed

+5
-33
lines changed

2 files changed

+5
-33
lines changed

tools/train/config/medium/config.json

Lines changed: 0 additions & 31 deletions
This file was deleted.

tools/train/train.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,8 @@ def split_params(data):
536536
split["scanned_decoder"][k] = v
537537
else:
538538
split["standard"][k] = v
539+
# remove empty keys
540+
split = {k: v for k, v in split.items() if v}
539541
for k, v in split.items():
540542
split[k] = freeze(traverse_util.unflatten_dict(v))
541543
return split
@@ -544,7 +546,8 @@ def split_params(data):
544546
def unsplit_params(data):
545547
flat = {}
546548
for k in ["standard", "scanned_encoder", "scanned_decoder"]:
547-
flat.update(traverse_util.flatten_dict(unfreeze(data[k])))
549+
if k in data:
550+
flat.update(traverse_util.flatten_dict(unfreeze(data[k])))
548551
return freeze(traverse_util.unflatten_dict(flat))
549552

550553

@@ -1483,7 +1486,7 @@ def run_save_model(state, eval_metrics=None):
14831486
logger.info(" Ready to start training")
14841487
with mesh:
14851488
for epoch in epochs:
1486-
state.replace(epoch=epoch)
1489+
state = state.replace(epoch=epoch)
14871490
local_state["epoch"] = epoch
14881491
# ======================== Training ================================
14891492
metrics_logger.update_state_metrics(local_state)

0 commit comments

Comments
 (0)