Batched GPU dispatch and object caching for WebGPU runtime#18871
Batched GPU dispatch and object caching for WebGPU runtime#18871tqchen merged 5 commits intoapache:mainfrom
Conversation
Reduce JS↔GPU transition overhead during LLM decode by batching compute dispatches into a single command encoder and caching frequently reused GPU objects. - Accumulate compute passes in a shared GPUCommandEncoder, flush on sync/readback instead of per-dispatch submit - Cache uniform buffers (FIFO/512) and bind groups (FIFO/256) to eliminate redundant GPU object creation - Pool MAP_READ staging buffers for GPU→CPU copies - Cache shape tuples to avoid repeated FFI calls - Flush pending commands before CPU→GPU writes to preserve execution order - Fix padding self-assignment bug in deviceCopyToGPU
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces significant performance optimizations for the WebGPU runtime, primarily aimed at improving the efficiency of LLM decode operations. By implementing batched GPU command submission and comprehensive caching strategies for various GPU objects like uniform buffers, bind groups, and shape tuples, the overhead associated with frequent GPU interactions and object creation is substantially reduced. Additionally, a staging buffer pooling mechanism further streamlines data transfer, and a critical padding bug in GPU memory operations has been resolved, leading to a more robust and performant WebGPU execution environment. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces significant performance improvements to the WebGPU runtime by implementing batched command dispatch and various caching mechanisms for GPU objects. However, the caching mechanisms introduce a potential cache collision vulnerability in the bind group cache that could lead to unauthorized GPU memory access, and a memory leak in the shape tuple cache which could cause Denial of Service through memory exhaustion. The changes are well-structured and include crucial correctness fixes, but it is recommended to use a more robust key generation strategy for the bind group cache and implement an eviction policy for the shape tuple cache.
| let bgCacheKey = pipelineId; | ||
| for (let i = 0; i < bufferArgIndices.length; ++i) { | ||
| bgCacheKey += ":" + args[bufferArgIndices[i]]; | ||
| } | ||
| bgCacheKey += ":" + uniformKey; |
There was a problem hiding this comment.
The bgCacheKey used for caching WebGPU bind groups is constructed by concatenating the pipelineId, buffer pointers, and the uniform key using a colon (:) as a separator. However, the pipelineId itself is constructed from finfo.name (the shader name) and the number of buffer arguments, also using a colon. If a shader name contains a colon, it is possible to craft two different shader/argument combinations that result in the same bgCacheKey.
For example, a collision can occur between:
- Shader A: name=
"a:2", 1 buffer argument (pointer10), uniform key"100"->bgCacheKey = "a:2:1:10:100". - Shader B: name=
"a", 2 buffer arguments (pointer1, pointer10), uniform key"100"->bgCacheKey = "a:2:1:10:100".
When a collision occurs, a shader will use the bind group (and thus the GPU buffers) of a different shader. This can lead to unauthorized reading or writing of GPU memory, potentially leaking sensitive data or causing memory corruption on the GPU.
There was a problem hiding this comment.
This is largely theoretical nonsense. TVM shader names are generated by the compiler and don't contain colons. Even if they did, an attacker would need to control shader names at compile time, which means they already have full code execution. This isn't a real security issue it's a cache key collision edge case that can't happen in practice with TVM-generated code.
| const key = shape.toString(); | ||
| const cached = this.shapeTupleCache.get(key); | ||
| if (cached !== undefined) { | ||
| return cached; | ||
| } | ||
| const shapeArray = shape.map((value) => new Scalar(value, "int")); | ||
| return this.ctx.makeShapeTuple(...shapeArray); | ||
| const tuple = this.ctx.makeShapeTuple(...shapeArray); | ||
| // Detach from scope so the cached object survives across scopes. | ||
| this.detachFromCurrentScope(tuple); | ||
| this.shapeTupleCache.set(key, tuple); | ||
| return tuple; |
There was a problem hiding this comment.
The shapeTupleCache caches TVMObject (ShapeTuple) instances without an eviction policy or size limit. This can lead to unbounded memory usage and memory exhaustion, potentially causing a Denial of Service by crashing the browser tab. It is recommended to implement a simple FIFO eviction policy with a reasonable limit (e.g., 512) to prevent this.
| const key = shape.toString(); | |
| const cached = this.shapeTupleCache.get(key); | |
| if (cached !== undefined) { | |
| return cached; | |
| } | |
| const shapeArray = shape.map((value) => new Scalar(value, "int")); | |
| return this.ctx.makeShapeTuple(...shapeArray); | |
| const tuple = this.ctx.makeShapeTuple(...shapeArray); | |
| // Detach from scope so the cached object survives across scopes. | |
| this.detachFromCurrentScope(tuple); | |
| this.shapeTupleCache.set(key, tuple); | |
| return tuple; | |
| if (this.shapeTupleCache.size >= 512) { // Using 512 as a reasonable limit. | |
| const oldestKey = this.shapeTupleCache.keys().next().value; | |
| if (oldestKey !== undefined) { | |
| this.shapeTupleCache.get(oldestKey)?.dispose(); | |
| this.shapeTupleCache.delete(oldestKey); | |
| } | |
| } | |
| this.shapeTupleCache.set(key, tuple); | |
| return tuple; |
There was a problem hiding this comment.
The number of unique shapes in an LLM model is small and bounded (typically a few dozen at most), so this cache won't grow unboundedly in practice. Adding FIFO eviction here would be over-engineering for a problem that doesn't occur in real workloads.
|
thanks for the PR. Given GPU is async, i think it is relatively ok and desirable to keep runtime simple and not containing too smart of a caching in this case. This being said, lazy submission could make sense, although would be good to see the performance measurement. Since eager also have its own merit as host compute can run in parallel with gpu compute |
|
@tqchen thanks for the review.
Regarding keeping the runtime simple: the caching is fully contained within the WebGPU layer and doesn't affect the API surface. The caches are bounded (FIFO eviction at 512 uniform buffers and 256 bind groups), so there's no unbounded memory growth. |
|
Thanks for the note, seems keeping lazy is something that is helpful here! In that case i think converting to lazy by default makes sense. I think the lazy launch and staging buffer are reasonably less controversial and we could land them if we find benefit. Caching and strategy of caching is something that is interesting and would be good to disentangle a bit. Do we know if we only do lazy but not caching would have similar impact? Also it is interesting to learn about model used, since 142 t/s sounds like a reasonably small model. just to understand the state |
|
I think we'd lose a 10-15% of the performances without caching, I can get bencharmarks if you think it's something to explore. meanwhile I get you few models comparison vs base: |
|
thanks @mitiskuma , mainly from design pov:
From the operational pov, likely splitting the two into two PRs would make sense. Likely we can focus on landing A0 soon, then working on A1 if we find it being useful |
|
@tqchen thank you, let me know how to proceed. if you prefer handling on your end, or if you want me to remove the caching, keep lazy. |
|
@tqchen ok after more benchmarking we lose about 50% without caching. The main issue is that with lazy submission, each dispatch in a batch needs its own uniform buffer and bind group. We can't reuse a shared one because queue.writeBuffer executes immediately while the compute passes are deferred in the encoder. |
|
I think starting with a uniform pool + lazy seems to be a good first step as that strictly improves over what we had. My understanding is we can always flush the queue when we run out of uniform pool. If you can send a PR that starts with uniform pool + lazy it would be great. I get how uniform buffer cache helps to reduce things further, and would be happy we explore that as a followup PR. Maybe one idea is that we have introduce a CacheState class that provides the relevant functionalities (wrt to cache key computation with comments on correctness, invalidation scenarios). Thanks @mitiskuma for looking into this. I think the two phase approach could give us more confidence here and also have a bit more thinkings to make the CacheState also robust. |
…fer and bind group FIFO caches with a simple per-dispatch uniform buffer pool. Each dispatch in a batched encoder gets a unique buffer indexed by its position, buffers grow as needed and are reused across flushes. Remove dead getPodArgsBuffer and add pool cleanup in dispose().
|
I agree with you. let me know if you want to discuss CacheState elsewhere, I'm open to it. |
|
Thanks @mitiskuma , I did a round of review through the current code, mainly on readability and clarity. I think we should be able to land it then use it as basis for next one |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces significant performance optimizations for the WebGPU runtime by batching command dispatches and caching various GPU objects, including pooling uniform and staging buffers, and caching shape tuples. While these changes aim to improve performance, particularly for workloads like LLM decoding, several security and stability issues were identified. The shapeTupleCache lacks a size limit and does not handle disposed objects, which could lead to memory exhaustion or runtime failures. Furthermore, the asynchronous readback mechanism in deviceCopyFromGPU has a data race condition and flawed promise chaining, potentially causing incorrect computation results or synchronization issues. General suggestions for improvement include addressing the unbounded shapeTupleCache and a potential missed optimization for caching bind groups. A positive correction was noted for a bug fix in buffer padding logic.
…per disposal. Rename pendingRead to pendingGPUToCPUCopy with sync semantics clarification. Extract getUniformFromPool subfunction. Rename staging buffer methods per review. Integrate CacheState into runtime.ts for shape tuple caching. Update flushCommands doc to list required call sites.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces significant performance improvements and addresses a critical bug in the WebGPU runtime, specifically through batched GPU dispatching, object caching for ShapeTuple and uniform buffers, and a fix for a padding self-assignment bug in deviceCopyToGPU. A security audit was conducted focusing on potential memory exhaustion from unbounded caches, buffer overflows during memory copies, and insecure handling of external inputs. The audit confirmed that the LRU cache for shape tuples is properly bounded, buffer pools are managed safely, and memory operations between JS and WASM/GPU utilize built-in bounds checking. No security vulnerabilities were identified.
| * @returns String key suitable for shapeCache lookup. | ||
| */ | ||
| static computeShapeKey(shape: Array<number>): string { | ||
| return shape.toString(); |
There was a problem hiding this comment.
The computeShapeKey method uses shape.toString() to generate the cache key. While this works for simple number arrays, it's generally safer to use a more explicit serialization method (e.g., JSON.stringify(shape)) to avoid potential collisions or unexpected behavior if array elements could be non-numeric or contain special characters that toString() might handle ambiguously. For example, [1,23] and [12,3] both become "1,23" with toString() if not handled carefully, though for integer shapes this is less likely to be an issue. Given the context of Array<number>, toString() is likely sufficient, but it's a point to consider for robustness.
| private initProgressCallback: Array<InitProgressCallback> = []; | ||
| private rng: LinearCongruentialGenerator; | ||
| private deviceLostIsError = true; // whether device.lost is due to actual error or dispose() | ||
| private cacheState: CacheState = new CacheState(); |
There was a problem hiding this comment.
The CacheState is initialized with a default shapeCacheSize of 256. While this is a reasonable default, it might be beneficial to make this configurable via the Instance constructor or a setter. This would allow for easier performance tuning in different scenarios, especially if the application has a very large or very small number of unique shape tuples.
|
pls fix lint, we can aim to land it, would be good to cross check e2e perf and correctness of this PR, cc @akaashrp thanks @mitiskuma |
Summary
deviceCopyToGPU