Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def validate_keys(keys):
assert keys["local_checkpoint_period"] > 0, "A positive local checkpoint period must be specified when using emergency checkpoint"
else:
max_logging.log("Not using emergency checkpoint, ignoring local_checkpoint_directory and local_checkpoint_period")
if keys["num_experts"] > 1:
# Currently, Megablox only supports data parallelism
validate_megablox_parallelism(keys)


def validate_data_input(keys):
Expand Down Expand Up @@ -420,9 +423,6 @@ def validate_and_update_keys(raw_keys, model_keys, config_name: str):
"""Validate and update model specific config keys"""
max_logging.log("Updating following parameters in config\n")

if raw_keys["num_experts"] > 1:
# Currently, Megablox only supports data parallelism
validate_megablox_parallelism(raw_keys)

for k in model_keys:
max_logging.log(f"{k}: {model_keys[k]}")
Expand Down