diff --git a/src/pruna/config/smash_config.py b/src/pruna/config/smash_config.py index f5995969..20866253 100644 --- a/src/pruna/config/smash_config.py +++ b/src/pruna/config/smash_config.py @@ -172,6 +172,24 @@ def load_from_json(self, path: str) -> None: setattr(self, name, config_dict.pop(name)) + # Normalize algorithm groups in config dict + current_groups = set(config_dict.keys()) + expected_groups = set(ALGORITHM_GROUPS) + + # Remove extra groups with warning if they have values + for group in current_groups - expected_groups: + if config_dict[group] is not None: + pruna_logger.warning( + f"Removing non-existing algorithm group: {group}, with value: {config_dict[group]}.\n" + "This is likely due to a version difference between the saved model and the current library.\n" + "You can use an older version of Pruna to load the model or reconfigure the model." + ) + del config_dict[group] + + # Add missing groups with info message + for group in expected_groups - current_groups: + config_dict[group] = None + self._configuration = Configuration(SMASH_SPACE, values=config_dict) if os.path.exists(os.path.join(path, TOKENIZER_SAVE_PATH)):