From 698612f0ec2dff92bcb6b62e5f696fc05d047f65 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Tue, 25 Mar 2025 16:24:27 -0700 Subject: [PATCH 1/2] Update Llama2 to use zero weights when no checkpoint is provided --- examples/models/llama/model.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index ec10ae5a649..75334f4980c 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -259,15 +259,22 @@ def __init__(self, **kwargs): assign=True, ) # self.model_ = Transformer(gptconf) else: - print("Checkpoint not provided, defaulting to uninitialized weights.") + print("Checkpoint not provided, defaulting weights to zeros.") self.model_.to_empty(device="cpu") + for p in self.model_.parameters(): + p.data.fill_(0) + for b in self.model_.buffers(): + b.data.fill_(0) except RuntimeError as e: print( - f"Could not load checkpoint into mode and will default to uninitialized weights due to error: {e}." + f"Could not load checkpoint into mode and will defaulting weights to zeros due to error: {e}." ) # Need to provide concrete (empty) values for meta-initialized tensors for quantization. self.model_.to_empty(device="cpu") - + for p in self.model_.parameters(): + p.data.fill_(0) + for b in self.model_.buffers(): + b.data.fill_(0) if missing: missing_weights = [fqn for fqn in missing if fqn.endswith(".weight")] if missing_weights: From ccf664e188b58224dfffb638de371f27c1fdcd61 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Tue, 25 Mar 2025 16:24:55 -0700 Subject: [PATCH 2/2] Specify Quant Type in AoT Compiler for better results --- .ci/scripts/gather_test_models.py | 4 +-- examples/xnnpack/__init__.py | 49 ++++++++++++++++---------- examples/xnnpack/aot_compiler.py | 6 ++-- examples/xnnpack/quantization/utils.py | 16 +++++++-- 4 files changed, 50 insertions(+), 25 deletions(-) diff --git a/.ci/scripts/gather_test_models.py b/.ci/scripts/gather_test_models.py index 3f22d7699de..802aee4b53c 100755 --- a/.ci/scripts/gather_test_models.py +++ b/.ci/scripts/gather_test_models.py @@ -14,7 +14,7 @@ from typing import Any from examples.models import MODEL_NAME_TO_MODEL -from examples.xnnpack import MODEL_NAME_TO_OPTIONS +from examples.xnnpack import MODEL_NAME_TO_OPTIONS, QuantType DEFAULT_RUNNERS = { "linux": "linux.2xlarge", @@ -154,7 +154,7 @@ def export_models_for_ci() -> dict[str, dict]: if backend == "xnnpack": if name not in MODEL_NAME_TO_OPTIONS: continue - if MODEL_NAME_TO_OPTIONS[name].quantization: + if MODEL_NAME_TO_OPTIONS[name].quantization != QuantType.NONE: backend += "-quantization" if MODEL_NAME_TO_OPTIONS[name].delegation: diff --git a/examples/xnnpack/__init__.py b/examples/xnnpack/__init__.py index d8de9f6a36e..e78e1fec5be 100644 --- a/examples/xnnpack/__init__.py +++ b/examples/xnnpack/__init__.py @@ -7,33 +7,44 @@ # pyre-unsafe from dataclasses import dataclass +from enum import Enum + + +class QuantType(Enum): + NONE = 1 + # Used for Operations that don't have weights + STATIC_PER_TENSOR = 2 + # Used best for CNN/RNN Models with Conv layers + STATIC_PER_CHANNEL = 3 + # Used for Linear Layers and Transformer Based Models + DYNAMIC_PER_CHANNEL = 4 @dataclass class XNNPACKOptions(object): - quantization: bool + quantization: QuantType delegation: bool MODEL_NAME_TO_OPTIONS = { - "linear": XNNPACKOptions(True, True), - "add": XNNPACKOptions(True, True), - "add_mul": XNNPACKOptions(True, True), - "dl3": XNNPACKOptions(True, True), - "ic3": XNNPACKOptions(True, True), - "ic4": XNNPACKOptions(True, True), - "mv2": XNNPACKOptions(True, True), - "mv3": XNNPACKOptions(True, True), - "resnet18": XNNPACKOptions(True, True), - "resnet50": XNNPACKOptions(True, True), - "vit": XNNPACKOptions(True, True), - "w2l": XNNPACKOptions(True, True), - "edsr": XNNPACKOptions(True, True), - "mobilebert": XNNPACKOptions(True, True), - "llama2": XNNPACKOptions(False, True), - "emformer_join": XNNPACKOptions(True, True), - "emformer_predict": XNNPACKOptions(True, True), - "emformer_transcribe": XNNPACKOptions(True, True), + "linear": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True), + "add": XNNPACKOptions(QuantType.STATIC_PER_TENSOR, True), + "add_mul": XNNPACKOptions(QuantType.STATIC_PER_TENSOR, True), + "dl3": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True), + "ic3": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True), + "ic4": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True), + "mv2": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True), + "mv3": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True), + "resnet18": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True), + "resnet50": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True), + "vit": XNNPACKOptions(QuantType.DYNAMIC_PER_CHANNEL, True), + "w2l": XNNPACKOptions(QuantType.DYNAMIC_PER_CHANNEL, True), + "edsr": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True), + "mobilebert": XNNPACKOptions(QuantType.DYNAMIC_PER_CHANNEL, True), + "llama2": XNNPACKOptions(QuantType.DYNAMIC_PER_CHANNEL, True), + "emformer_join": XNNPACKOptions(QuantType.DYNAMIC_PER_CHANNEL, True), + "emformer_predict": XNNPACKOptions(QuantType.DYNAMIC_PER_CHANNEL, True), + "emformer_transcribe": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True), } diff --git a/examples/xnnpack/aot_compiler.py b/examples/xnnpack/aot_compiler.py index e1542245aca..6db0d82a274 100644 --- a/examples/xnnpack/aot_compiler.py +++ b/examples/xnnpack/aot_compiler.py @@ -66,7 +66,7 @@ args = parser.parse_args() - if not args.delegate: + if not args.delegate and args.quantize: raise NotImplementedError( "T161880157: Quantization-only without delegation is not supported yet" ) @@ -79,6 +79,8 @@ f"Available models are {list(MODEL_NAME_TO_OPTIONS.keys())}." ) + quant_type = MODEL_NAME_TO_OPTIONS[args.model_name].quantization + model, example_inputs, _, _ = EagerModelFactory.create_model( *MODEL_NAME_TO_MODEL[args.model_name] ) @@ -91,7 +93,7 @@ if args.quantize: logging.info("Quantizing Model...") # TODO(T165162973): This pass shall eventually be folded into quantizer - model = quantize(model, example_inputs) + model = quantize(model, example_inputs, quant_type) ep = torch.export.export_for_training(model, example_inputs) edge = to_edge_transform_and_lower( diff --git a/examples/xnnpack/quantization/utils.py b/examples/xnnpack/quantization/utils.py index de59c076a8f..9e49f15a99d 100644 --- a/examples/xnnpack/quantization/utils.py +++ b/examples/xnnpack/quantization/utils.py @@ -13,13 +13,25 @@ from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from .. import QuantType -def quantize(model, example_inputs): + +def quantize( + model, example_inputs, quant_type: QuantType = QuantType.STATIC_PER_TENSOR +): """This is the official recommended flow for quantization in pytorch 2.0 export""" logging.info(f"Original model: {model}") quantizer = XNNPACKQuantizer() # if we set is_per_channel to True, we also need to add out_variant of quantize_per_channel/dequantize_per_channel - operator_config = get_symmetric_quantization_config(is_per_channel=False) + is_per_channel = ( + quant_type == QuantType.STATIC_PER_CHANNEL + or quant_type == QuantType.DYNAMIC_PER_CHANNEL + ) + is_dynamic = quant_type == QuantType.DYNAMIC_PER_CHANNEL + operator_config = get_symmetric_quantization_config( + is_per_channel=is_per_channel, + is_dynamic=is_dynamic, + ) quantizer.set_global(operator_config) m = prepare_pt2e(model, quantizer) # calibration