Skip to content

Commit 2f66ada

Browse files
marcromeynakoumpa
andauthored
Adding serialization to all Auto* objects in HuggingFace transformers (NVIDIA-NeMo#11645)
* Adding serialization to all Auto* objects in HuggingFace transformers Signed-off-by: Marc Romeyn <mromeijn@nvidia.com> * Apply isort and black reformatting Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com> * Adding docs Signed-off-by: Marc Romeyn <mromeijn@nvidia.com> * Apply isort and black reformatting Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com> * Adding more doc-strings Signed-off-by: Marc Romeyn <mromeijn@nvidia.com> * Adding more doc-strings Signed-off-by: Marc Romeyn <mromeijn@nvidia.com> * Address comments Signed-off-by: Marc Romeijn <mromeijn@nvidia.com> * Apply isort and black reformatting Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com> * fix? Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> --------- Signed-off-by: Marc Romeyn <mromeijn@nvidia.com> Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com> Signed-off-by: Marc Romeijn <mromeijn@nvidia.com> Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> Co-authored-by: marcromeyn <marcromeyn@users.noreply.github.com> Co-authored-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
1 parent ea5ed67 commit 2f66ada

File tree

8 files changed

+314
-20
lines changed

8 files changed

+314
-20
lines changed
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from nemo.lightning.io.artifact.base import Artifact
22
from nemo.lightning.io.artifact.file import DirArtifact, DirOrStringArtifact, FileArtifact, PathArtifact
3+
from nemo.lightning.io.artifact.hf_auto import HFAutoArtifact
34

4-
__all__ = ["Artifact", "FileArtifact", "PathArtifact", "DirArtifact", "DirOrStringArtifact"]
5+
__all__ = ["Artifact", "FileArtifact", "PathArtifact", "DirArtifact", "DirOrStringArtifact", "HFAutoArtifact"]

nemo/lightning/io/artifact/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(self, attr: str, required: bool = True, skip: bool = False):
2626
self.skip = skip
2727

2828
@abstractmethod
29-
def dump(self, value: ValueT, absolute_dir: Path, relative_dir: Path) -> ValueT:
29+
def dump(self, instance, value: ValueT, absolute_dir: Path, relative_dir: Path) -> ValueT:
3030
pass
3131

3232
@abstractmethod

nemo/lightning/io/artifact/file.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323

2424
class PathArtifact(Artifact[Path]):
25-
def dump(self, value: Path, absolute_dir: Path, relative_dir: Path) -> Path:
25+
def dump(self, instance, value: Path, absolute_dir: Path, relative_dir: Path) -> Path:
2626
new_value = copy_file(value, absolute_dir, relative_dir)
2727
return new_value
2828

@@ -31,7 +31,7 @@ def load(self, path: Path) -> Path:
3131

3232

3333
class FileArtifact(Artifact[str]):
34-
def dump(self, value: str, absolute_dir: Path, relative_dir: Path) -> str:
34+
def dump(self, instance, value: str, absolute_dir: Path, relative_dir: Path) -> str:
3535
if not pathize(value).exists():
3636
# This is Artifact is just a string.
3737
return fdl.Config(FileArtifact, attr=value, skip=True)
@@ -58,7 +58,7 @@ def copy_file(src: Union[Path, str], path: Union[Path, str], relative_dst: Union
5858

5959

6060
class DirArtifact(Artifact[str]):
61-
def dump(self, value: str, absolute_dir: Path, relative_dir: Path) -> str:
61+
def dump(self, instance, value: str, absolute_dir: Path, relative_dir: Path) -> str:
6262
value = pathize(value)
6363
absolute_dir = pathize(absolute_dir)
6464
relative_dir = pathize(relative_dir)
@@ -76,11 +76,11 @@ def load(self, path: str) -> str:
7676

7777

7878
class DirOrStringArtifact(DirArtifact):
79-
def dump(self, value: str, absolute_dir: Path, relative_dir: Path) -> str:
79+
def dump(self, instance, value: str, absolute_dir: Path, relative_dir: Path) -> str:
8080
if not pathize(value).exists():
8181
# This is Artifact is just a string.
8282
return fdl.Config(DirOrStringArtifact, attr=value, skip=True)
83-
return super().dump(value, absolute_dir, relative_dir)
83+
return super().dump(instance, value, absolute_dir, relative_dir)
8484

8585
def load(self, path: str) -> str:
8686
return path
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
"""HuggingFace model serialization support for NeMo's configuration system.
17+
18+
This module provides integration between NeMo's configuration system and HuggingFace's
19+
pretrained models. It enables automatic serialization and deserialization of HuggingFace
20+
models within NeMo's configuration framework.
21+
22+
The integration works by:
23+
1. Detecting HuggingFace models through their characteristic methods (save_pretrained/from_pretrained)
24+
2. Converting them to Fiddle configurations that preserve the model's class and path
25+
3. Providing an artifact handler (HFAutoArtifact) that manages the actual model files
26+
27+
Example:
28+
```python
29+
from transformers import AutoModel
30+
31+
# This model will be automatically handled by the HFAutoArtifact system
32+
model = AutoModel.from_pretrained("bert-base-uncased")
33+
34+
# When serialized, the model files will be saved to the artifacts directory
35+
# When deserialized, the model will be loaded from the saved files
36+
```
37+
"""
38+
39+
import contextlib
40+
import inspect
41+
import threading
42+
from pathlib import Path
43+
44+
import fiddle as fdl
45+
46+
from nemo.lightning.io.artifact import Artifact
47+
from nemo.lightning.io.to_config import to_config
48+
49+
_local = threading.local()
50+
51+
52+
class HFAutoArtifact(Artifact):
53+
"""Artifact handler for HuggingFace pretrained model/processor/tokenizer/etc..
54+
55+
This handler manages the serialization and deserialization of HuggingFace models
56+
by utilizing their save_pretrained/from_pretrained methods. It saves models to
57+
an 'artifacts' subdirectory within the specified path.
58+
"""
59+
60+
def dump(self, instance, value: Path, absolute_dir: Path, relative_dir: Path) -> Path:
61+
"""Save a HuggingFace model to disk.
62+
63+
Args:
64+
instance: The HuggingFace model instance to save
65+
value: Original path value (unused)
66+
absolute_dir: Absolute path to the save directory
67+
relative_dir: Relative path from the config file to the save directory
68+
69+
Returns:
70+
str: The relative path to the saved model artifacts
71+
"""
72+
instance.save_pretrained(Path(absolute_dir) / "artifacts")
73+
return "./" + str(Path(relative_dir) / "artifacts")
74+
75+
def load(self, path: Path) -> Path:
76+
"""Return the path to load a HuggingFace model.
77+
78+
Args:
79+
path: Path to the saved model artifacts
80+
81+
Returns:
82+
Path: The same path, to be used with from_pretrained
83+
"""
84+
return path
85+
86+
87+
@contextlib.contextmanager
88+
def from_pretrained_kwargs(**kwargs):
89+
"""Context manager for passing additional kwargs to from_pretrained.
90+
91+
Args:
92+
**kwargs: Keyword arguments to pass to from_pretrained
93+
94+
Example:
95+
with from_pretrained_kwargs(trust_remote_code=True):
96+
io.load_context("path/to/checkpoint")
97+
"""
98+
if not hasattr(_local, "kwargs"):
99+
_local.kwargs = {}
100+
previous = _local.kwargs.copy()
101+
_local.kwargs.update(kwargs)
102+
try:
103+
yield
104+
finally:
105+
_local.kwargs = previous
106+
107+
108+
def from_pretrained(auto_cls, pretrained_model_name_or_path="dummy"):
109+
"""Factory function for loading HuggingFace pretrained models.
110+
111+
This function is used as the serialization target for HuggingFace models.
112+
When deserialized, it will recreate the model using its from_pretrained method.
113+
114+
Args:
115+
auto_cls: The HuggingFace model class (e.g., AutoModel, AutoTokenizer)
116+
pretrained_model_name_or_path: Path to the saved model or model identifier
117+
118+
Returns:
119+
The loaded HuggingFace model
120+
"""
121+
kwargs = getattr(_local, "kwargs", {})
122+
return auto_cls.from_pretrained(pretrained_model_name_or_path, **kwargs)
123+
124+
125+
@to_config.register(
126+
lambda v: not inspect.isclass(v)
127+
and getattr(v, "__module__", "").startswith("transformers")
128+
and hasattr(v, "save_pretrained")
129+
and hasattr(v, "from_pretrained")
130+
)
131+
def handle_hf_pretrained(value):
132+
"""Convert a HuggingFace model instance to a Fiddle configuration.
133+
134+
This handler detects HuggingFace model instances by checking for the presence
135+
of save_pretrained and from_pretrained methods. It converts them to a Fiddle
136+
configuration that will recreate the model using from_pretrained.
137+
138+
Args:
139+
value: A HuggingFace model instance
140+
141+
Returns:
142+
fdl.Config: A Fiddle configuration that will recreate the model
143+
"""
144+
return fdl.Config(
145+
from_pretrained,
146+
auto_cls=value.__class__,
147+
pretrained_model_name_or_path="dummy",
148+
)
149+
150+
151+
# Register the HFAutoArtifact handler for the pretrained_model_name_or_path parameter
152+
from_pretrained.__io_artifacts__ = [HFAutoArtifact("pretrained_model_name_or_path")]

nemo/lightning/io/artifact/pickle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222

2323
class PickleArtifact(Artifact[Any]):
24-
def dump(self, absolute_dir: Path, relative_dir: Path) -> Path:
24+
def dump(self, instance, absolute_dir: Path, relative_dir: Path) -> Path:
2525
relative_file = self.file_path(relative_dir)
2626
with open(Path(absolute_dir) / relative_file, "wb") as f:
2727
dump(value, f)

nemo/lightning/io/fdl_torch.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,20 @@
1919
"""
2020

2121
import types
22-
from functools import partial
2322

24-
import fiddle as fdl
2523
import libcst as cst
2624
import torch
2725
import torch.nn as nn
2826
from fiddle._src import daglish_extensions
2927
from fiddle._src.codegen import import_manager, py_val_to_cst_converter, special_value_codegen
3028
from fiddle._src.experimental import serialization
3129

30+
from nemo.lightning.io.artifact import * # noqa: F403
31+
from nemo.lightning.io.to_config import to_config
32+
3233

3334
def _make_torch_importable(name: str) -> special_value_codegen.Importable:
35+
"""Make a torch importable."""
3436
return special_value_codegen.SingleImportable("torch", lambda torch_name: f"{torch_name}.{name}")
3537

3638

@@ -67,6 +69,7 @@ def _make_torch_importable(name: str) -> special_value_codegen.Importable:
6769

6870

6971
def _make_torch_nn_importable(name: str) -> special_value_codegen.Importable:
72+
"""Make a torch.nn importable."""
7073
return special_value_codegen.SingleImportable("torch", lambda torch_mod_name: f"{torch_mod_name}.nn.{name}")
7174

7275

@@ -88,6 +91,7 @@ def is_torch_tensor(value):
8891

8992

9093
def convert_torch_tensor_to_cst(value, convert_child):
94+
"""Convert a PyTorch tensor to a CST node."""
9195
return cst.Call(
9296
func=cst.Attribute(value=convert_child(torch), attr=cst.Name("tensor")),
9397
args=[
@@ -124,11 +128,10 @@ def enable():
124128

125129
# Monkey-patch the Serialization class to handle things like activation-functions
126130
def _modified_serialize(self, value, current_path, all_paths=None):
131+
"""Serialize a value to a Fiddle configuration."""
127132
if isinstance(value, types.BuiltinFunctionType):
128133
return self._pyref(value, current_path)
129-
if isinstance(value, partial):
130-
value = fdl.Partial(value.func, *value.args, **value.keywords)
131-
return self._original_serialize(value, current_path, all_paths)
134+
return self._original_serialize(to_config(value), current_path, all_paths)
132135

133136
serialization.Serialization._original_serialize = serialization.Serialization._serialize
134137
serialization.Serialization._serialize = _modified_serialize

nemo/lightning/io/mixin.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from nemo.lightning.io.capture import IOProtocol
3939
from nemo.lightning.io.connector import ModelConnector
4040
from nemo.lightning.io.fdl_torch import enable as _enable_ext
41+
from nemo.lightning.io.to_config import to_config
4142
from nemo.utils import logging
4243

4344
ConnT = TypeVar("ConnT", bound=ModelConnector)
@@ -233,8 +234,7 @@ def io_dump(self, output: Path, yaml_attrs: list[str]):
233234

234235
config_path = output_path / "io.json"
235236
with open(config_path, "w") as f:
236-
io = deepcopy(self.__io__)
237-
_artifact_transform_save(io, output_path, local_artifacts_dir)
237+
io = _artifact_transform_save(self, deepcopy(self.__io__), output_path, local_artifacts_dir)
238238
json = serialization.dump_json(io)
239239
f.write(json)
240240

@@ -632,8 +632,10 @@ def _io_path_elements_fn(x):
632632
return x.__io__.__path_elements__()
633633

634634

635-
def _artifact_transform_save(cfg: fdl.Config, output_path: Path, relative_dir: Path = "."):
636-
for artifact in getattr(cfg.__fn_or_cls__, "__io_artifacts__", []):
635+
def _artifact_transform_save(instance, cfg: fdl.Config, output_path: Path, relative_dir: Path = Path(".")):
636+
artifacts = getattr(cfg.__fn_or_cls__, "__io_artifacts__", [])
637+
638+
for artifact in artifacts:
637639
# Allow optional artifacts
638640
if artifact.skip or (not hasattr(cfg, artifact.attr) and not artifact.required):
639641
continue
@@ -647,16 +649,29 @@ def _artifact_transform_save(cfg: fdl.Config, output_path: Path, relative_dir: P
647649
raise ValueError(f"Artifact '{artifact.attr}' is required but not provided")
648650
continue
649651
## dump artifact and return the relative path
650-
new_val = artifact.dump(current_val, output_path, relative_dir)
652+
new_val = artifact.dump(instance, current_val, output_path, relative_dir)
651653
setattr(cfg, artifact.attr, new_val)
652654

653655
for attr in dir(cfg):
656+
child = to_config(getattr(cfg, attr))
657+
654658
try:
655-
if isinstance(getattr(cfg, attr), fdl.Config):
656-
_artifact_transform_save(getattr(cfg, attr), output_path=output_path, relative_dir=relative_dir)
659+
if isinstance(child, (fdl.Config, fdl.Partial)):
660+
setattr(
661+
cfg,
662+
attr,
663+
_artifact_transform_save(
664+
getattr(instance, attr, None),
665+
child,
666+
output_path=output_path,
667+
relative_dir=relative_dir,
668+
),
669+
)
657670
except ValueError:
658671
pass
659672

673+
return cfg
674+
660675

661676
def _artifact_transform_load(cfg: fdl.Config, path: Path):
662677
for artifact in getattr(cfg.__fn_or_cls__, "__io_artifacts__", []):

0 commit comments

Comments
 (0)