From 4c2330c73a9fdc52e14021bd9032828e11c79a24 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 26 May 2026 15:35:07 -0400 Subject: [PATCH] fix gemv test on avx512bf16 cpu --- pyproject.toml | 1 + tests/test_functional.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cf9e00708..1b004efb3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,6 +101,7 @@ markers = [ "deprecated: mark test as covering a deprecated feature", "slow: mark test as slow", ] +testpaths = ["tests"] [tool.ruff] src = [ diff --git a/tests/test_functional.py b/tests/test_functional.py index 62585de0c..95d8727f7 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -807,14 +807,20 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, double_quant, kind): compress_statistics=double_quant, quant_storage=quant_storage, ) + + # dequant+F.linear reference path. + C1 = torch.nn.functional.linear(A, F.dequantize_4bit(qB, state).to(dtype)) + + # original matmul reference path. C3 = torch.matmul(A, B.t()) + # CPU requires convert weight packed for gemv if device == "cpu" and F.has_avx512bf16(): qB, state = F._convert_weight_packed_for_cpu(qB, state) qB = qB.t() + + # GEMV test C2 = F.gemv_4bit(A, qB.t(), state=state) - # dequant+F.linear reference path - C1 = torch.nn.functional.linear(A, F.dequantize_4bit(qB, state).to(dtype)) err1 = (C1 - C2).abs().float() err2 = (C3 - C2).abs().float()