1- # Copyright 2024-2025 Arm Limited and/or its affiliates.
1+ # Copyright 2024-2026 Arm Limited and/or its affiliates.
22#
33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
77from typing import List , Tuple , Union
88
99import torch
10+ from executorch .backends .arm .quantizer .arm_quantizer import (
11+ get_symmetric_a8w4_quantization_config ,
12+ )
1013from executorch .backends .arm .test import common
1114from executorch .backends .arm .test .tester .test_pipeline import (
1215 EthosU55PipelineINT ,
1720 VgfPipeline ,
1821)
1922
23+
2024aten_op = "torch.ops.aten.conv2d.default"
2125exir_op = "executorch_exir_dialects_edge__ops_aten_convolution_default"
2226
@@ -162,8 +166,8 @@ def forward(self, x):
162166 batches = 1 ,
163167)
164168
165- conv2d_2x2_1x1x14x13_st2 = Conv2d (
166- in_channels = 1 ,
169+ conv2d_2x2_2x1x14x13_st2 = Conv2d (
170+ in_channels = 2 ,
167171 out_channels = 1 ,
168172 kernel_size = (2 , 2 ),
169173 stride = 2 ,
@@ -363,7 +367,7 @@ def forward(self, x):
363367 "3x3_1x3x24x24_st1" : lambda : conv2d_3x3_1x3x24x24_st1 ,
364368 "3x3_1x3x12x12_st2_pd1" : lambda : conv2d_3x3_1x3x12x12_st2_pd1 ,
365369 "1x1_1x2x16x16_st1" : lambda : conv2d_1x1_1x2x16x16_st1 ,
366- "2x2_1x1x14x13_st2_needs_adjust_pass " : lambda : conv2d_2x2_1x1x14x13_st2 ,
370+ "2x2_2x1x14x13_st2_needs_adjust_pass " : lambda : conv2d_2x2_2x1x14x13_st2 ,
367371 "5x5_1x3x14x15_st3_pd1_needs_adjust_pass" : lambda : conv2d_5x5_1x3x14x15_st3_pd1 ,
368372 "7x7_1x3x16x16_st2_pd1_dl2_needs_adjust_pass" : lambda : conv2d_7x7_1x3x16x16_st2_pd1_dl2 ,
369373 "7x7_1x3x15x15_st1_pd0_dl1_needs_adjust_pass" : lambda : conv2d_7x7_1x3x15x15_st1_pd0_dl1 ,
@@ -391,6 +395,15 @@ def forward(self, x):
391395input_t = Tuple [torch .Tensor ]
392396
393397
398+ def _get_dtype_count (model : torch .nn .Module ):
399+ nbr_convs : int = model .nbr_convs # noqa
400+ return {
401+ "CONST" : {"INT4" : nbr_convs * 2 }, # One for the weight, one for the zp.
402+ "CONV2D" : {"INT32" : nbr_convs },
403+ "RESCALE" : {"INT8" : nbr_convs },
404+ }
405+
406+
394407@common .parametrize ("test_data" , test_data_FP )
395408def test_convolution_2d_tosa_FP (test_data ):
396409 model = test_data ()
@@ -417,6 +430,36 @@ def test_convolution_2d_tosa_INT(test_data):
417430 pipeline .run ()
418431
419432
433+ @common .parametrize (
434+ "test_data" ,
435+ test_data_INT ,
436+ xfails = {
437+ "groups,per_channel_quant=True" : "Int4 not supported for grouped convolutions. MLETORCH-1726" ,
438+ "groups,per_channel_quant=False" : "Int4 not supported for grouped convolutions. MLETORCH-1726" ,
439+ "groups_bias,per_channel_quant=True" : "Int4 not supported for grouped convolutions. MLETORCH-1726" ,
440+ "groups_bias,per_channel_quant=False" : "Int4 not supported for grouped convolutions. MLETORCH-1726" ,
441+ },
442+ )
443+ def test_convolution_2d_tosa_INT_a8w4 (test_data ):
444+ model , per_channel_quantization = test_data ()
445+ pipeline = TosaPipelineINT [input_t ](
446+ model ,
447+ model .get_inputs (),
448+ aten_op ,
449+ exir_op ,
450+ tosa_extensions = ["int4" ],
451+ )
452+ pipeline .quantizer .set_global (
453+ get_symmetric_a8w4_quantization_config (is_per_channel = per_channel_quantization )
454+ )
455+ pipeline .add_stage_after (
456+ "to_edge_transform_and_lower" ,
457+ pipeline .tester .check_dtype_count ,
458+ _get_dtype_count (model ),
459+ )
460+ pipeline .run ()
461+
462+
420463@common .parametrize ("test_data" , test_data_INT )
421464@common .XfailIfNoCorstone300
422465def test_convolution_2d_u55_INT (test_data ):
@@ -431,6 +474,21 @@ def test_convolution_2d_u55_INT(test_data):
431474 pipeline .run ()
432475
433476
477+ @common .parametrize ("test_data" , test_data_INT )
478+ def test_convolution_2d_u55_INT_a8w4 (test_data ):
479+ model , per_channel_quantization = test_data ()
480+ pipeline = EthosU55PipelineINT [input_t ](
481+ model ,
482+ model .get_inputs (),
483+ aten_op ,
484+ exir_op ,
485+ )
486+ pipeline .quantizer .set_global (
487+ get_symmetric_a8w4_quantization_config (is_per_channel = per_channel_quantization )
488+ )
489+ pipeline .run ()
490+
491+
434492@common .parametrize ("test_data" , test_data_INT )
435493@common .XfailIfNoCorstone320
436494def test_convolution_u85_INT (test_data ):
@@ -445,6 +503,21 @@ def test_convolution_u85_INT(test_data):
445503 pipeline .run ()
446504
447505
506+ @common .parametrize ("test_data" , test_data_INT )
507+ def test_convolution_2d_u85_INT_a8w4 (test_data ):
508+ model , per_channel_quantization = test_data ()
509+ pipeline = EthosU85PipelineINT [input_t ](
510+ model ,
511+ model .get_inputs (),
512+ aten_op ,
513+ exir_op ,
514+ )
515+ pipeline .quantizer .set_global (
516+ get_symmetric_a8w4_quantization_config (is_per_channel = per_channel_quantization )
517+ )
518+ pipeline .run ()
519+
520+
448521@common .parametrize ("test_data" , test_data_FP )
449522@common .SkipIfNoModelConverter
450523def test_convolution_2d_vgf_no_quant (test_data ):
0 commit comments