-
Notifications
You must be signed in to change notification settings - Fork 2
Description
CPU Performance Investigation: 30-44× Slowdown
Problem Statement
Despite AVX-512 SIMD being enabled, CPU implementation is 30-44× slower than PyTorch:
Backend: SIMD: Avx512, MR=8×NR=8 ✓ (SIMD enabled!)
Size tropical-gemm PyTorch CPU Slowdown
64×64 5.647ms 0.128ms 44.3× ❌
128×128 40.393ms 3.780ms 10.7× ❌
256×256 312.253ms 10.159ms 30.7× ❌
512×512 2130.398ms 65.827ms 32.4× ❌
Root Causes Identified
🚨 #1: SCALAR LOOPS IN SIMD CODE (SMOKING GUN!)
File: crates/tropical-gemm/src/simd/kernels/avx2.rs:41-63
// Current implementation - TERRIBLE!
for p in 0..k {
let mut a_vals = [0.0f32; 8];
for i in 0..mr { // ❌ SCALAR LOOP (8 iterations)
a_vals[i] = *a.add(p * Self::MR + i);
}
let mut b_vals = [0.0f32; 8];
for j in 0..nr { // ❌ SCALAR LOOP (8 iterations)
b_vals[j] = *b.add(p * Self::NR + j);
}
let b_vec = f32x8::from(b_vals);
for i in 0..mr { // ❌ SCALAR LOOP (8 iterations)
let a_broadcast = f32x8::splat(a_vals[i]);
let product = a_broadcast + b_vec;
acc[i] = acc[i].max(product);
}
}Impact: For 64×64 matrix, inner loops execute 1,536 times per microkernel call!
- K=64 iterations × (8+8+8) scalar operations = 1,536 scalar ops
- PyTorch does equivalent with 1 vectorized broadcast
Expected Loss: 3-5× slower
🚨 #2: EXCESSIVE PACKING OVERHEAD
File: crates/tropical-gemm/src/core/gemm.rs:78-79
// Allocated on EVERY call!
let mut packed_a = vec![T::Scalar::scalar_zero(); packed_a_size(...)];
let mut packed_b = vec![T::Scalar::scalar_zero(); packed_b_size(...)];Impact: For 64×64 matrix (16 KB data):
packed_a: 512 KBpacked_b: 512 KB- 1 MB heap allocation + zero-init for 16 KB of actual data!
Expected Loss: 2-3× slower for small matrices
🚨 #3: NO REAL AVX-512 KERNELS
File: crates/tropical-gemm/src/simd/dispatch.rs:60
SimdLevel::Avx2 | SimdLevel::Avx512 => {
let kernel = Avx2MaxPlusF32Kernel; // ❌ Using AVX2 for AVX-512!Impact: AVX-512 detected but using 256-bit AVX2 code
- Wasting 50% of SIMD width
- Backend reports "AVX-512" but actually runs AVX2
Expected Loss: 1.5-2× slower
🚨 #4: WRONG TILING FOR SMALL MATRICES
File: crates/tropical-gemm/src/core/tiling.rs
pub const F32_AVX2: Self = Self {
mc: 256, // For 64×64 matrix, this causes 75% padding waste!
nc: 256,
kc: 512,
mr: 8,
nr: 8,
};Impact:
- Parameters optimized for large matrices (>1024×1024)
- For 64×64: mc=256 > m=64 → wasted padding and cache thrashing
- Packed buffers (1 MB) exceed L2 cache (~256 KB)
Expected Loss: 1.5-2× slower
🚨 #5: BLIS BLOCKING OVERHEAD
Issue: 5-level loop nesting vs PyTorch's simple broadcasting
PyTorch approach:
a_expanded = a.unsqueeze(2) # (m, k, 1)
b_expanded = b.unsqueeze(0) # (1, k, n)
result = torch.max(a_expanded + b_expanded, dim=1)[0]- 4 vectorized operations
- Leverages Intel MKL
tropical-gemm approach:
- 5-level BLIS blocking (jc, pc, ic, jr, ir loops)
- Memory traffic: 3× (original + 2 packed copies)
Expected Loss: 1.5-2× slower for small matrices
Cumulative Impact
| Issue | Loss Factor | Cumulative |
|---|---|---|
| Scalar loops in SIMD | 3-5× | 3-5× |
| Packing overhead | 2-3× | 6-15× |
| No AVX-512 | 1.5-2× | 9-30× |
| Wrong tiling | 1.5-2× | 13-60× |
| BLIS overhead | 1.5-2× | 20-120× |
Observed: 30-44× slower ✓ (consistent with analysis!)
Fix Plan
Priority 1: Fix Microkernel (3-5× gain) 🔥
File: crates/tropical-gemm/src/simd/kernels/avx2.rs
Replace scalar loops with vectorized loads:
// Use SIMD intrinsics directly
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
unsafe fn microkernel_maxplus_f32_avx2(...) {
for p in 0..k {
// Load 8 elements at once (no scalar loop!)
let a_vec = _mm256_loadu_ps(a.add(p * 8));
let b_vec = _mm256_loadu_ps(b.add(p * 8));
// Broadcast and compute (8 operations in parallel)
for i in 0..8 {
let a_broadcast = _mm256_set1_ps(_mm256_extract_ps(a_vec, i));
let product = _mm256_add_ps(a_broadcast, b_vec);
acc[i] = _mm256_max_ps(acc[i], product);
}
}
}Alternative: Use wide crate properly (if it supports direct loads)
Expected: 3-5× speedup
Priority 2: Size-Adaptive Tiling (2× gain)
File: crates/tropical-gemm/src/core/tiling.rs
Add adaptive parameter selection:
impl TilingParams {
pub fn for_size(m: usize, n: usize, k: usize) -> Self {
if m <= 128 && n <= 128 {
// Small matrices: reduce overhead
Self {
mc: 64,
nc: 64,
kc: 64,
mr: 8,
nr: 8,
}
} else if m <= 512 && n <= 512 {
// Medium matrices
Self {
mc: 128,
nc: 128,
kc: 256,
mr: 8,
nr: 8,
}
} else {
// Large matrices: use current parameters
Self::F32_AVX2
}
}
}Update dispatch in src/simd/dispatch.rs to use TilingParams::for_size(m, n, k)
Expected: 1.5-2× speedup for small matrices
Priority 3: Workspace API (2-3× gain)
File: crates/tropical-gemm/src/core/gemm.rs
Implement reusable workspace (TODO #34 exists):
pub struct GemmWorkspace<T> {
packed_a: Vec<T>,
packed_b: Vec<T>,
}
impl<T> GemmWorkspace<T> {
pub fn new(mc: usize, kc: usize, nc: usize, mr: usize, nr: usize) -> Self {
Self {
packed_a: vec![T::default(); packed_a_size(mc, kc, mr)],
packed_b: vec![T::default(); packed_b_size(kc, nc, nr)],
}
}
pub fn resize_if_needed(&mut self, mc: usize, kc: usize, nc: usize, mr: usize, nr: usize) {
// Only reallocate if size increased
}
}
pub fn tropical_gemm_with_workspace<T, K>(
workspace: &mut GemmWorkspace<T::Scalar>,
m: usize, n: usize, k: usize,
// ... other params
) {
workspace.resize_if_needed(...);
// Use workspace.packed_a and workspace.packed_b instead of allocating
}Expected: 2× speedup for repeated calls
Priority 4: Implement AVX-512 Kernels (1.5-2× gain)
File: crates/tropical-gemm/src/simd/kernels/avx512.rs (new)
Create true AVX-512 kernel:
pub struct Avx512MaxPlusF32Kernel;
impl MicroKernel<TropicalMaxPlus<f32>> for Avx512MaxPlusF32Kernel {
const MR: usize = 16; // 512-bit / 32-bit = 16 elements
const NR: usize = 16;
unsafe fn execute(...) {
// Use _mm512_* intrinsics
// 16-wide SIMD operations
}
}Update dispatch in src/simd/dispatch.rs:
SimdLevel::Avx512 => {
let kernel = Avx512MaxPlusF32Kernel; // ✓ Use real AVX-512
let params = TilingParams::F32_AVX512;
// ...
}Expected: 1.5-2× speedup on AVX-512 CPUs
Priority 5: Naive Algorithm for Small Matrices (1.5-2× gain)
File: crates/tropical-gemm/src/core/gemm.rs
Add simple triple-loop implementation:
fn tropical_gemm_naive<T: TropicalSemiring>(
m: usize, n: usize, k: usize,
a: &[T::Scalar], lda: usize,
b: &[T::Scalar], ldb: usize,
c: &mut [T], ldc: usize,
) {
for i in 0..m {
for j in 0..n {
let mut acc = T::neg_infinity();
for p in 0..k {
let a_val = T::from_scalar(a[i * lda + p]);
let b_val = T::from_scalar(b[p * ldb + j]);
acc = acc.tropical_add(a_val.tropical_mul(b_val));
}
c[i * ldc + j] = acc;
}
}
}
// In tropical_gemm_inner:
pub unsafe fn tropical_gemm_inner<T, K>(...) {
if m <= 64 && n <= 64 {
return tropical_gemm_naive(m, n, k, a, lda, b, ldb, c, ldc);
}
// ... existing BLIS code
}Expected: 1.5-2× speedup for very small matrices (<64×64)
Implementation Roadmap
Phase 1: Quick Wins (1-2 days)
- Fix microkernel scalar loops (Priority 1)
- Add size-adaptive tiling (Priority 2)
- Expected gain: 5-10× improvement
Phase 2: Memory Optimization (2-3 days)
- Implement workspace API (Priority 3)
- Add naive algorithm for small matrices (Priority 5)
- Expected gain: Additional 2-3× improvement
Phase 3: AVX-512 (1-2 weeks)
- Implement true AVX-512 kernels (Priority 4)
- Optimize for cache layout
- Expected gain: Additional 1.5-2× improvement
Total Expected Improvement
- Phase 1: 5-10× (brings to 3-9× slower than PyTorch)
- Phase 1+2: 10-30× (brings to 1-3× slower than PyTorch)
- Phase 1+2+3: 15-60× (potentially faster than PyTorch!)
Profiling Commands
Flamegraph
cargo install cargo-flamegraph
cargo flamegraph --bench gemm_benchPerf Stats
perf stat -e cycles,instructions,cache-references,cache-misses,L1-dcache-loads \
cargo bench --bench gemm_benchAssembly Analysis
cargo install cargo-asm
cargo asm tropical_gemm::simd::kernels::avx2::Avx2MaxPlusF32Kernel::execute --rustMicro-benchmark (create new file)
// benches/microkernel_bench.rs
// Isolate microkernel performance
// Test with different K values
// Compare packed vs unpackedCritical Files to Modify
crates/tropical-gemm/src/simd/kernels/avx2.rs- Fix scalar loopscrates/tropical-gemm/src/core/tiling.rs- Add adaptive parameterscrates/tropical-gemm/src/core/gemm.rs- Workspace API + naive algorithmcrates/tropical-gemm/src/simd/dispatch.rs- Use adaptive tilingcrates/tropical-gemm/src/simd/kernels/avx512.rs- New AVX-512 kernels
Conclusion
The 30-44× slowdown is caused by:
- ❌ Scalar loops in SIMD code (worst offender)
- ❌ Excessive memory allocation
- ❌ Missing AVX-512 implementation
- ❌ Wrong parameters for small matrices
- ❌ Algorithm overhead vs PyTorch's optimized path
All issues are fixable with targeted optimizations.
Recommendation: Start with Phase 1 (microkernel + adaptive tiling) for immediate 5-10× improvement.