Skip to content

Adds GEMM Profiling Guide to TE#2863

Open
jomitchellnv wants to merge 1 commit intoNVIDIA:mainfrom
jomitchellnv:jm/gemm-blog
Open

Adds GEMM Profiling Guide to TE#2863
jomitchellnv wants to merge 1 commit intoNVIDIA:mainfrom
jomitchellnv:jm/gemm-blog

Conversation

@jomitchellnv
Copy link
Copy Markdown
Contributor

Description

Adds a GEMM profiling guide to the Transformer Engine documentation and a companion benchmark tool. The guide
explains how to derive all 12 per-layer GEMM shapes (Fprop, Dgrad, Wgrad) from transformer model
hyperparameters, benchmark them across precisions (BF16, FP8 Block, MXFP8, NVFP4), and interpret the resulting
speedup estimates.

The benchmark tool supports two modes: model config mode (derives shapes automatically from hidden_size,
intermediate_size, etc.) and manual shape mode (explicit MxKxN triplets). It measures both autocast performance
(realistic end-to-end with quantization overhead) and pre-quantized kernel-only throughput, using CUDA events
or torch.profiler timing backends.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add benchmarks/gemm/benchmark_gemm.py — standalone GEMM benchmark tool supporting BF16, FP8 Block, MXFP8, and
    NVFP4 precisions with autocast and pre-quantized modes, CUDA event and torch.profiler timing, Nsight Systems
    integration, and bar-chart output

  • Add docs/features/low_precision_training/gemm_profiling/gemm_profiling.rst — documentation covering GEMM
    shape derivation from model configs, forward/backward pass shape conventions, precision mapping per GEMM pass,
    speedup calculation methodology, and a worked example on B300

  • Add benchmark result plots (img/model_config_speedup.png, img/model_config_speedup_prequant.png)

  • Update docs/features/low_precision_training/index.rst toctree to include the new guide
    Please list the changes introduced in this PR:

  • Change A

  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@jomitchellnv jomitchellnv changed the title adds blog post Adds GEMM Profiling Guide to TE Apr 9, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 9, 2026

Greptile Summary

This PR adds a GEMM profiling guide to the Transformer Engine documentation and a companion benchmark tool (benchmarks/gemm/benchmark_gemm.py) that derives transformer GEMM shapes from model hyperparameters and benchmarks them across BF16, FP8 Block, MXFP8, and NVFP4 precisions.

  • P1: When --verify-dgrad is used, create_model_config_plot is never given dgrad_results and always computes bar heights as 2 × Fprop, silently producing a chart that contradicts the printed table's measured Dgrad times. The fix requires threading dgrad_results and a verify_dgrad flag into the plot function.

Confidence Score: 4/5

Safe to merge with one P1 fix: the verify-dgrad plot discrepancy should be resolved before the tool is used for benchmarking guidance.

One P1 logic bug (plot ignores measured Dgrad with --verify-dgrad) caps the score at 4. The FP8Block omission in shape mode was already flagged in a prior thread. No security or data-corruption concerns.

benchmarks/gemm/benchmark_gemm.py — specifically create_model_config_plot and its call site in run_model_config_benchmarks.

Important Files Changed

Filename Overview
benchmarks/gemm/benchmark_gemm.py New 1609-line GEMM benchmark tool; --verify-dgrad plot uses 2×Fprop approximation instead of measured Dgrad data, and FP8Block precision is omitted from run_benchmarks() shape mode (already flagged).
docs/features/low_precision_training/gemm_profiling/gemm_profiling.rst New RST documentation guide covering GEMM shape derivation, precision mapping, speedup calculation, and a worked B300 example; content is accurate and well-structured.
docs/features/low_precision_training/index.rst Adds gemm_profiling/gemm_profiling.rst to the toctree; straightforward, correct change.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[CLI: main] --> B{has_model_config?}
    B -- Yes --> C[run_model_config_benchmarks]
    B -- No --> D[run_benchmarks shape/profile mode]
    C --> E[compute_gemm_shapes fprop/dgrad/wgrad]
    E --> F[_benchmark_single_shape per shape x precision]
    F --> G{pre_quantize?}
    G -- Yes --> H[benchmark_*_prequantized tex.generic_gemm]
    G -- No --> I[benchmark_* te.Linear autocast]
    C --> M{verify_dgrad?}
    M -- Yes --> N[benchmark dgrad_shapes use measured sums]
    M -- No --> O[assume Dgrad = Fprop x 2]
    C --> P[print per-layer / full-model summary]
    C --> Q[create_model_config_plot ALWAYS uses Fprop x 2 for Fprop+Dgrad bars]
    D --> R[run BF16 / MXFP8 / NVFP4 NOTE: FP8Block omitted]
    D --> S[create_plot]
    style Q fill:#ffcccc
    style R fill:#ffcccc
Loading

Reviews (4): Last reviewed commit: "adds blog post" | Re-trigger Greptile

Comment on lines +794 to +799
results: dict[str, list[float]] = {"BF16": [], "MXFP8": [], "NVFP4": []}
time_results: dict[str, list[float]] = {"BF16": [], "MXFP8": [], "NVFP4": []}

has_blackwell = is_blackwell_available()
run_fp8 = include_fp8 and TE_AVAILABLE
run_fp4 = include_fp4 and TE_AVAILABLE and has_blackwell
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 FP8Block silently omitted in shape mode

run_benchmarks() (used for both default square-shape benchmarks and explicit --shapes invocations) never calls benchmark_fp8_block / benchmark_fp8_block_prequantized. The results dict is initialized with only "BF16", "MXFP8", and "NVFP4", and the function has no include_fp8_block parameter — so the --no-fp8-block flag parsed in main() is only forwarded to run_model_config_benchmarks (line 1579) and has no effect here.

Users who run the tool in shape mode (no model-config flags) will silently receive BF16/MXFP8/NVFP4 data only, even though the module docstring advertises "BF16, FP8 Block, MXFP8, and NVFP4 precisions."

To fix, add include_fp8_block: bool = True to run_benchmarks, initialise results["FP8Block"] = [], select fp8_block_fn the same way model-config mode does, and forward the flag from main().

Comment on lines +1355 to +1367
color=op_color,
alpha=0.9,
label=f"{op_label} (Fprop+Dgrad)" if i == 0 or True else "",
)
ax.bar(
x,
wgrad_ms,
bar_width,
bottom=all_fprop_total + total_wgrad_bottom,
color=op_color,
alpha=0.5,
label=f"{op_label} (Wgrad)" if i == 0 or True else "",
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Dead condition if i == 0 or True always evaluates to True

Both label= expressions use if i == 0 or True, which unconditionally takes the True branch. This is dead code — or True makes the condition tautological. The intent was likely either True (always label, which is fine for a stacked bar chart) or if i == 0 (label only the first series). Clean it up to express intent clearly:

Suggested change
color=op_color,
alpha=0.9,
label=f"{op_label} (Fprop+Dgrad)" if i == 0 or True else "",
)
ax.bar(
x,
wgrad_ms,
bar_width,
bottom=all_fprop_total + total_wgrad_bottom,
color=op_color,
alpha=0.5,
label=f"{op_label} (Wgrad)" if i == 0 or True else "",
)
label=f"{op_label} (Fprop+Dgrad)",

and

Suggested change
color=op_color,
alpha=0.9,
label=f"{op_label} (Fprop+Dgrad)" if i == 0 or True else "",
)
ax.bar(
x,
wgrad_ms,
bar_width,
bottom=all_fprop_total + total_wgrad_bottom,
color=op_color,
alpha=0.5,
label=f"{op_label} (Wgrad)" if i == 0 or True else "",
)
label=f"{op_label} (Wgrad)",

Comment on lines +18 to +21
* **profiler** -- ``torch.profiler`` (CUPTI) kernel timestamps.
Only the matched GEMM compute kernels (nvjet, xmma, cutlass, cublas)
are summed, giving a kernel-only measurement.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Docstring lists "cublas" but the pattern tuple uses "gemm" instead

The module docstring (line 19) lists the matched kernel patterns as (nvjet, xmma, cutlass, cublas), but GEMM_KERNEL_PATTERNS at line 70 is ("gemm", "nvjet", "xmma", "cutlass")"cublas" is absent and "gemm" was added in its place. In practice "gemm" does catch cuBLAS kernels (their names contain gemm), so the behaviour is correct, but the docstring is inaccurate and may confuse users auditing kernel coverage.

Suggested change
* **profiler** -- ``torch.profiler`` (CUPTI) kernel timestamps.
Only the matched GEMM compute kernels (nvjet, xmma, cutlass, cublas)
are summed, giving a kernel-only measurement.
* **profiler** -- ``torch.profiler`` (CUPTI) kernel timestamps.
Only the matched GEMM compute kernels (gemm, nvjet, xmma, cutlass)
are summed, giving a kernel-only measurement.

@pggPL pggPL self-requested a review April 10, 2026 14:00
@pggPL
Copy link
Copy Markdown
Collaborator

pggPL commented Apr 13, 2026

Hi @jomitchellnv, I see that this PR is open, but "Documentation" job is failing. If you fix it, please ping me and I'll review it.

@jomitchellnv
Copy link
Copy Markdown
Contributor Author

@pggPL they should be fixed now I hope

@jomitchellnv
Copy link
Copy Markdown
Contributor Author

/te-ci L1 pytorch

Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
Comment on lines +1405 to +1407
loc="upper right",
fontsize=8,
ncol=2,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 --verify-dgrad plot silently uses approximation instead of measured values

When --verify-dgrad is passed, run_model_config_benchmarks benchmarks and records actual Dgrad timings into dgrad_results, and the printed table correctly shows those measured values. However, create_model_config_plot is never given dgrad_results — the call site only passes fprop_results and wgrad_results. Inside the plot function, Fprop+Dgrad bar height is always computed as fp.avg_time_ms * 2 (the approximation), so the chart silently contradicts the table when --verify-dgrad is used.

Fix: add dgrad_results and verify_dgrad parameters to create_model_config_plot, and when verify_dgrad=True, use fprop_ms[j] + dgrad_ms[j] instead of fprop_ms[j] * 2 for each op bar.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants