diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index ec10ae5a649..4871e18b00a 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -178,7 +178,7 @@ def __init__(self, **kwargs): if checkpoint: self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint) else: - self.model_.checkpoint_dtype = None + self.model_.checkpoint_dtype = torch.float32 if "int8" in str(checkpoint_path): print("Using int8 weight-only quantization!") diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index fdbd18eca1d..8d2641d9d78 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -5,7 +5,7 @@ @dataclass class ModelArgs: dim: int = 4096 - n_layers: int = 32 + n_layers: int = 8 n_heads: int = 32 n_kv_heads: Optional[int] = None vocab_size: int = 512 # Arbitrary value, should be defined later by tokenizer. diff --git a/examples/models/llama/tests/test_export_llama_lib.py b/examples/models/llama/tests/test_export_llama_lib.py new file mode 100644 index 00000000000..b94adb5fa0c --- /dev/null +++ b/examples/models/llama/tests/test_export_llama_lib.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from executorch.devtools.backend_debug import get_delegation_info +from executorch.examples.models.llama.export_llama_lib import ( + _export_llama, + build_args_parser, +) + +UNWANTED_OPS = [ + "aten_permute_copy_default", + "aten_transpose_copy_default", +] + + +class ExportLlamaLibTest(unittest.TestCase): + def test_has_expected_ops_and_op_counts(self): + """ + Checks the presence of unwanted expensive ops. + + Serves as a proxy for a performance regression test, as performance + is directly tied to which and how many of each ops are in the graph. + + If this test breaks, please ensure that the difference in ops + is intentional before updating the expected ops. + """ + # Since we aren't loading a checkpoint, it doesn't + # matter what model we specify. Note that + # we cannot test quantization args in this way + # since quantization requires promoting meta tensors + # to device=cpu, which requires real weights. + parser = build_args_parser() + args = parser.parse_args([]) + args.use_sdpa_with_kv_cache = True + args.use_kv_cache = True + args.verbose = True + + builder = _export_llama(args) + graph_module = builder.edge_manager.exported_program().graph_module + delegation_info = get_delegation_info(graph_module) + + for op, _op_info in delegation_info.delegation_by_operator.items(): + self.assertTrue(op not in UNWANTED_OPS) diff --git a/examples/models/llava/model.py b/examples/models/llava/model.py index 5e215d1c035..6ce4b701bbe 100644 --- a/examples/models/llava/model.py +++ b/examples/models/llava/model.py @@ -56,6 +56,7 @@ def __init__( ) self.text_model_args = ModelArgs( use_kv_cache=True, + n_layers=32, vocab_size=self.model_.config.text_config.vocab_size, hidden_dim=self.model_.config.text_config.intermediate_size, max_batch_size=1, # doesn't work with default batch size 32