Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
WIP
  • Loading branch information
daniil-lyakhov committed Feb 12, 2025
commit 7c66314296db63523872df6407bfbc271d4d8e4c
87 changes: 38 additions & 49 deletions backends/openvino/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@
# limitations under the License.

from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple

import torch.fx
from torch.ao.quantization.observer import HistogramObserver
from torch.ao.quantization.observer import PerChannelMinMaxObserver
from torch.ao.quantization.observer import MinMaxObserver
from torch.ao.quantization.quantizer.quantizer import EdgeOrNode
from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation
from torch.ao.quantization.quantizer.quantizer import QuantizationSpec
Expand All @@ -24,25 +23,11 @@
from torch.ao.quantization.quantizer.quantizer import SharedQuantizationSpec

import nncf
import nncf.common.quantization as q
import nncf.experimental.torch.fx as nncf_fx
import nncf.parameters as p
import nncf.quantization.advanced_parameters as advanced_p
from nncf.common.graph.graph import NNCFGraph
from nncf.common.logging import nncf_logger
from nncf.common.quantization.quantizer_propagation.solver import QuantizerPropagationRule
from nncf.common.quantization.quantizer_setup import QuantizationPointBase
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup
from nncf.common.quantization.structs import QuantizationPreset
from nncf.common.quantization.structs import QuantizationScheme
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter
from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name
from nncf.experimental.torch.fx.transformations import fold_constant_except_qdq
from nncf.parameters import ModelType
from nncf.parameters import QuantizationMode
from nncf.parameters import TargetDevice
from nncf.quantization.advanced_parameters import FP8QuantizationParameters
from nncf.quantization.advanced_parameters import OverflowFix
from nncf.quantization.advanced_parameters import QuantizationParameters
from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization
from nncf.scopes import IgnoredScope
from nncf.torch.model_graph_manager import get_weight_tensor_port_ids

QUANT_ANNOTATION_KEY = "quantization_annotation"

Expand All @@ -56,16 +41,15 @@ class OpenVINOQuantizer(Quantizer):
def __init__(
self,
*,
mode: Optional[QuantizationMode] = None,
preset: Optional[QuantizationPreset] = None,
target_device: TargetDevice = TargetDevice.ANY,
model_type: Optional[ModelType] = None,
ignored_scope: Optional[IgnoredScope] = None,
overflow_fix: Optional[OverflowFix] = None,
mode: Optional[p.QuantizationMode] = None,
preset: Optional[q.structs.QuantizationPreset] = None,
target_device: p.TargetDevice = p.TargetDevice.ANY,
transformer_model: bool = False,
ignored_scope: Optional[nncf.IgnoredScope] = None,
overflow_fix: Optional[advanced_p.OverflowFix] = None,
quantize_outputs: bool = False,
activations_quantization_params: Optional[Union[QuantizationParameters, FP8QuantizationParameters]] = None,
weights_quantization_params: Optional[Union[QuantizationParameters, FP8QuantizationParameters]] = None,
quantizer_propagation_rule: QuantizerPropagationRule = QuantizerPropagationRule.MERGE_ALL_IN_ONE,
activations_quantization_params: Optional[advanced_p.QuantizationParameters] = None,
weights_quantization_params: Optional[advanced_p.QuantizationParameters] = None,
):
"""
:param mode: Defines optimization mode for the algorithm. None by default.
Expand All @@ -89,29 +73,28 @@ def __init__(
:param activations_quantization_params: Quantization parameters for model
activations.
:param weights_quantization_params: Quantization parameters for model weights.
:param quantizer_propagation_rule: The strategy to be used while propagating and merging quantizers.
MERGE_ALL_IN_ONE by default.
"""
self._min_max_algo = MinMaxQuantization(
self._min_max_algo = nncf.quantization.algorithms.min_max.algorithm.MinMaxQuantization(
mode=mode,
preset=preset,
target_device=target_device,
model_type=model_type,
model_type=p.ModelType.TRANSFORMER if transformer_model else None,
ignored_scope=ignored_scope,
overflow_fix=overflow_fix,
quantize_outputs=quantize_outputs,
activations_quantization_params=activations_quantization_params,
weights_quantization_params=weights_quantization_params,
quantizer_propagation_rule=quantizer_propagation_rule,
)

def get_quantization_setup(self, model: torch.fx.GraphModule, nncf_graph: NNCFGraph) -> SingleConfigQuantizerSetup:
def get_nncf_quantization_setup(
self, model: torch.fx.GraphModule, nncf_graph: NNCFGraph
) -> q.quantizer_setup.SingleConfigQuantizerSetup:
self._min_max_algo._set_backend_entity(model)
return self._min_max_algo.find_quantization_setup(model, nncf_graph)

def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
nncf_graph = GraphConverter.create_nncf_graph(model)
quantization_setup = self.get_quantization_setup(model, nncf_graph)
nncf_graph = nncf_fx.nncf_graph_builder.GraphConverter.create_nncf_graph(model)
quantization_setup = self.get_nncf_quantization_setup(model, nncf_graph)

graph = model.graph
node_vs_torch_annotation = defaultdict(QuantizationAnnotation)
Expand All @@ -138,7 +121,9 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
)
raise nncf.InternalError(msg)

root_target_node = get_graph_node_by_name(graph, root_qp.insertion_point.target_node_name)
root_target_node = nncf_fx.node_utils.get_graph_node_by_name(
graph, root_qp.insertion_point.target_node_name
)
root_edge_or_node = self._get_edge_or_node(root_target_node, root_qp, nncf_graph)

for quantizer_id in quantizer_ids:
Expand All @@ -155,10 +140,11 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node, annotation in node_vs_torch_annotation.items():
assert QUANT_ANNOTATION_KEY not in node.meta
node.meta[QUANT_ANNOTATION_KEY] = annotation
return model

@staticmethod
def _get_unified_scales_root_quantizer_id(
nncf_graph: NNCFGraph, quantizer_ids: List[int], quantizer_setup: SingleConfigQuantizerSetup
nncf_graph: NNCFGraph, quantizer_ids: List[int], quantizer_setup: q.quantizer_setup.SingleConfigQuantizerSetup
) -> int:
"""
Identifies the earliest quantizer node ID based on the corresponding `nncf_node.node_id`
Expand All @@ -184,7 +170,7 @@ def _get_unified_scales_root_quantizer_id(
def _get_edge_or_node_and_annotation(
graph: torch.fx.Graph,
nncf_graph: NNCFGraph,
qp: QuantizationPointBase,
qp: q.quantizer_setup.QuantizationPointBase,
node_vs_torch_annotation: Dict[torch.fx.Node, QuantizationAnnotation],
) -> Tuple[EdgeOrNode, QuantizationAnnotation]:
"""
Expand All @@ -198,13 +184,15 @@ def _get_edge_or_node_and_annotation(
QuantizationAnnotations.
:return: A tuple containing the EdgeOrNode and its associated QuantizationAnnotation.
"""
target_node = get_graph_node_by_name(graph, qp.insertion_point.target_node_name)
target_node = nncf_fx.node_utils.get_graph_node_by_name(graph, qp.insertion_point.target_node_name)
annotation = node_vs_torch_annotation[target_node]
edge_or_node = OpenVINOQuantizer._get_edge_or_node(target_node, qp, nncf_graph)
return edge_or_node, annotation

@staticmethod
def _get_edge_or_node(target_node: torch.fx.Node, qp: QuantizationPointBase, nncf_graph: NNCFGraph) -> EdgeOrNode:
def _get_edge_or_node(
target_node: torch.fx.Node, qp: q.quantizer_setup.QuantizationPointBase, nncf_graph: NNCFGraph
) -> EdgeOrNode:
"""
Returns the edge or node based on the given target node and quantization point.

Expand All @@ -216,10 +204,10 @@ def _get_edge_or_node(target_node: torch.fx.Node, qp: QuantizationPointBase, nnc
ip = qp.insertion_point
if qp.is_weight_quantization_point():
nncf_node = nncf_graph.get_node_by_name(target_node.name)
weights_ports_ids = get_weight_tensor_port_ids(nncf_node, nncf_graph)
weights_ports_ids = nncf.torch.model_graph_manager.get_weight_tensor_port_ids(nncf_node, nncf_graph)
if len(weights_ports_ids) > 1:
# TODO(dlyakhov): support quantization for nodes with several weights
nncf_logger.warning(
nncf.common.logging.nncf_logger.warning(
f"Quantization of the weighted node {target_node.name}"
" is not yet supported by the OpenVINOQuantizer."
f" Only the weight on port ID {weights_ports_ids[0]} will be quantized."
Expand Down Expand Up @@ -253,7 +241,7 @@ def _fill_torch_ao_annotation(
annotation_to_update.input_qspec_map[edge_or_node[0]] = qspec

@staticmethod
def _get_torch_ao_qspec_from_qp(qp: QuantizationPointBase) -> QuantizationSpec:
def _get_torch_ao_qspec_from_qp(qp: q.quantizer_setup.QuantizationPointBase) -> QuantizationSpec:
"""
Retrieves the quantization configuration from the given quantization point and
converts it into a QuantizationSpec.
Expand All @@ -269,15 +257,16 @@ def _get_torch_ao_qspec_from_qp(qp: QuantizationPointBase) -> QuantizationSpec:
if qconfig.per_channel:
torch_qscheme = (
torch.per_channel_symmetric
if qconfig.mode is QuantizationScheme.SYMMETRIC
if qconfig.mode is q.structs.QuantizationScheme.SYMMETRIC
else torch.per_channel_affine
)
else:
torch_qscheme = (
torch.per_tensor_symmetric if qconfig.mode is QuantizationScheme.SYMMETRIC else torch.per_tensor_affine
torch.per_tensor_symmetric
if qconfig.mode is q.structs.QuantizationScheme.SYMMETRIC
else torch.per_tensor_affine
)
if is_weight:
observer = PerChannelMinMaxObserver if qconfig.per_channel else MinMaxObserver
observer = PerChannelMinMaxObserver
quant_min = -128
quant_max = 127
Expand Down Expand Up @@ -307,5 +296,5 @@ def validate(self, model: torch.fx.GraphModule) -> None:
pass

def transform_for_annotation(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
fold_constant_except_qdq(model)
nncf_fx.transformations.fold_constant_except_qdq(model)
return model
132 changes: 103 additions & 29 deletions examples/openvino/aot/aot_openvino_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@
from torch.export.exported_program import ExportedProgram
import argparse
from executorch.backends.openvino import OpenVINOQuantizer
#from nncf.experimental.torch.fx.quantization.quantizer.openvino_quantizer import OpenVINOQuantizer
from nncf.experimental.torch.fx.quantization.quantize_pt2e import quantize_pt2e
from torch.ao.quantization.quantize_pt2e import (
convert_pt2e,
prepare_pt2e,
)

from sklearn.metrics import accuracy_score
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

# Function to load a model based on the selected suite
def load_model(suite: str, model_name: str):
Expand All @@ -42,20 +46,17 @@ def load_model(suite: str, model_name: str):
raise ValueError(f"Unsupported model suite: {suite}")


def load_calibration_dataset(dataset_path: str):
def load_calibration_dataset(dataset_path: str, suite: str, model: torch.nn.Module):
val_dir = f"{dataset_path}/val"

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
if suite == "torchvision":
transform = torchvision_models.get_model_weights(model.name).transforms()
else:
transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))

val_dataset = datasets.ImageFolder(
val_dir,
transforms.Compose(
[
transforms.Resize(64), # for tiny imagenet
transforms.ToTensor(),
normalize,
]
),
transform=transform
)

calibration_dataset = torch.utils.data.DataLoader(
Expand All @@ -65,21 +66,6 @@ def load_calibration_dataset(dataset_path: str):
return calibration_dataset


def quantize_model(model: torch.fx.GraphModule, example_args, subset_size=300):
#quantizer = OpenVINOQuantizer(ignored_scope=nncf.IgnoredScope(types=["__getitem__", "layer_norm"]))
quantizer = OpenVINOQuantizer()

print("PTQ: Annotate the model...")
annotated_model = prepare_pt2e(model, quantizer)

print("PTQ: Calibrate the model...")
annotated_model(*example_args)

print("PTQ: Convert the quantized model...")
quantized_model = convert_pt2e(annotated_model, fold_quantize=False)
return quantized_model


def main(suite: str, model_name: str, input_shape, quantize: bool, dataset_path: str, device: str):
# Ensure input_shape is a tuple
if isinstance(input_shape, list):
Expand All @@ -98,15 +84,24 @@ def main(suite: str, model_name: str, input_shape, quantize: bool, dataset_path:
aten_dialect: ExportedProgram = export(model, example_args)

if quantize:
if suite == "huggingface":
raise ValueError("Quantization of {suite} models did not support yet.")

# Quantize model
if not dataset_path:
raise ValueError("Quantization requires a calibration dataset.")
#calibration_dataset = load_calibration_dataset(dataset_path)
calibration_dataset = load_calibration_dataset(dataset_path, suite, model)

captured_model = aten_dialect.module()
#visualize_fx_model(captured_model, f"{model_name}_fp32.svg")
quantized_model = quantize_model(captured_model, example_args)
#visualize_fx_model(quantized_model, f"{model_name}_int8.svg")
quantizer = OpenVINOQuantizer()

print("PTQ: Quantize the model")
def transform(x):
return x[0]

quantized_model = quantize_pt2e(captured_model, quantizer, calibration_dataset=nncf.Dataset(calibration_dataset, transform_func=transform), fold_quantize=False)

aten_dialect: ExportedProgram = export(quantized_model, example_args)

# Convert to edge dialect
Expand All @@ -121,16 +116,95 @@ def main(suite: str, model_name: str, input_shape, quantize: bool, dataset_path:
exec_prog = lowered_module.to_executorch(config=executorch.exir.ExecutorchBackendConfig())

# Serialize and save it to a file
model_name = f"{model_name}_{'int8' if quantize else 'fp32'}.pte"
model_name = f"{model_name}_{'int8' if quantize else 'fp32'}.pte"
with open(model_name, "wb") as file:
exec_prog.write_to_file(file)
print(f"Model exported and saved as {model_name} on {device}.")

if quantize:
print("Start validation of the quantized model:")

# 1: Dump inputs
import os
import shutil

dest_path = "tmp_inputs"
out_path = "tmp_outputs"
targets, input_files = [], []
for d in [dest_path, out_path]:
if os.path.exists(d):
shutil.rmtree(d)
os.makedirs(d)
input_list = ""
for idx, data in enumerate(calibration_dataset):
feature, target = data
targets.append(target)
file_name = f"{dest_path}/input_{idx}_0.raw"
input_list += file_name + " "
if not isinstance(feature, torch.Tensor):
feature = torch.tensor(feature)
feature.detach().numpy().tofile(file_name)
input_files.append(file_name)

inp_list_file = os.path.join(dest_path, "in_list.txt")
with open(inp_list_file, "w") as f:
input_list = input_list.strip() + "\n"
f.write(input_list)

# 2: Run the executor
print("Run openvino_executor_runner...")
import subprocess
breakpoint()
subprocess.run(["../../../cmake-openvino-out/examples/openvino/openvino_executor_runner",
f"--model_path={model_name}",
f"--input_list_path={inp_list_file}",
f"--output_folder_path={out_path}",
#f"--num_iter={len(input_files)}"
])

# 3: load the outputs and compare with the targets
import numpy as np
predictions = []
for i in range(len(input_files)):
predictions.append(
np.fromfile(
os.path.join(out_path, f"output_{i}.raw"), dtype=np.float32
)
)

k_val = [1, 5]
acc_top1 = accuracy_score(predictions, targets)
print(f"acc@1: {acc_top1}")


from torch.fx.passes.graph_drawer import FxGraphDrawer
def visualize_fx_model(model: torch.fx.GraphModule, output_svg_path: str):
g = FxGraphDrawer(model, output_svg_path)
g.get_dot_graph().write_svg(output_svg_path)

def generate_inputs(dest_path: str, file_name: str, inputs=None, input_list=None):
input_list_file = None
input_files = []

# Prepare input list
if input_list is not None:
input_list_file = f"{dest_path}/{file_name}"
with open(input_list_file, "w") as f:
f.write(input_list)
f.flush()

# Prepare input data
if inputs is not None:
for idx, data in enumerate(inputs):
for i, d in enumerate(data):
file_name = f"{dest_path}/input_{idx}_{i}.raw"
if not isinstance(d, torch.Tensor):
d = torch.tensor(d)
d.detach().numpy().tofile(file_name)
input_files.append(file_name)

return input_list_file, input_files

if __name__ == "__main__":
# Argument parser for dynamic inputs
parser = argparse.ArgumentParser(description="Export models with executorch.")
Expand Down
Loading