diff --git a/examples/models/llama2/builder.py b/examples/models/llama2/builder.py index 00d71a5b014..b05dc19bfc0 100644 --- a/examples/models/llama2/builder.py +++ b/examples/models/llama2/builder.py @@ -62,6 +62,7 @@ def to_torch_dtype(self) -> torch.dtype: def load_llama_model( *, + modelname: str = "llama2", checkpoint: Optional[str] = None, checkpoint_dir: Optional[str] = None, params_path: str, @@ -114,6 +115,7 @@ def load_llama_model( return LlamaEdgeManager( model=model, + modelname=modelname, weight_type=weight_type, dtype=dtype, use_kv_cache=use_kv_cache, @@ -131,6 +133,7 @@ class LlamaEdgeManager: def __init__( self, model, + modelname, weight_type, dtype, use_kv_cache, @@ -139,6 +142,7 @@ def __init__( verbose: bool = False, ): self.model = model + self.modelname = modelname self.weight_type = weight_type self.dtype = dtype self.example_inputs = example_inputs diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 6a34c9bd889..9a8baf5aa1f 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -485,7 +485,6 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager: ) params_path = canonical_path(args.params) output_dir_path = canonical_path(args.output_dir, dir=True) - modelname = "llama2" weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA # dtype override @@ -552,6 +551,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager: return ( load_llama_model( + modelname=modelname, checkpoint=checkpoint_path, checkpoint_dir=checkpoint_dir, params_path=params_path, @@ -599,6 +599,8 @@ def _export_llama(modelname, args) -> str: # noqa: C901 modelname, args ).export_to_edge(quantizers) + modelname = builder_exported_to_edge.modelname + # to_backend partitioners = [] if pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None: