Skip to content

add torch sdpa to TransformerEncoder#13580

Closed
MahmoudAshraf97 wants to merge 1 commit into
NVIDIA-NeMo:mainfrom
MahmoudAshraf97:sdpa_transformer
Closed

add torch sdpa to TransformerEncoder#13580
MahmoudAshraf97 wants to merge 1 commit into
NVIDIA-NeMo:mainfrom
MahmoudAshraf97:sdpa_transformer

Conversation

@MahmoudAshraf97
Copy link
Copy Markdown
Contributor

This PR follows the path of #9590 and adds torch SDPA implementation to TransformerEncoder that is used in Sortformer
This module also exists in nemo.collections.nlp with duplicated code so let me know if I should modify it there also.
These are the benchmark results using n_heads=8 and hidden_dim=192 which match sortformer transformer encoder

Batch Size Sequence Length SDPA Boost (%) SDPA Backward Boost (%)
1 128 9.71 28.57
1 512 60.41 48.02
1 1024 71.51 36.23
4 128 88.42 -0.15
4 512 83.39 51.32
4 1024 84.84 52.10
16 128 77.85 58.24
16 512 77.18 53.02
16 1024 81.64 54.72

Benchmark Code:

Details
import torch
import torch.utils.benchmark as benchmark
from nemo.collections.asr.modules.transformer.transformer_modules import MultiHeadAttention

torch.manual_seed(123)

device = "cuda"

batch_sizes = [1,4,16]
seq_lens = [128,512,1024]
num_heads = 8
hidden_dim = 192
masked = False

mha = MultiHeadAttention(hidden_dim, num_heads, 0.1, 0.1).to(device)
mha_sdpa = MultiHeadAttention(hidden_dim, num_heads, 0.1, 0.1, use_pytorch_sdpa=True).to(device)
mha_sdpa.load_state_dict(mha.state_dict())

mha.eval()
mha_sdpa.eval()

def measure_time(attention, query, key, value, mask):
    timer = benchmark.Timer(
        stmt='attention(query, key, value, mask);torch.cuda.synchronize(); torch.cuda.empty_cache()',
        setup='torch.cuda.synchronize(); torch.cuda.empty_cache()',
        globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask}
    )

    with torch.no_grad():
        torch.cuda.synchronize()
        results = timer.blocked_autorange(min_run_time=10)
        forward_time = results.mean
        output = attention(query, key, value, mask)
    return forward_time, output


def measure_fwd_bwd_time(attention, query, key, value, mask):
    timer = benchmark.Timer(
        stmt='loss=attention(query, key, value, mask).sum();torch.cuda.synchronize();loss.backward();torch.cuda.synchronize(); torch.cuda.empty_cache()',
        globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask}
    )
    torch.cuda.synchronize()
    results = timer.blocked_autorange(min_run_time=10)
    fwd_bwd_time = results.mean
    return fwd_bwd_time

for batch_size in batch_sizes:
    for seq_len in seq_lens:

        input_tensor = torch.randn(batch_size, seq_len, 192).to(device)
        if masked:
            mask = torch.randint(0, 2, (batch_size, num_heads, seq_len, seq_len)).bool()
            mask = torch.where(mask, torch.tensor(float('-inf')), torch.tensor(0.0)).to(device)
        else:
            mask = None


        time_fwd_original, output_original = measure_time(mha, input_tensor, input_tensor, input_tensor, mask)
        time_fwd_sdpa, output_sdpa = measure_time(mha_sdpa, input_tensor, input_tensor, input_tensor, mask)
        print(f"Batch size: {batch_size}, Sequence length: {seq_len}")
        # print(f"Original implementation time: {time_fwd_original:.6f} seconds")
        # print(f"SDPA implementation time: {time_fwd_sdpa:.6f} seconds")
        print(f"SDPA boost {(time_fwd_original - time_fwd_sdpa) / time_fwd_original * 100:.2f}%")

        time_fwd_bwd_original = measure_fwd_bwd_time(mha, input_tensor, input_tensor, input_tensor, mask)
        time_fwd_bwd_sdpa = measure_fwd_bwd_time(mha_sdpa, input_tensor, input_tensor, input_tensor, mask)
        time_bwd_original = time_fwd_bwd_original - time_fwd_original
        time_bwd_sdpa = time_fwd_bwd_sdpa - time_fwd_sdpa

        # print(f"Original implementation backward time: {time_bwd_original:.6f} seconds")
        # print(f"SDPA implementation backward time: {time_bwd_sdpa:.6f} seconds")
        print(f"SDPA backward boost {(time_bwd_original - time_bwd_sdpa) / time_bwd_original * 100:.2f}%")

        assert torch.allclose(output_original, output_sdpa, atol=1e-5)

Collection: ASR and Possible NLP

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

@titu1994, @redoctopus, @jbalam-nv, @okuchaiev and @pzelasko since this module is also used in some canary configs, @tango4j since this is used in sortformer

Signed-off-by: MahmoudAshraf97 <hassouna97.ma@gmail.com>
@github-actions
Copy link
Copy Markdown
Contributor

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

@github-actions github-actions Bot added the stale label May 29, 2025
@MahmoudAshraf97
Copy link
Copy Markdown
Contributor Author

Bump

@github-actions github-actions Bot removed the stale label May 30, 2025
@github-actions
Copy link
Copy Markdown
Contributor

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

@github-actions github-actions Bot added the stale label Jun 13, 2025
@github-actions
Copy link
Copy Markdown
Contributor

This PR was closed because it has been inactive for 7 days since being marked as stale.

@github-actions github-actions Bot closed this Jun 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant