Skip to content

fix(zero): enable vmap on LinearFunctionForZeroStage3#8023

Open
roycho96 wants to merge 1 commit into
deepspeedai:masterfrom
roycho96:fix/zero-linear-vmap-rule
Open

fix(zero): enable vmap on LinearFunctionForZeroStage3#8023
roycho96 wants to merge 1 commit into
deepspeedai:masterfrom
roycho96:fix/zero-linear-vmap-rule

Conversation

@roycho96
Copy link
Copy Markdown
Contributor

@roycho96 roycho96 commented May 22, 2026

Follow-up to #7916.

Adds generate_vmap_rule = True to LinearFunctionForZeroStage3 so torch.func.vmap works on the Function directly. The previous PR covered grad / jacrev via setup_context but not vmap.

Test:
pytest tests/unit/v1/zero/test_zero_functorch_linear.py::TestLinearFunctionVmap

The forward is a pure tensor op (addmm / matmul + bias) with no closure
state, so PyTorch's auto-generated vmap rule produces correct batched
semantics. Without this, vmap (and vmap(grad)) over the Function raises
'does not have vmap support', the case the setup_context fix in deepspeedai#7916
left unaddressed.

Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant