🐛 Describe the bug
In arm_quantizer.py, the name filter (module_name_filter) assumes module names starts with "L['self']." and filters it out of the module name, but it doesn't contain that string, so the whole name is deleted and the node isn't detected by the filter.
class SimpleConvModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
def forward(self, x):
x = self.conv1(x)
x = torch.relu(x)
x = self.conv2(x)
return x
def create_int8_int8_config():
"""INT8 activations, INT8 weights, INT32 bias"""
return get_symmetric_quantization_config(is_per_channel=True)
def create_int16_int8_config():
"""INT16 activations, INT8 weights, INT32 bias"""
base_config = get_symmetric_quantization_config(is_per_channel=True)
# Replace activation specs with INT16
new_fields = {}
if hasattr(base_config, 'input_activation') and base_config.input_activation:
new_fields['input_activation'] = replace(
base_config.input_activation,
dtype=torch.int16,
quant_min=-32768,
quant_max=32767
)
if hasattr(base_config, 'output_activation') and base_config.output_activation:
new_fields['output_activation'] = replace(
base_config.output_activation,
dtype=torch.int16,
quant_min=-32768,
quant_max=32767
)
return replace(base_config, **new_fields)
# Create model
model = SimpleConvModel()
model.eval()
# Create dummy input
batch_size = 4
example_input = torch.randn(batch_size, 3, 32, 32)
# Export model
print("Exporting model...")
exported_program = torch.export.export(model, (example_input,))
graph_module = exported_program.module(check_guards=False)
# Setup compile spec
compile_spec = EthosUCompileSpec(
target="ethos-u55-128",
system_config="Ethos_U55_High_End_Embedded",
memory_mode="Shared_Sram",
extra_flags=["--output-format=raw", "--debug-force-regor"]
)
# Create quantizer
quantizer = EthosUQuantizer(compile_spec)
quantizer.set_global(get_symmetric_quantization_config(is_per_channel=True))
# Configure conv1: INT8 activations, INT8 weights
int8_config = create_int8_int8_config()
quantizer.set_module_name('conv1', int8_config)
print("\nConv1 config (INT8/INT8):")
print(f" Input activation dtype: {getattr(int8_config.input_activation, 'dtype', None)}")
print(f" Output activation dtype: {getattr(int8_config.output_activation, 'dtype', None)}")
print(f" Weight dtype: {getattr(int8_config.weight, 'dtype', None)}")
# Configure conv2: INT16 activations, INT8 weights
int16_config = create_int16_int8_config()
quantizer.set_module_name('conv2', int16_config)
# quantizer.set_module_name('aten.conv2d.default', int16_config)
print("\nConv2 config (INT16/INT8):")
print(f" Input activation dtype: {getattr(int16_config.input_activation, 'dtype', None)}")
print(f" Output activation dtype: {getattr(int16_config.output_activation, 'dtype', None)}")
print(f" Weight dtype: {getattr(int16_config.weight, 'dtype', None)}")
# Prepare model for quantization
print("\nPreparing model for quantization...")
prepared = prepare_pt2e(graph_module, quantizer)
Versions
PyTorch version: 2.9.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 15.5 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.0.13.5)
CMake version: Could not collect
Libc version: N/A
Python version: 3.12.10 (v3.12.10:0cc81280367, Apr 8 2025, 08:46:59) [Clang 13.0.0 (clang-1300.0.29.30)] (64-bit runtime)
Python platform: macOS-15.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M4 Max
Versions of relevant libraries:
[pip3] executorch==1.0.0
[pip3] numpy==2.3.4
[pip3] torch==2.9.0
[pip3] torchao==0.14.0
[pip3] torchaudio==2.9.0
[pip3] torchcodec==0.8.1
[pip3] torchvision==0.24.0
[conda] Could not collect
cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai
🐛 Describe the bug
In arm_quantizer.py, the name filter (module_name_filter) assumes module names starts with "L['self']." and filters it out of the module name, but it doesn't contain that string, so the whole name is deleted and the node isn't detected by the filter.
Versions
PyTorch version: 2.9.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 15.5 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.0.13.5)
CMake version: Could not collect
Libc version: N/A
Python version: 3.12.10 (v3.12.10:0cc81280367, Apr 8 2025, 08:46:59) [Clang 13.0.0 (clang-1300.0.29.30)] (64-bit runtime)
Python platform: macOS-15.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M4 Max
Versions of relevant libraries:
[pip3] executorch==1.0.0
[pip3] numpy==2.3.4
[pip3] torch==2.9.0
[pip3] torchao==0.14.0
[pip3] torchaudio==2.9.0
[pip3] torchcodec==0.8.1
[pip3] torchvision==0.24.0
[conda] Could not collect
cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai