[WebGPU] Implement tir.dp4a with WGSL built-in function dot4I8Packed#16976
Merged
tqchen merged 8 commits intoapache:mainfrom Jul 4, 2024
Merged
[WebGPU] Implement tir.dp4a with WGSL built-in function dot4I8Packed#16976tqchen merged 8 commits intoapache:mainfrom
tir.dp4a with WGSL built-in function dot4I8Packed#16976tqchen merged 8 commits intoapache:mainfrom
Conversation
This patch adds the support of `__dp4a(int8x4, int8x4)` as a pure
extern method of WebGPU target. In the generated WGSL shader,
`int8x4` will be translated into `u32`, and `__dp4a(int8x4, int8x4)`
will be translated into the WGSL built-in function
`dot4I8Packed(u32, u32)`.
Here is an example to use `__dp4a` in WebGPU target:
```
n = te.var("n")
A = te.placeholder((n,), "int8x4", name="A")
B = te.placeholder((n,), "int8x4", name="B")
C = te.compute(A.shape, lambda i: tvm.tir.call_pure_extern("int32", "__dp4a", A[i], B[i]), name="C")
s = te.create_schedule(C.op)
bx, tx = s[C].split(C.op.axis[0], factor=64)
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(tx, te.thread_axis("threadIdx.x"))
mod = tvm.build(s, [A, B, C], tgt, name="dp4aTest")
```
Issue: apache#16627
tqchen
approved these changes
May 8, 2024
tqchen
requested changes
May 8, 2024
__dp4a(int8x4, int8x4) as a pure extern methoddot4I8Packed(int8x4, int8x4) as a pure extern method
tqchen
reviewed
May 9, 2024
| // extra dispatch | ||
| TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchFastErf); | ||
|
|
||
| TVM_REGISTER_OP("tir.dot4I8Packed").set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchPureExtern<Direct>); |
Member
There was a problem hiding this comment.
sorry i was not being clear, for tir, it is better to have a common name dp4a (as this intrinsic shared across backends)
Contributor
Author
Member
There was a problem hiding this comment.
we can add tir.dp4a intrinsic, and use it to lower to various places
Contributor
Author
There was a problem hiding this comment.
Oh sorry I was busy on some other urgent stuffs these days. I will go back to work on this next week. I will follow the steps to add tir.dp4a first.
Contributor
Author
There was a problem hiding this comment.
Hi @tqchen,
Sorry for my late response. I've updated this PR. PTAL, thanks!
dot4I8Packed(int8x4, int8x4) as a pure extern methodbuiltin::dp4a with WGSL built-in function dot4I8Packed
builtin::dp4a with WGSL built-in function dot4I8Packedtir.dp4a with WGSL built-in function dot4I8Packed
Contributor
Author
|
@tqchen Now the PR has passed all the tests. PTAL, thanks! |
tqchen
approved these changes
Jul 4, 2024
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This patch implements
tir.dp4awith WGSL built-in functiondot4I8Packed()on WebGPU backend.Here is an example to use
tir.dp4ain WebGPU target:Issue: #16627