|
| 1 | +// Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +// Licensed under the MIT License. |
| 3 | + |
| 4 | +#include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h" |
| 5 | + |
| 6 | +namespace onnxruntime { |
| 7 | +namespace contrib { |
| 8 | +namespace webgpu { |
| 9 | + |
| 10 | +Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { |
| 11 | + shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); |
| 12 | + shader.AddInput("input_b", ShaderUsage::UseUniform); |
| 13 | + shader.AddInput("scales_b", ShaderUsage::UseUniform); |
| 14 | + shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); |
| 15 | + |
| 16 | + // tile/subtile sizes and work distribution are inspired from metal shaders in llama.cpp (kernel_mul_mm) |
| 17 | + // https://github.com/ggml-org/llama.cpp/blob/d04e7163c85a847bc61d58c22f2c503596db7aa8/ggml/src/ggml-metal/ggml-metal.metal#L6066 |
| 18 | + shader.AdditionalImplementation() << R"ADDNL_FN( |
| 19 | + const tile_cols = 64; |
| 20 | + const tile_rows = 32; |
| 21 | + const tile_k = 32; |
| 22 | + const subtile_cols = 32; |
| 23 | + const subtile_rows = 16; |
| 24 | + const quantization_block_size = 32; |
| 25 | + alias compute_precision = output_element_t; |
| 26 | +
|
| 27 | + var<workgroup> tile_A: array<compute_precision, tile_rows * tile_k>; // 32 x 32 - RxC |
| 28 | + var<workgroup> tile_B: array<compute_precision, tile_cols * tile_k>; // 64 x 32 - RxC |
| 29 | + var<workgroup> scratch: array<array<array<compute_precision, 64>, 4>, 4>; // 64 * 4 * 4 |
| 30 | +
|
| 31 | + fn loadSHMA(tile_base: u32, k_idx: u32, row: u32, c_idx:u32) { |
| 32 | + let a_global = tile_base + row; |
| 33 | + if (a_global >= uniforms.M) { |
| 34 | + return; |
| 35 | + } |
| 36 | + // Each call loads 8 columns, starting at col. |
| 37 | + var col = c_idx * 8; |
| 38 | + // 128 threads need to load 32 x 32. 4 threads per row or 8 col per thread. |
| 39 | + for (var col_offset:u32 = 0; col_offset < 8; col_offset++) |
| 40 | + { |
| 41 | + tile_A[row * tile_k + col + col_offset] = compute_precision(input_a[a_global*uniforms.K + k_idx + col + col_offset]); |
| 42 | + } |
| 43 | + } |
| 44 | +
|
| 45 | + fn loadSHMB(tile_base: u32, k_idx: u32, row: u32, c_idx: u32) { |
| 46 | + let b_global = tile_base + row; |
| 47 | + if (b_global >= uniforms.N) { |
| 48 | + return; |
| 49 | + } |
| 50 | + // Each call loads 16 columns, starting at col. |
| 51 | + var col = c_idx * 16; |
| 52 | + // 128 threads need to load 64 x 32. 2 threads per row or 16 col per thread. |
| 53 | + // Stored in column major fashion. |
| 54 | + let b_idx = u32((b_global*uniforms.K + k_idx + col)/8); |
| 55 | + let scale = compute_precision(scales_b[(b_global*uniforms.K + k_idx + col)/quantization_block_size]); |
| 56 | + for (var step:u32 = 0; step < 2; step++) |
| 57 | + { |
| 58 | + var b_value = input_b[b_idx+step]; |
| 59 | + var b_value_lower = (vec4<compute_precision>(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4<compute_precision>(8)) * scale; |
| 60 | + var b_value_upper = (vec4<compute_precision>(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4<compute_precision>(8)) * scale; |
| 61 | + let tile_b_base = row * tile_k + col + step * 8; |
| 62 | + tile_B[tile_b_base] = b_value_lower[0]; |
| 63 | + tile_B[tile_b_base + 1] = b_value_upper[0]; |
| 64 | + tile_B[tile_b_base + 2] = b_value_lower[1]; |
| 65 | + tile_B[tile_b_base + 3] = b_value_upper[1]; |
| 66 | + tile_B[tile_b_base + 4] = b_value_lower[2]; |
| 67 | + tile_B[tile_b_base + 5] = b_value_upper[2]; |
| 68 | + tile_B[tile_b_base + 6] = b_value_lower[3]; |
| 69 | + tile_B[tile_b_base + 7] = b_value_upper[3]; |
| 70 | + } |
| 71 | + } |
| 72 | +
|
| 73 | + fn storeOutput(offset:u32, row: u32, col:u32, src_slot:u32, row_limit:i32) { |
| 74 | + if (row_limit > 0 && row < u32(row_limit)) |
| 75 | + { |
| 76 | + output[offset + row * uniforms.N + col] = output_element_t(scratch[src_slot][0][row * 8 + col]); |
| 77 | + output[offset + row * uniforms.N + col + 8] = output_element_t(scratch[src_slot][1][row * 8 + col]); |
| 78 | + output[offset + row * uniforms.N + col + 16] = output_element_t(scratch[src_slot][2][row * 8 + col]); |
| 79 | + output[offset + row * uniforms.N + col + 24] = output_element_t(scratch[src_slot][3][row * 8 + col]); |
| 80 | + let col2 = col + 1; |
| 81 | + output[offset + row * uniforms.N + col2] = output_element_t(scratch[src_slot][0][row * 8 + col2]); |
| 82 | + output[offset + row * uniforms.N + col2 + 8] = output_element_t(scratch[src_slot][1][row * 8 + col2]); |
| 83 | + output[offset + row * uniforms.N + col2 + 16] = output_element_t(scratch[src_slot][2][row * 8 + col2]); |
| 84 | + output[offset + row * uniforms.N + col2 + 24] = output_element_t(scratch[src_slot][3][row * 8 + col2]); |
| 85 | + } |
| 86 | + } |
| 87 | + )ADDNL_FN"; |
| 88 | + |
| 89 | + shader.MainFunctionBody() << R"MAIN_FN( |
| 90 | + let a_global_base = workgroup_id.y * tile_rows; |
| 91 | + let b_global_base = workgroup_id.x * tile_cols; |
| 92 | +
|
| 93 | + let subtile_id = u32(local_idx / sg_size); |
| 94 | + let subtile_idx = u32(subtile_id / 2); |
| 95 | + let subtile_idy = subtile_id % 2; |
| 96 | + let base_A = subtile_idy * subtile_rows; |
| 97 | + let base_B = subtile_idx * subtile_cols; |
| 98 | +
|
| 99 | + var matC00: subgroup_matrix_result<compute_precision, 8, 8>; |
| 100 | + var matC01: subgroup_matrix_result<compute_precision, 8, 8>; |
| 101 | + var matC02: subgroup_matrix_result<compute_precision, 8, 8>; |
| 102 | + var matC03: subgroup_matrix_result<compute_precision, 8, 8>; |
| 103 | + var matC10: subgroup_matrix_result<compute_precision, 8, 8>; |
| 104 | + var matC11: subgroup_matrix_result<compute_precision, 8, 8>; |
| 105 | + var matC12: subgroup_matrix_result<compute_precision, 8, 8>; |
| 106 | + var matC13: subgroup_matrix_result<compute_precision, 8, 8>; |
| 107 | + for (var kidx: u32 = 0; kidx < uniforms.K; kidx += tile_k) { |
| 108 | + // Load Phase |
| 109 | + loadSHMA(a_global_base, kidx, local_idx/4, local_idx%4); |
| 110 | + loadSHMB(b_global_base, kidx, local_idx/2, local_idx%2); |
| 111 | + workgroupBarrier(); |
| 112 | +
|
| 113 | + for (var step: u32 = 0; step < tile_k; step+=8) |
| 114 | + { |
| 115 | + // Load to local memory phase |
| 116 | + let matrix_a_offset = subtile_idy * subtile_rows * tile_k + step; |
| 117 | + // Syntax: subgroupMatrixLoad src_ptr,src_offset,is_col_major,src_stride |
| 118 | + var matA0: subgroup_matrix_left<compute_precision, 8, 8> = subgroupMatrixLoad<subgroup_matrix_left<compute_precision, 8, 8>>(&tile_A, matrix_a_offset, false, tile_k); |
| 119 | + var matA1: subgroup_matrix_left<compute_precision, 8, 8> = subgroupMatrixLoad<subgroup_matrix_left<compute_precision, 8, 8>>(&tile_A, matrix_a_offset + 8 * tile_k, false, tile_k); |
| 120 | +
|
| 121 | + // tile_B is stored as column major. |
| 122 | + // [col0-0:32][col1-0:32][col2-0:32]..[col63-0:32] |
| 123 | + var matrix_b_offset = subtile_idx * subtile_cols * tile_k + step; |
| 124 | + var matB0: subgroup_matrix_right<compute_precision, 8, 8> = subgroupMatrixLoad<subgroup_matrix_right<compute_precision, 8, 8>>(&tile_B, matrix_b_offset, true, tile_k); |
| 125 | + var matB1: subgroup_matrix_right<compute_precision, 8, 8> = subgroupMatrixLoad<subgroup_matrix_right<compute_precision, 8, 8>>(&tile_B, matrix_b_offset + 8 * tile_k, true, tile_k); |
| 126 | + var matB2: subgroup_matrix_right<compute_precision, 8, 8> = subgroupMatrixLoad<subgroup_matrix_right<compute_precision, 8, 8>>(&tile_B, matrix_b_offset + 16 * tile_k, true, tile_k); |
| 127 | + var matB3: subgroup_matrix_right<compute_precision, 8, 8> = subgroupMatrixLoad<subgroup_matrix_right<compute_precision, 8, 8>>(&tile_B, matrix_b_offset + 24 * tile_k, true, tile_k); |
| 128 | +
|
| 129 | + // Compute Phase |
| 130 | + // Syntax: subgroupMatrixMultiplyAccumulate left, right, accumulate -> accumulate |
| 131 | + matC00 = subgroupMatrixMultiplyAccumulate(matA0, matB0, matC00); |
| 132 | + matC01 = subgroupMatrixMultiplyAccumulate(matA0, matB1, matC01); |
| 133 | + matC02 = subgroupMatrixMultiplyAccumulate(matA0, matB2, matC02); |
| 134 | + matC03 = subgroupMatrixMultiplyAccumulate(matA0, matB3, matC03); |
| 135 | +
|
| 136 | + matC10 = subgroupMatrixMultiplyAccumulate(matA1, matB0, matC10); |
| 137 | + matC11 = subgroupMatrixMultiplyAccumulate(matA1, matB1, matC11); |
| 138 | + matC12 = subgroupMatrixMultiplyAccumulate(matA1, matB2, matC12); |
| 139 | + matC13 = subgroupMatrixMultiplyAccumulate(matA1, matB3, matC13); |
| 140 | + } |
| 141 | +
|
| 142 | + workgroupBarrier(); |
| 143 | + } |
| 144 | +
|
| 145 | + // Write out |
| 146 | + // Write out top block |
| 147 | + subgroupMatrixStore(&scratch[subtile_id][0], 0, matC00, false, 8); |
| 148 | + subgroupMatrixStore(&scratch[subtile_id][1], 0, matC01, false, 8); |
| 149 | + subgroupMatrixStore(&scratch[subtile_id][2], 0, matC02, false, 8); |
| 150 | + subgroupMatrixStore(&scratch[subtile_id][3], 0, matC03, false, 8); |
| 151 | + workgroupBarrier(); |
| 152 | + let row = u32(sg_id / 4); |
| 153 | + var col = u32(sg_id % 4) * 2; |
| 154 | + var matrix_c_offset = (a_global_base+base_A) * uniforms.N + b_global_base + base_B; |
| 155 | + var row_limit:i32 = i32(uniforms.M) - i32(a_global_base + base_A); |
| 156 | + storeOutput(matrix_c_offset, row, col, subtile_id, row_limit); |
| 157 | + workgroupBarrier(); |
| 158 | +
|
| 159 | + // Write out bottom block |
| 160 | + subgroupMatrixStore(&scratch[subtile_id][0], 0, matC10, false, 8); |
| 161 | + subgroupMatrixStore(&scratch[subtile_id][1], 0, matC11, false, 8); |
| 162 | + subgroupMatrixStore(&scratch[subtile_id][2], 0, matC12, false, 8); |
| 163 | + subgroupMatrixStore(&scratch[subtile_id][3], 0, matC13, false, 8); |
| 164 | + workgroupBarrier(); |
| 165 | + matrix_c_offset = matrix_c_offset + 8 * uniforms.N; |
| 166 | + row_limit = i32(uniforms.M) - i32(a_global_base + base_A + 8); |
| 167 | + storeOutput(matrix_c_offset, row, col, subtile_id, row_limit); |
| 168 | + )MAIN_FN"; |
| 169 | + |
| 170 | + return Status::OK(); |
| 171 | +} |
| 172 | + |
| 173 | +Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, |
| 174 | + uint32_t M, |
| 175 | + uint32_t N, |
| 176 | + uint32_t K, |
| 177 | + onnxruntime::webgpu::ComputeContext& context, |
| 178 | + Tensor* y) { |
| 179 | + constexpr uint32_t kTileSizeA = 32; |
| 180 | + constexpr uint32_t kTileSizeB = 64; |
| 181 | + constexpr uint32_t kU32Components = 4; |
| 182 | + TensorShape y_shape{1, M, N}; |
| 183 | + SubgroupMatrixMatMulNBitsProgram mul_program; |
| 184 | + mul_program.SetWorkgroupSize(128); |
| 185 | + mul_program.SetDispatchGroupSize( |
| 186 | + (N + kTileSizeB - 1) / kTileSizeB, |
| 187 | + (M + kTileSizeA - 1) / kTileSizeA, 1); |
| 188 | + mul_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(1)}, |
| 189 | + {b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kU32Components)}, |
| 190 | + {scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(1)}}) |
| 191 | + .AddUniformVariables({{static_cast<uint32_t>(M)}, |
| 192 | + {static_cast<uint32_t>(N)}, |
| 193 | + {static_cast<uint32_t>(K)}}) |
| 194 | + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, y_shape, gsl::narrow<int>(1)}); |
| 195 | + return context.RunProgram(mul_program); |
| 196 | +} |
| 197 | + |
| 198 | +bool CanApplySubgroupMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context, |
| 199 | + uint64_t accuracy_level, |
| 200 | + uint32_t block_size, |
| 201 | + uint32_t batch_count, |
| 202 | + uint32_t N, |
| 203 | + uint32_t K, |
| 204 | + bool has_zero_points) { |
| 205 | +#if !defined(__wasm__) |
| 206 | + const bool has_subgroup_matrix = context.Device().HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); |
| 207 | +#else |
| 208 | + const bool has_subgroup_matrix = false; |
| 209 | +#endif |
| 210 | + // For now SubgroupMatrixMatMulNBits is only supported for accuracy level 4, because with Fp16 there are |
| 211 | + // some precision issues with subgroupMatrixMultiplyAccumulate. It is possible to support higher accuracy |
| 212 | + // by setting compute_precision to Fp32, but that will be slower. For 1K token prefill FP16 Phi 3.5 is around 5s, |
| 213 | + // FP322 is around 7s. |
| 214 | + return context.AdapterInfo().backendType == wgpu::BackendType::Metal && |
| 215 | + has_subgroup_matrix && |
| 216 | + accuracy_level == 4 && |
| 217 | + block_size == 32 && |
| 218 | + batch_count == 1 && |
| 219 | + K % 32 == 0 && |
| 220 | + N % 64 == 0 && |
| 221 | + !has_zero_points; |
| 222 | +} |
| 223 | +} // namespace webgpu |
| 224 | +} // namespace contrib |
| 225 | +} // namespace onnxruntime |
0 commit comments