Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/pruna/config/smash_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,24 @@ def load_from_json(self, path: str) -> None:

setattr(self, name, config_dict.pop(name))

# Normalize algorithm groups in config dict
current_groups = set(config_dict.keys())
expected_groups = set(ALGORITHM_GROUPS)

# Remove extra groups with warning if they have values
for group in current_groups - expected_groups:
if config_dict[group] is not None:
pruna_logger.warning(
f"Removing non-existing algorithm group: {group}, with value: {config_dict[group]}.\n"
"This is likely due to a version difference between the saved model and the current library.\n"
"You can use an older version of Pruna to load the model or reconfigure the model."
)
del config_dict[group]

# Add missing groups with info message
for group in expected_groups - current_groups:
config_dict[group] = None

self._configuration = Configuration(SMASH_SPACE, values=config_dict)

if os.path.exists(os.path.join(path, TOKENIZER_SAVE_PATH)):
Expand Down