diff --git a/exir/program/_program.py b/exir/program/_program.py index 8b6e84b000b..c62214f051c 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -743,7 +743,7 @@ class EdgeProgramManager: def __init__( self, - edge_programs: Dict[str, ExportedProgram], + edge_programs: Union[ExportedProgram, Dict[str, ExportedProgram]], constant_methods: Optional[Dict[str, Any]] = None, compile_config: Optional[EdgeCompileConfig] = None, ): @@ -753,6 +753,8 @@ def __init__( Constructs an EdgeProgramManager from an existing set of exported programs in edge dialect. """ config = compile_config or EdgeCompileConfig() + if not isinstance(edge_programs, dict): + edge_programs = {"forward": edge_programs} for name, program in edge_programs.items(): try: EXIREdgeDialectVerifier( @@ -763,7 +765,7 @@ def __init__( logging.info(f"Input program {name} is not in aten dialect.") raise e - self._edge_programs = edge_programs + self._edge_programs: Dict[str, ExportedProgram] = edge_programs self._config_methods = constant_methods @property diff --git a/exir/serde/TARGETS b/exir/serde/TARGETS index fceebe7869f..10c970867d7 100644 --- a/exir/serde/TARGETS +++ b/exir/serde/TARGETS @@ -3,6 +3,8 @@ load("@fbcode_macros//build_defs:python_library.bzl", "python_library") oncall("executorch") python_library( + # @autodeps-skip for some reason autodeps thinks this target + # needs to depend on exir:lib which it doesn't. name = "serialize", srcs = [ "export_serialize.py", diff --git a/exir/serde/serialize.py b/exir/serde/serialize.py index 34d55252d83..5826a52b01f 100644 --- a/exir/serde/serialize.py +++ b/exir/serde/serialize.py @@ -33,7 +33,6 @@ from executorch.exir.lowered_backend_module import ( LoweredBackendModule as ExirLoweredBackendModule, ) -from executorch.exir.serde.export_serialize import SerializedArtifact from executorch.exir.serde.schema import ( CompileSpec, LoweredBackendModule as SerdeLoweredBackendModule, @@ -680,7 +679,7 @@ def deserialize( root=state_dict, graph=dummy_g, graph_signature=ep.ExportGraphSignature(input_specs=[], output_specs=[]), - state_dict={}, # TODO(T157676982) + state_dict=state_dict, # TODO(T157676982) range_constraints=range_constraints, module_call_graph=module_call_graph, verifier=load_verifier( @@ -765,7 +764,7 @@ def save( if not isinstance(ep_save, ep.ExportedProgram): raise TypeError(f"save() expects an ExportedProgram but got {type(ep)}") - artifact: SerializedArtifact = serialize(ep_save, opset_version) + artifact: export_serialize.SerializedArtifact = serialize(ep_save, opset_version) if isinstance(f, (str, os.PathLike)): f = os.fspath(f) @@ -836,10 +835,12 @@ def load( assert serialized_exported_program is not None assert serialized_state_dict is not None assert serialized_constants is not None - artifact: SerializedArtifact = SerializedArtifact( - serialized_exported_program, - serialized_state_dict, - serialized_constants, + artifact: export_serialize.SerializedArtifact = ( + export_serialize.SerializedArtifact( + serialized_exported_program, + serialized_state_dict, + serialized_constants, + ) ) # Deserialize ExportedProgram diff --git a/exir/tests/test_serde.py b/exir/tests/test_serde.py index 1ada5169479..2c68920ff34 100644 --- a/exir/tests/test_serde.py +++ b/exir/tests/test_serde.py @@ -159,6 +159,28 @@ def forward(self, x): edge_new = deserialize(serialize(edge.exported_program())) self.check_ep(edge.exported_program(), edge_new, model_inputs) + def test_model_with_weights(self) -> None: + class LinearAdd(nn.Module): + def __init__(self, M: int, N: int): + super().__init__() + self.M = M + self.N = N + self.linear = torch.nn.Linear(M, N) + + def forward(self, x, y): + x = self.linear(x) + y = self.linear(y) + return torch.add(x, y) + + @classmethod + def _get_random_inputs(cls): + return (torch.rand(128, 20), torch.rand(128, 20)) + + linear_add = LinearAdd(20, 30) + model_inputs = LinearAdd._get_random_inputs() + + self.check_serde(linear_add, model_inputs) + def test_delegate_partitioner(self) -> None: class Model(torch.nn.Module): def __init__(self):