From 034e078af9fd0db12a1a6961d43a37b27d95deb4 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Thu, 15 May 2025 10:54:18 +0200 Subject: [PATCH 1/3] feat: enhance algorithm group management in SmashConfig --- src/pruna/config/smash_config.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/pruna/config/smash_config.py b/src/pruna/config/smash_config.py index f5995969..6a2edb32 100644 --- a/src/pruna/config/smash_config.py +++ b/src/pruna/config/smash_config.py @@ -172,6 +172,27 @@ def load_from_json(self, path: str) -> None: setattr(self, name, config_dict.pop(name)) + # ensure all algorithm groups are present and remove extra ones + extra_groups = set(config_dict.keys()) - set(ALGORITHM_GROUPS) + if extra_groups: + for group in extra_groups: + if config_dict[group] is not None: + pruna_logger.warning( + f"Removing non-existing algorithm group: {group}, with value: {config_dict[group]}.\n" + "This is likely due to a version difference between the saved model and the current library.\n" + "You can use an older version of Pruna to load the model or reconfigure the model." + ) + del config_dict[group] + + for group in ALGORITHM_GROUPS: + if group not in config_dict: + pruna_logger.info( + f"Adding missing algorithm group: {group}\n" + "This is likely due to a version difference between the saved model and the current library.\n" + "This is not an error, but it is recommended to update the model." + ) + config_dict[group] = None + self._configuration = Configuration(SMASH_SPACE, values=config_dict) if os.path.exists(os.path.join(path, TOKENIZER_SAVE_PATH)): From db5a9d8c7c85a0969ad89ff898e7658c473ad0a1 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Thu, 15 May 2025 10:57:02 +0200 Subject: [PATCH 2/3] refactor: improve algorithm group normalization in SmashConfig --- src/pruna/config/smash_config.py | 38 +++++++++++++++++--------------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/src/pruna/config/smash_config.py b/src/pruna/config/smash_config.py index 6a2edb32..eca70f20 100644 --- a/src/pruna/config/smash_config.py +++ b/src/pruna/config/smash_config.py @@ -172,26 +172,28 @@ def load_from_json(self, path: str) -> None: setattr(self, name, config_dict.pop(name)) - # ensure all algorithm groups are present and remove extra ones - extra_groups = set(config_dict.keys()) - set(ALGORITHM_GROUPS) - if extra_groups: - for group in extra_groups: - if config_dict[group] is not None: - pruna_logger.warning( - f"Removing non-existing algorithm group: {group}, with value: {config_dict[group]}.\n" - "This is likely due to a version difference between the saved model and the current library.\n" - "You can use an older version of Pruna to load the model or reconfigure the model." - ) - del config_dict[group] - - for group in ALGORITHM_GROUPS: - if group not in config_dict: - pruna_logger.info( - f"Adding missing algorithm group: {group}\n" + # Normalize algorithm groups in config dict + current_groups = set(config_dict.keys()) + expected_groups = set(ALGORITHM_GROUPS) + + # Remove extra groups with warning if they have values + for group in current_groups - expected_groups: + if config_dict[group] is not None: + pruna_logger.warning( + f"Removing non-existing algorithm group: {group}, with value: {config_dict[group]}.\n" "This is likely due to a version difference between the saved model and the current library.\n" - "This is not an error, but it is recommended to update the model." + "You can use an older version of Pruna to load the model or reconfigure the model." ) - config_dict[group] = None + del config_dict[group] + + # Add missing groups with info message + for group in expected_groups - current_groups: + pruna_logger.info( + f"Adding missing algorithm group: {group}\n" + "This is likely due to a version difference between the saved model and the current library.\n" + "This is not an error, but it is recommended to update the model." + ) + config_dict[group] = None self._configuration = Configuration(SMASH_SPACE, values=config_dict) From 1fb61af1e0537e88806359d79f87c5edfda1a246 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Thu, 15 May 2025 11:05:03 +0200 Subject: [PATCH 3/3] refactor: remove redundant logging for missing algorithm groups in SmashConfig --- src/pruna/config/smash_config.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/pruna/config/smash_config.py b/src/pruna/config/smash_config.py index eca70f20..20866253 100644 --- a/src/pruna/config/smash_config.py +++ b/src/pruna/config/smash_config.py @@ -188,11 +188,6 @@ def load_from_json(self, path: str) -> None: # Add missing groups with info message for group in expected_groups - current_groups: - pruna_logger.info( - f"Adding missing algorithm group: {group}\n" - "This is likely due to a version difference between the saved model and the current library.\n" - "This is not an error, but it is recommended to update the model." - ) config_dict[group] = None self._configuration = Configuration(SMASH_SPACE, values=config_dict)