Skip to content

Commit 8eb5513

Browse files
[webgpu] Implement SubGroupMatrix based MatMulNBits for Metal (#23729)
### Description Recent progress with SubGroupMatrix prototype in Dawn https://issues.chromium.org/issues/348702031, exposes SIMD-Group Matrix Functions to webgpu. This shader implements a matmulnbits using that primitive. Observed perf gains, in terms of LLM inference speed, prefill perf for Phi 3.5 for a 1K token prefill see 3x improvement. 5.4s from 15s. With Changes ``` ./model_benchmark -i ~/Phi-3.5-mini-instruct-onnx-web -l 1000 Batch size: 1, prompt tokens: 1001, tokens to generate: 128 Prompt processing (time to first token): avg (us): 5.42498e+06 <<< SubGroupMatrix 5.4s avg (tokens/s): 184.517 p50 (us): 5.41982e+06 stddev (us): 12023.8 n: 5 * 1001 token(s) Token generation: avg (us): 91138.5 avg (tokens/s): 10.9723 p50 (us): 89488.5 stddev (us): 35136.2 n: 635 * 1 token(s) ``` Baseline ``` ./model_benchmark -i ~/Phi-3.5-mini-instruct-onnx-web -l 1000 Batch size: 1, prompt tokens: 1001, tokens to generate: 128 Prompt processing (time to first token): avg (us): 1.45507e+07 <<< Baseline 14.5s avg (tokens/s): 68.7938 p50 (us): 1.45413e+07 stddev (us): 22208.9 n: 5 * 1001 token(s) Token generation: avg (us): 94109.8 avg (tokens/s): 10.6259 p50 (us): 89660 stddev (us): 61579 n: 635 * 1 token(s) ```
1 parent d82604e commit 8eb5513

8 files changed

Lines changed: 304 additions & 5 deletions

File tree

cmake/deps.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,5 +58,5 @@ extensions;https://github.com/microsoft/onnxruntime-extensions/archive/c24b7bab0
5858
composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/204da9c522cebec5220bba52cd3542ebcaf99e7a.zip;1827348efd47831c13074245274d41b7cae8a557
5959
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
6060
cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.7.0.zip;d0753d8d5b39947ca0729d7773cb84653a129eb1
61-
dawn;https://github.com/google/dawn/archive/b9b4a37041dec3dd62ac92014a6cc1aece48d9f3.zip;e8b8c2ebabdedb7c57d931fc4a19ae22146d31e1
61+
dawn;https://github.com/google/dawn/archive/40a9fa79f76e6c76cca9e2fa69ea07f202f1d2e6.zip;e224563d5ab4a8e53a517b06f721242533bce722
6262
kleidiai;https://gitlab.arm.com/kleidi/kleidiai/-/archive/d15722976120710080ca098fe8ddabf4556cb40f/kleidiai-d15722976120710080ca098fe8ddabf4556cb40f.zip;d6c840d00c3b05aedf06e957ddaece1013d1f40b

cmake/patches/dawn/dawn.patch

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,24 @@
1+
diff --git a/src/cmake/DawnCompilerPlatformFlags.cmake b/src/cmake/DawnCompilerPlatformFlags.cmake
2+
index 50638e2456..efa42711e6 100644
3+
--- a/src/cmake/DawnCompilerPlatformFlags.cmake
4+
+++ b/src/cmake/DawnCompilerPlatformFlags.cmake
5+
@@ -63,7 +63,3 @@ endif ()
6+
if (MSVC AND NOT COMPILER_IS_CLANG_CL)
7+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP")
8+
endif ()
9+
-
10+
-if (TARGET_MACOS)
11+
- set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS version" FORCE)
12+
-endif ()
13+
\ No newline at end of file
114
diff --git a/src/emdawnwebgpu/CMakeLists.txt b/src/emdawnwebgpu/CMakeLists.txt
215
index 6e8ae37593..633af91eef 100644
316
--- a/src/emdawnwebgpu/CMakeLists.txt
417
+++ b/src/emdawnwebgpu/CMakeLists.txt
518
@@ -77,9 +77,17 @@ if (${DAWN_ENABLE_EMSCRIPTEN})
619
"${arg_UNPARSED_ARGUMENTS}")
720
endif()
8-
21+
922
+ # since Emscripten 4.0.3, file gen_struct_info.py is moved to outside of directory maint.
1023
+ if (EXISTS "${DAWN_EMSCRIPTEN_TOOLCHAIN}/tools/gen_struct_info.py")
1124
+ set(EM_GEN_STRUCT_INFO_SCRIPT "${DAWN_EMSCRIPTEN_TOOLCHAIN}/tools/gen_struct_info.py")

onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <string_view>
55

66
#include "contrib_ops/webgpu/quantization/matmul_nbits.h"
7+
#include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h"
78
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
89
#include "core/providers/cpu/math/matmul_helper.h"
910
#include "core/providers/webgpu/shader_helper.h"
@@ -815,6 +816,12 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
815816
uint32_t components = GetMaxComponents(N);
816817

817818
const bool has_zero_points = zero_points != nullptr;
819+
// macOS - Experimental dawn support for subgroup matrix matmul on Metal.
820+
if (M >= kMinMForTileOptimization &&
821+
CanApplySubgroupMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, has_zero_points)) {
822+
return ApplySubgroupMatrixMatMulNBits(a, b, scales, M, N, K, context, y);
823+
}
824+
818825
const bool has_subgroup = context.Device().HasFeature(wgpu::FeatureName::Subgroups);
819826
// macOS - Avoid using dp4a on Metal, as it does not appear to have native dp4a support.
820827
// https://github.com/gpuweb/gpuweb/issues/2677#issuecomment-1713292226
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h"
5+
6+
namespace onnxruntime {
7+
namespace contrib {
8+
namespace webgpu {
9+
10+
Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
11+
shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
12+
shader.AddInput("input_b", ShaderUsage::UseUniform);
13+
shader.AddInput("scales_b", ShaderUsage::UseUniform);
14+
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias);
15+
16+
// tile/subtile sizes and work distribution are inspired from metal shaders in llama.cpp (kernel_mul_mm)
17+
// https://github.com/ggml-org/llama.cpp/blob/d04e7163c85a847bc61d58c22f2c503596db7aa8/ggml/src/ggml-metal/ggml-metal.metal#L6066
18+
shader.AdditionalImplementation() << R"ADDNL_FN(
19+
const tile_cols = 64;
20+
const tile_rows = 32;
21+
const tile_k = 32;
22+
const subtile_cols = 32;
23+
const subtile_rows = 16;
24+
const quantization_block_size = 32;
25+
alias compute_precision = output_element_t;
26+
27+
var<workgroup> tile_A: array<compute_precision, tile_rows * tile_k>; // 32 x 32 - RxC
28+
var<workgroup> tile_B: array<compute_precision, tile_cols * tile_k>; // 64 x 32 - RxC
29+
var<workgroup> scratch: array<array<array<compute_precision, 64>, 4>, 4>; // 64 * 4 * 4
30+
31+
fn loadSHMA(tile_base: u32, k_idx: u32, row: u32, c_idx:u32) {
32+
let a_global = tile_base + row;
33+
if (a_global >= uniforms.M) {
34+
return;
35+
}
36+
// Each call loads 8 columns, starting at col.
37+
var col = c_idx * 8;
38+
// 128 threads need to load 32 x 32. 4 threads per row or 8 col per thread.
39+
for (var col_offset:u32 = 0; col_offset < 8; col_offset++)
40+
{
41+
tile_A[row * tile_k + col + col_offset] = compute_precision(input_a[a_global*uniforms.K + k_idx + col + col_offset]);
42+
}
43+
}
44+
45+
fn loadSHMB(tile_base: u32, k_idx: u32, row: u32, c_idx: u32) {
46+
let b_global = tile_base + row;
47+
if (b_global >= uniforms.N) {
48+
return;
49+
}
50+
// Each call loads 16 columns, starting at col.
51+
var col = c_idx * 16;
52+
// 128 threads need to load 64 x 32. 2 threads per row or 16 col per thread.
53+
// Stored in column major fashion.
54+
let b_idx = u32((b_global*uniforms.K + k_idx + col)/8);
55+
let scale = compute_precision(scales_b[(b_global*uniforms.K + k_idx + col)/quantization_block_size]);
56+
for (var step:u32 = 0; step < 2; step++)
57+
{
58+
var b_value = input_b[b_idx+step];
59+
var b_value_lower = (vec4<compute_precision>(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4<compute_precision>(8)) * scale;
60+
var b_value_upper = (vec4<compute_precision>(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4<compute_precision>(8)) * scale;
61+
let tile_b_base = row * tile_k + col + step * 8;
62+
tile_B[tile_b_base] = b_value_lower[0];
63+
tile_B[tile_b_base + 1] = b_value_upper[0];
64+
tile_B[tile_b_base + 2] = b_value_lower[1];
65+
tile_B[tile_b_base + 3] = b_value_upper[1];
66+
tile_B[tile_b_base + 4] = b_value_lower[2];
67+
tile_B[tile_b_base + 5] = b_value_upper[2];
68+
tile_B[tile_b_base + 6] = b_value_lower[3];
69+
tile_B[tile_b_base + 7] = b_value_upper[3];
70+
}
71+
}
72+
73+
fn storeOutput(offset:u32, row: u32, col:u32, src_slot:u32, row_limit:i32) {
74+
if (row_limit > 0 && row < u32(row_limit))
75+
{
76+
output[offset + row * uniforms.N + col] = output_element_t(scratch[src_slot][0][row * 8 + col]);
77+
output[offset + row * uniforms.N + col + 8] = output_element_t(scratch[src_slot][1][row * 8 + col]);
78+
output[offset + row * uniforms.N + col + 16] = output_element_t(scratch[src_slot][2][row * 8 + col]);
79+
output[offset + row * uniforms.N + col + 24] = output_element_t(scratch[src_slot][3][row * 8 + col]);
80+
let col2 = col + 1;
81+
output[offset + row * uniforms.N + col2] = output_element_t(scratch[src_slot][0][row * 8 + col2]);
82+
output[offset + row * uniforms.N + col2 + 8] = output_element_t(scratch[src_slot][1][row * 8 + col2]);
83+
output[offset + row * uniforms.N + col2 + 16] = output_element_t(scratch[src_slot][2][row * 8 + col2]);
84+
output[offset + row * uniforms.N + col2 + 24] = output_element_t(scratch[src_slot][3][row * 8 + col2]);
85+
}
86+
}
87+
)ADDNL_FN";
88+
89+
shader.MainFunctionBody() << R"MAIN_FN(
90+
let a_global_base = workgroup_id.y * tile_rows;
91+
let b_global_base = workgroup_id.x * tile_cols;
92+
93+
let subtile_id = u32(local_idx / sg_size);
94+
let subtile_idx = u32(subtile_id / 2);
95+
let subtile_idy = subtile_id % 2;
96+
let base_A = subtile_idy * subtile_rows;
97+
let base_B = subtile_idx * subtile_cols;
98+
99+
var matC00: subgroup_matrix_result<compute_precision, 8, 8>;
100+
var matC01: subgroup_matrix_result<compute_precision, 8, 8>;
101+
var matC02: subgroup_matrix_result<compute_precision, 8, 8>;
102+
var matC03: subgroup_matrix_result<compute_precision, 8, 8>;
103+
var matC10: subgroup_matrix_result<compute_precision, 8, 8>;
104+
var matC11: subgroup_matrix_result<compute_precision, 8, 8>;
105+
var matC12: subgroup_matrix_result<compute_precision, 8, 8>;
106+
var matC13: subgroup_matrix_result<compute_precision, 8, 8>;
107+
for (var kidx: u32 = 0; kidx < uniforms.K; kidx += tile_k) {
108+
// Load Phase
109+
loadSHMA(a_global_base, kidx, local_idx/4, local_idx%4);
110+
loadSHMB(b_global_base, kidx, local_idx/2, local_idx%2);
111+
workgroupBarrier();
112+
113+
for (var step: u32 = 0; step < tile_k; step+=8)
114+
{
115+
// Load to local memory phase
116+
let matrix_a_offset = subtile_idy * subtile_rows * tile_k + step;
117+
// Syntax: subgroupMatrixLoad src_ptr,src_offset,is_col_major,src_stride
118+
var matA0: subgroup_matrix_left<compute_precision, 8, 8> = subgroupMatrixLoad<subgroup_matrix_left<compute_precision, 8, 8>>(&tile_A, matrix_a_offset, false, tile_k);
119+
var matA1: subgroup_matrix_left<compute_precision, 8, 8> = subgroupMatrixLoad<subgroup_matrix_left<compute_precision, 8, 8>>(&tile_A, matrix_a_offset + 8 * tile_k, false, tile_k);
120+
121+
// tile_B is stored as column major.
122+
// [col0-0:32][col1-0:32][col2-0:32]..[col63-0:32]
123+
var matrix_b_offset = subtile_idx * subtile_cols * tile_k + step;
124+
var matB0: subgroup_matrix_right<compute_precision, 8, 8> = subgroupMatrixLoad<subgroup_matrix_right<compute_precision, 8, 8>>(&tile_B, matrix_b_offset, true, tile_k);
125+
var matB1: subgroup_matrix_right<compute_precision, 8, 8> = subgroupMatrixLoad<subgroup_matrix_right<compute_precision, 8, 8>>(&tile_B, matrix_b_offset + 8 * tile_k, true, tile_k);
126+
var matB2: subgroup_matrix_right<compute_precision, 8, 8> = subgroupMatrixLoad<subgroup_matrix_right<compute_precision, 8, 8>>(&tile_B, matrix_b_offset + 16 * tile_k, true, tile_k);
127+
var matB3: subgroup_matrix_right<compute_precision, 8, 8> = subgroupMatrixLoad<subgroup_matrix_right<compute_precision, 8, 8>>(&tile_B, matrix_b_offset + 24 * tile_k, true, tile_k);
128+
129+
// Compute Phase
130+
// Syntax: subgroupMatrixMultiplyAccumulate left, right, accumulate -> accumulate
131+
matC00 = subgroupMatrixMultiplyAccumulate(matA0, matB0, matC00);
132+
matC01 = subgroupMatrixMultiplyAccumulate(matA0, matB1, matC01);
133+
matC02 = subgroupMatrixMultiplyAccumulate(matA0, matB2, matC02);
134+
matC03 = subgroupMatrixMultiplyAccumulate(matA0, matB3, matC03);
135+
136+
matC10 = subgroupMatrixMultiplyAccumulate(matA1, matB0, matC10);
137+
matC11 = subgroupMatrixMultiplyAccumulate(matA1, matB1, matC11);
138+
matC12 = subgroupMatrixMultiplyAccumulate(matA1, matB2, matC12);
139+
matC13 = subgroupMatrixMultiplyAccumulate(matA1, matB3, matC13);
140+
}
141+
142+
workgroupBarrier();
143+
}
144+
145+
// Write out
146+
// Write out top block
147+
subgroupMatrixStore(&scratch[subtile_id][0], 0, matC00, false, 8);
148+
subgroupMatrixStore(&scratch[subtile_id][1], 0, matC01, false, 8);
149+
subgroupMatrixStore(&scratch[subtile_id][2], 0, matC02, false, 8);
150+
subgroupMatrixStore(&scratch[subtile_id][3], 0, matC03, false, 8);
151+
workgroupBarrier();
152+
let row = u32(sg_id / 4);
153+
var col = u32(sg_id % 4) * 2;
154+
var matrix_c_offset = (a_global_base+base_A) * uniforms.N + b_global_base + base_B;
155+
var row_limit:i32 = i32(uniforms.M) - i32(a_global_base + base_A);
156+
storeOutput(matrix_c_offset, row, col, subtile_id, row_limit);
157+
workgroupBarrier();
158+
159+
// Write out bottom block
160+
subgroupMatrixStore(&scratch[subtile_id][0], 0, matC10, false, 8);
161+
subgroupMatrixStore(&scratch[subtile_id][1], 0, matC11, false, 8);
162+
subgroupMatrixStore(&scratch[subtile_id][2], 0, matC12, false, 8);
163+
subgroupMatrixStore(&scratch[subtile_id][3], 0, matC13, false, 8);
164+
workgroupBarrier();
165+
matrix_c_offset = matrix_c_offset + 8 * uniforms.N;
166+
row_limit = i32(uniforms.M) - i32(a_global_base + base_A + 8);
167+
storeOutput(matrix_c_offset, row, col, subtile_id, row_limit);
168+
)MAIN_FN";
169+
170+
return Status::OK();
171+
}
172+
173+
Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
174+
uint32_t M,
175+
uint32_t N,
176+
uint32_t K,
177+
onnxruntime::webgpu::ComputeContext& context,
178+
Tensor* y) {
179+
constexpr uint32_t kTileSizeA = 32;
180+
constexpr uint32_t kTileSizeB = 64;
181+
constexpr uint32_t kU32Components = 4;
182+
TensorShape y_shape{1, M, N};
183+
SubgroupMatrixMatMulNBitsProgram mul_program;
184+
mul_program.SetWorkgroupSize(128);
185+
mul_program.SetDispatchGroupSize(
186+
(N + kTileSizeB - 1) / kTileSizeB,
187+
(M + kTileSizeA - 1) / kTileSizeA, 1);
188+
mul_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(1)},
189+
{b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kU32Components)},
190+
{scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(1)}})
191+
.AddUniformVariables({{static_cast<uint32_t>(M)},
192+
{static_cast<uint32_t>(N)},
193+
{static_cast<uint32_t>(K)}})
194+
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, y_shape, gsl::narrow<int>(1)});
195+
return context.RunProgram(mul_program);
196+
}
197+
198+
bool CanApplySubgroupMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context,
199+
uint64_t accuracy_level,
200+
uint32_t block_size,
201+
uint32_t batch_count,
202+
uint32_t N,
203+
uint32_t K,
204+
bool has_zero_points) {
205+
#if !defined(__wasm__)
206+
const bool has_subgroup_matrix = context.Device().HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
207+
#else
208+
const bool has_subgroup_matrix = false;
209+
#endif
210+
// For now SubgroupMatrixMatMulNBits is only supported for accuracy level 4, because with Fp16 there are
211+
// some precision issues with subgroupMatrixMultiplyAccumulate. It is possible to support higher accuracy
212+
// by setting compute_precision to Fp32, but that will be slower. For 1K token prefill FP16 Phi 3.5 is around 5s,
213+
// FP322 is around 7s.
214+
return context.AdapterInfo().backendType == wgpu::BackendType::Metal &&
215+
has_subgroup_matrix &&
216+
accuracy_level == 4 &&
217+
block_size == 32 &&
218+
batch_count == 1 &&
219+
K % 32 == 0 &&
220+
N % 64 == 0 &&
221+
!has_zero_points;
222+
}
223+
} // namespace webgpu
224+
} // namespace contrib
225+
} // namespace onnxruntime
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/providers/webgpu/program.h"
7+
#include "core/providers/webgpu/compute_context.h"
8+
#include "core/providers/webgpu/program.h"
9+
#include "core/providers/webgpu/shader_helper.h"
10+
#include "core/providers/webgpu/webgpu_kernel.h"
11+
12+
namespace onnxruntime {
13+
namespace contrib {
14+
namespace webgpu {
15+
16+
using namespace onnxruntime::webgpu;
17+
18+
class SubgroupMatrixMatMulNBitsProgram final : public Program<SubgroupMatrixMatMulNBitsProgram> {
19+
public:
20+
SubgroupMatrixMatMulNBitsProgram() : Program{"SubgroupMatrixMatMulNBits"} {}
21+
Status GenerateShaderCode(ShaderHelper& sh) const override;
22+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
23+
{"M", ProgramUniformVariableDataType::Uint32},
24+
{"N", ProgramUniformVariableDataType::Uint32},
25+
{"K", ProgramUniformVariableDataType::Uint32});
26+
};
27+
28+
Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
29+
uint32_t M,
30+
uint32_t N,
31+
uint32_t K,
32+
onnxruntime::webgpu::ComputeContext& context,
33+
Tensor* y);
34+
35+
bool CanApplySubgroupMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context,
36+
uint64_t accuracy_level,
37+
uint32_t block_size,
38+
uint32_t batch_count,
39+
uint32_t N,
40+
uint32_t K,
41+
bool has_zero_points);
42+
43+
} // namespace webgpu
44+
} // namespace contrib
45+
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/shader_helper.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,11 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector<int>& sha
352352
if (device_.HasFeature(wgpu::FeatureName::Subgroups)) {
353353
ss << "enable subgroups;\n";
354354
}
355+
#if !defined(__wasm__)
356+
if (device_.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
357+
ss << "enable chromium_experimental_subgroup_matrix;\n";
358+
}
359+
#endif
355360

356361
//
357362
// Section constants

onnxruntime/core/providers/webgpu/webgpu_context.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,7 @@ std::vector<wgpu::FeatureName> WebGpuContext::GetAvailableRequiredFeatures(const
486486
constexpr wgpu::FeatureName features[]{
487487
#if !defined(__wasm__)
488488
wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses,
489+
wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix,
489490
#endif
490491
wgpu::FeatureName::TimestampQuery,
491492
wgpu::FeatureName::ShaderF16,
@@ -738,7 +739,7 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co
738739
// Step.2 - Create wgpu::Instance
739740
#if !defined(__wasm__)
740741
wgpu::InstanceDescriptor instance_desc{};
741-
instance_desc.features.timedWaitAnyEnable = true;
742+
instance_desc.capabilities.timedWaitAnyEnable = true;
742743
default_instance_ = wgpu::CreateInstance(&instance_desc);
743744
#else
744745
default_instance_ = wgpu::CreateInstance(nullptr);

onnxruntime/core/providers/webgpu/webgpu_context.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,11 @@ class WebGpuContext final {
9494
wgpu::ComputePassDescriptor compute_pass_desc{};
9595

9696
if (is_profiling_ && query_type_ == TimestampQueryType::AtPasses) {
97-
wgpu::ComputePassTimestampWrites timestampWrites = {
98-
query_set_, num_pending_dispatches_ * 2, num_pending_dispatches_ * 2 + 1};
97+
wgpu::PassTimestampWrites timestampWrites = {
98+
nullptr,
99+
query_set_,
100+
num_pending_dispatches_ * 2,
101+
num_pending_dispatches_ * 2 + 1};
99102
compute_pass_desc.timestampWrites = &timestampWrites;
100103
}
101104

0 commit comments

Comments
 (0)