diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 99c15d1b9f..f4ad421871 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -92,6 +92,9 @@ def validate_keys(keys): assert keys["local_checkpoint_period"] > 0, "A positive local checkpoint period must be specified when using emergency checkpoint" else: max_logging.log("Not using emergency checkpoint, ignoring local_checkpoint_directory and local_checkpoint_period") + if keys["num_experts"] > 1: + # Currently, Megablox only supports data parallelism + validate_megablox_parallelism(keys) def validate_data_input(keys): @@ -420,9 +423,6 @@ def validate_and_update_keys(raw_keys, model_keys, config_name: str): """Validate and update model specific config keys""" max_logging.log("Updating following parameters in config\n") - if raw_keys["num_experts"] > 1: - # Currently, Megablox only supports data parallelism - validate_megablox_parallelism(raw_keys) for k in model_keys: max_logging.log(f"{k}: {model_keys[k]}")