From f98429a1f4802a30479b7911edda1bbffe7d5cfa Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 16 Oct 2024 10:17:22 -0400 Subject: [PATCH] Fix backward compatibility in normalization --- fast_llm/layers/language_model/config.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 9b6c78d8d..3976c69bc 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -81,6 +81,20 @@ def use_absolute_position_embeddings(self): # TODO: Set through num embeddings instead instead. return self.use_position_embeddings + @classmethod + def from_flat_dict( + cls, + default: dict[str], + strict: bool = True, + ): + # The backward compatibility fix in `NormalizationArchitectureConfig` + # won't work for older checkpoints saved with a flat config. + # TODO v0.2: Remove flat format + cls._handle_renamed_field(default, "normalization_type", "type") + cls._handle_renamed_field(default, "layer_norm_eps", "epsilon") + cls._handle_renamed_field(default, "zero_centered_normalization", "zero_centered") + return super().from_flat_dict(default, strict) + @config_class() class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig):