From 817919e5d08f31857604f0c71ea9b50b235100f3 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Thu, 18 Apr 2024 09:46:17 -0700 Subject: [PATCH] Preserve modelname (#3122) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/3122 Reviewed By: mikekgfb Differential Revision: D56212361 --- examples/models/llama2/builder.py | 4 ++++ examples/models/llama2/export_llama_lib.py | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/models/llama2/builder.py b/examples/models/llama2/builder.py index 00d71a5b014..b05dc19bfc0 100644 --- a/examples/models/llama2/builder.py +++ b/examples/models/llama2/builder.py @@ -62,6 +62,7 @@ def to_torch_dtype(self) -> torch.dtype: def load_llama_model( *, + modelname: str = "llama2", checkpoint: Optional[str] = None, checkpoint_dir: Optional[str] = None, params_path: str, @@ -114,6 +115,7 @@ def load_llama_model( return LlamaEdgeManager( model=model, + modelname=modelname, weight_type=weight_type, dtype=dtype, use_kv_cache=use_kv_cache, @@ -131,6 +133,7 @@ class LlamaEdgeManager: def __init__( self, model, + modelname, weight_type, dtype, use_kv_cache, @@ -139,6 +142,7 @@ def __init__( verbose: bool = False, ): self.model = model + self.modelname = modelname self.weight_type = weight_type self.dtype = dtype self.example_inputs = example_inputs diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 6a34c9bd889..9a8baf5aa1f 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -485,7 +485,6 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager: ) params_path = canonical_path(args.params) output_dir_path = canonical_path(args.output_dir, dir=True) - modelname = "llama2" weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA # dtype override @@ -552,6 +551,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager: return ( load_llama_model( + modelname=modelname, checkpoint=checkpoint_path, checkpoint_dir=checkpoint_dir, params_path=params_path, @@ -599,6 +599,8 @@ def _export_llama(modelname, args) -> str: # noqa: C901 modelname, args ).export_to_edge(quantizers) + modelname = builder_exported_to_edge.modelname + # to_backend partitioners = [] if pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None: