Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def __init__(self, **kwargs):
if checkpoint:
self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint)
else:
self.model_.checkpoint_dtype = None
self.model_.checkpoint_dtype = torch.float32

if "int8" in str(checkpoint_path):
print("Using int8 weight-only quantization!")
Expand Down
2 changes: 1 addition & 1 deletion examples/models/llama/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
@dataclass
class ModelArgs:
dim: int = 4096
n_layers: int = 32
n_layers: int = 8
n_heads: int = 32
n_kv_heads: Optional[int] = None
vocab_size: int = 512 # Arbitrary value, should be defined later by tokenizer.
Expand Down
48 changes: 48 additions & 0 deletions examples/models/llama/tests/test_export_llama_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

from executorch.devtools.backend_debug import get_delegation_info
from executorch.examples.models.llama.export_llama_lib import (
_export_llama,
build_args_parser,
)

UNWANTED_OPS = [
Comment thread
jackzhxng marked this conversation as resolved.
"aten_permute_copy_default",
"aten_transpose_copy_default",
]


class ExportLlamaLibTest(unittest.TestCase):
def test_has_expected_ops_and_op_counts(self):
"""
Checks the presence of unwanted expensive ops.

Serves as a proxy for a performance regression test, as performance
is directly tied to which and how many of each ops are in the graph.

If this test breaks, please ensure that the difference in ops
is intentional before updating the expected ops.
"""
# Since we aren't loading a checkpoint, it doesn't
# matter what model we specify. Note that
# we cannot test quantization args in this way
# since quantization requires promoting meta tensors
# to device=cpu, which requires real weights.
parser = build_args_parser()
args = parser.parse_args([])
args.use_sdpa_with_kv_cache = True
args.use_kv_cache = True
args.verbose = True

builder = _export_llama(args)
graph_module = builder.edge_manager.exported_program().graph_module
delegation_info = get_delegation_info(graph_module)

for op, _op_info in delegation_info.delegation_by_operator.items():
self.assertTrue(op not in UNWANTED_OPS)
1 change: 1 addition & 0 deletions examples/models/llava/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
)
self.text_model_args = ModelArgs(
use_kv_cache=True,
n_layers=32,
vocab_size=self.model_.config.text_config.vocab_size,
hidden_dim=self.model_.config.text_config.intermediate_size,
max_batch_size=1, # doesn't work with default batch size 32
Expand Down
Loading