Skip to content

perf: CPU 30-44× slower than PyTorch despite AVX-512 - Investigation & Fix Plan #38

@isPANN

Description

@isPANN

CPU Performance Investigation: 30-44× Slowdown

Problem Statement

Despite AVX-512 SIMD being enabled, CPU implementation is 30-44× slower than PyTorch:

Backend: SIMD: Avx512, MR=8×NR=8 ✓ (SIMD enabled!)

Size      tropical-gemm    PyTorch CPU    Slowdown
64×64     5.647ms          0.128ms        44.3× ❌
128×128   40.393ms         3.780ms        10.7× ❌
256×256   312.253ms        10.159ms       30.7× ❌
512×512   2130.398ms       65.827ms       32.4× ❌

Root Causes Identified

🚨 #1: SCALAR LOOPS IN SIMD CODE (SMOKING GUN!)

File: crates/tropical-gemm/src/simd/kernels/avx2.rs:41-63

// Current implementation - TERRIBLE!
for p in 0..k {
    let mut a_vals = [0.0f32; 8];
    for i in 0..mr {  // ❌ SCALAR LOOP (8 iterations)
        a_vals[i] = *a.add(p * Self::MR + i);
    }

    let mut b_vals = [0.0f32; 8];
    for j in 0..nr {  // ❌ SCALAR LOOP (8 iterations)
        b_vals[j] = *b.add(p * Self::NR + j);
    }
    let b_vec = f32x8::from(b_vals);

    for i in 0..mr {  // ❌ SCALAR LOOP (8 iterations)
        let a_broadcast = f32x8::splat(a_vals[i]);
        let product = a_broadcast + b_vec;
        acc[i] = acc[i].max(product);
    }
}

Impact: For 64×64 matrix, inner loops execute 1,536 times per microkernel call!

  • K=64 iterations × (8+8+8) scalar operations = 1,536 scalar ops
  • PyTorch does equivalent with 1 vectorized broadcast

Expected Loss: 3-5× slower


🚨 #2: EXCESSIVE PACKING OVERHEAD

File: crates/tropical-gemm/src/core/gemm.rs:78-79

// Allocated on EVERY call!
let mut packed_a = vec![T::Scalar::scalar_zero(); packed_a_size(...)];
let mut packed_b = vec![T::Scalar::scalar_zero(); packed_b_size(...)];

Impact: For 64×64 matrix (16 KB data):

  • packed_a: 512 KB
  • packed_b: 512 KB
  • 1 MB heap allocation + zero-init for 16 KB of actual data!

Expected Loss: 2-3× slower for small matrices


🚨 #3: NO REAL AVX-512 KERNELS

File: crates/tropical-gemm/src/simd/dispatch.rs:60

SimdLevel::Avx2 | SimdLevel::Avx512 => {
    let kernel = Avx2MaxPlusF32Kernel;  // ❌ Using AVX2 for AVX-512!

Impact: AVX-512 detected but using 256-bit AVX2 code

  • Wasting 50% of SIMD width
  • Backend reports "AVX-512" but actually runs AVX2

Expected Loss: 1.5-2× slower


🚨 #4: WRONG TILING FOR SMALL MATRICES

File: crates/tropical-gemm/src/core/tiling.rs

pub const F32_AVX2: Self = Self {
    mc: 256,  // For 64×64 matrix, this causes 75% padding waste!
    nc: 256,
    kc: 512,
    mr: 8,
    nr: 8,
};

Impact:

  • Parameters optimized for large matrices (>1024×1024)
  • For 64×64: mc=256 > m=64 → wasted padding and cache thrashing
  • Packed buffers (1 MB) exceed L2 cache (~256 KB)

Expected Loss: 1.5-2× slower


🚨 #5: BLIS BLOCKING OVERHEAD

Issue: 5-level loop nesting vs PyTorch's simple broadcasting

PyTorch approach:

a_expanded = a.unsqueeze(2)  # (m, k, 1)
b_expanded = b.unsqueeze(0)  # (1, k, n)
result = torch.max(a_expanded + b_expanded, dim=1)[0]
  • 4 vectorized operations
  • Leverages Intel MKL

tropical-gemm approach:

  • 5-level BLIS blocking (jc, pc, ic, jr, ir loops)
  • Memory traffic: 3× (original + 2 packed copies)

Expected Loss: 1.5-2× slower for small matrices


Cumulative Impact

Issue Loss Factor Cumulative
Scalar loops in SIMD 3-5× 3-5×
Packing overhead 2-3× 6-15×
No AVX-512 1.5-2× 9-30×
Wrong tiling 1.5-2× 13-60×
BLIS overhead 1.5-2× 20-120×

Observed: 30-44× slower ✓ (consistent with analysis!)


Fix Plan

Priority 1: Fix Microkernel (3-5× gain) 🔥

File: crates/tropical-gemm/src/simd/kernels/avx2.rs

Replace scalar loops with vectorized loads:

// Use SIMD intrinsics directly
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;

unsafe fn microkernel_maxplus_f32_avx2(...) {
    for p in 0..k {
        // Load 8 elements at once (no scalar loop!)
        let a_vec = _mm256_loadu_ps(a.add(p * 8));
        let b_vec = _mm256_loadu_ps(b.add(p * 8));

        // Broadcast and compute (8 operations in parallel)
        for i in 0..8 {
            let a_broadcast = _mm256_set1_ps(_mm256_extract_ps(a_vec, i));
            let product = _mm256_add_ps(a_broadcast, b_vec);
            acc[i] = _mm256_max_ps(acc[i], product);
        }
    }
}

Alternative: Use wide crate properly (if it supports direct loads)

Expected: 3-5× speedup


Priority 2: Size-Adaptive Tiling (2× gain)

File: crates/tropical-gemm/src/core/tiling.rs

Add adaptive parameter selection:

impl TilingParams {
    pub fn for_size(m: usize, n: usize, k: usize) -> Self {
        if m <= 128 && n <= 128 {
            // Small matrices: reduce overhead
            Self {
                mc: 64,
                nc: 64,
                kc: 64,
                mr: 8,
                nr: 8,
            }
        } else if m <= 512 && n <= 512 {
            // Medium matrices
            Self {
                mc: 128,
                nc: 128,
                kc: 256,
                mr: 8,
                nr: 8,
            }
        } else {
            // Large matrices: use current parameters
            Self::F32_AVX2
        }
    }
}

Update dispatch in src/simd/dispatch.rs to use TilingParams::for_size(m, n, k)

Expected: 1.5-2× speedup for small matrices


Priority 3: Workspace API (2-3× gain)

File: crates/tropical-gemm/src/core/gemm.rs

Implement reusable workspace (TODO #34 exists):

pub struct GemmWorkspace<T> {
    packed_a: Vec<T>,
    packed_b: Vec<T>,
}

impl<T> GemmWorkspace<T> {
    pub fn new(mc: usize, kc: usize, nc: usize, mr: usize, nr: usize) -> Self {
        Self {
            packed_a: vec![T::default(); packed_a_size(mc, kc, mr)],
            packed_b: vec![T::default(); packed_b_size(kc, nc, nr)],
        }
    }

    pub fn resize_if_needed(&mut self, mc: usize, kc: usize, nc: usize, mr: usize, nr: usize) {
        // Only reallocate if size increased
    }
}

pub fn tropical_gemm_with_workspace<T, K>(
    workspace: &mut GemmWorkspace<T::Scalar>,
    m: usize, n: usize, k: usize,
    // ... other params
) {
    workspace.resize_if_needed(...);
    // Use workspace.packed_a and workspace.packed_b instead of allocating
}

Expected: 2× speedup for repeated calls


Priority 4: Implement AVX-512 Kernels (1.5-2× gain)

File: crates/tropical-gemm/src/simd/kernels/avx512.rs (new)

Create true AVX-512 kernel:

pub struct Avx512MaxPlusF32Kernel;

impl MicroKernel<TropicalMaxPlus<f32>> for Avx512MaxPlusF32Kernel {
    const MR: usize = 16;  // 512-bit / 32-bit = 16 elements
    const NR: usize = 16;

    unsafe fn execute(...) {
        // Use _mm512_* intrinsics
        // 16-wide SIMD operations
    }
}

Update dispatch in src/simd/dispatch.rs:

SimdLevel::Avx512 => {
    let kernel = Avx512MaxPlusF32Kernel;  // ✓ Use real AVX-512
    let params = TilingParams::F32_AVX512;
    // ...
}

Expected: 1.5-2× speedup on AVX-512 CPUs


Priority 5: Naive Algorithm for Small Matrices (1.5-2× gain)

File: crates/tropical-gemm/src/core/gemm.rs

Add simple triple-loop implementation:

fn tropical_gemm_naive<T: TropicalSemiring>(
    m: usize, n: usize, k: usize,
    a: &[T::Scalar], lda: usize,
    b: &[T::Scalar], ldb: usize,
    c: &mut [T], ldc: usize,
) {
    for i in 0..m {
        for j in 0..n {
            let mut acc = T::neg_infinity();
            for p in 0..k {
                let a_val = T::from_scalar(a[i * lda + p]);
                let b_val = T::from_scalar(b[p * ldb + j]);
                acc = acc.tropical_add(a_val.tropical_mul(b_val));
            }
            c[i * ldc + j] = acc;
        }
    }
}

// In tropical_gemm_inner:
pub unsafe fn tropical_gemm_inner<T, K>(...) {
    if m <= 64 && n <= 64 {
        return tropical_gemm_naive(m, n, k, a, lda, b, ldb, c, ldc);
    }
    // ... existing BLIS code
}

Expected: 1.5-2× speedup for very small matrices (<64×64)


Implementation Roadmap

Phase 1: Quick Wins (1-2 days)

  1. Fix microkernel scalar loops (Priority 1)
  2. Add size-adaptive tiling (Priority 2)
  3. Expected gain: 5-10× improvement

Phase 2: Memory Optimization (2-3 days)

  1. Implement workspace API (Priority 3)
  2. Add naive algorithm for small matrices (Priority 5)
  3. Expected gain: Additional 2-3× improvement

Phase 3: AVX-512 (1-2 weeks)

  1. Implement true AVX-512 kernels (Priority 4)
  2. Optimize for cache layout
  3. Expected gain: Additional 1.5-2× improvement

Total Expected Improvement

  • Phase 1: 5-10× (brings to 3-9× slower than PyTorch)
  • Phase 1+2: 10-30× (brings to 1-3× slower than PyTorch)
  • Phase 1+2+3: 15-60× (potentially faster than PyTorch!)

Profiling Commands

Flamegraph

cargo install cargo-flamegraph
cargo flamegraph --bench gemm_bench

Perf Stats

perf stat -e cycles,instructions,cache-references,cache-misses,L1-dcache-loads \
    cargo bench --bench gemm_bench

Assembly Analysis

cargo install cargo-asm
cargo asm tropical_gemm::simd::kernels::avx2::Avx2MaxPlusF32Kernel::execute --rust

Micro-benchmark (create new file)

// benches/microkernel_bench.rs
// Isolate microkernel performance
// Test with different K values
// Compare packed vs unpacked

Critical Files to Modify

  1. crates/tropical-gemm/src/simd/kernels/avx2.rs - Fix scalar loops
  2. crates/tropical-gemm/src/core/tiling.rs - Add adaptive parameters
  3. crates/tropical-gemm/src/core/gemm.rs - Workspace API + naive algorithm
  4. crates/tropical-gemm/src/simd/dispatch.rs - Use adaptive tiling
  5. crates/tropical-gemm/src/simd/kernels/avx512.rs - New AVX-512 kernels

Conclusion

The 30-44× slowdown is caused by:

  1. ❌ Scalar loops in SIMD code (worst offender)
  2. ❌ Excessive memory allocation
  3. ❌ Missing AVX-512 implementation
  4. ❌ Wrong parameters for small matrices
  5. ❌ Algorithm overhead vs PyTorch's optimized path

All issues are fixable with targeted optimizations.

Recommendation: Start with Phase 1 (microkernel + adaptive tiling) for immediate 5-10× improvement.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions