-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Batched GPU dispatch and object caching for WebGPU runtime #18871
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
298f42c
291fe23
d43e7d1
2715c33
11c33ba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,175 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one | ||
| * or more contributor license agreements. See the NOTICE file | ||
| * distributed with this work for additional information | ||
| * regarding copyright ownership. The ASF licenses this file | ||
| * to you under the Apache License, Version 2.0 (the | ||
| * "License"); you may not use this file except in compliance | ||
| * with the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, | ||
| * software distributed under the License is distributed on an | ||
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| * KIND, either express or implied. See the License for the | ||
| * specific language governing permissions and limitations | ||
| * under the License. | ||
| */ | ||
|
|
||
| /** | ||
| * Caching utilities for the TVM web runtime. | ||
| * | ||
| * Provides a generic LRUCache and a CacheState container that manages | ||
| * domain-specific caches used by the WebGPU runtime. | ||
| */ | ||
| import { Disposable } from "./types"; | ||
|
|
||
| /** | ||
| * A generic LRU (Least Recently Used) cache with bounded size. | ||
| * | ||
| * Entries are evicted in insertion order when the cache exceeds `maxSize`. | ||
| * Uses a Map to maintain insertion order for O(1) LRU eviction. | ||
| * | ||
| * @typeParam K - Cache key type. | ||
| * @typeParam V - Cache value type. | ||
| */ | ||
| export class LRUCache<K, V> { | ||
| private cache: Map<K, V> = new Map(); | ||
| private readonly maxSize: number; | ||
| /** Optional callback invoked when an entry is evicted. */ | ||
| private readonly onEvict?: (key: K, value: V) => void; | ||
|
|
||
| constructor(maxSize: number, onEvict?: (key: K, value: V) => void) { | ||
| this.maxSize = maxSize; | ||
| this.onEvict = onEvict; | ||
| } | ||
|
|
||
| /** | ||
| * Get a value from the cache, constructing it via `constructor` on miss. | ||
| * | ||
| * On hit: moves the entry to most-recently-used position and returns it. | ||
| * On miss: calls `constructor()` to create the value, inserts it, and | ||
| * returns it. If the cache is full, the least-recently-used entry is | ||
| * evicted first. | ||
| * | ||
| * @param key The cache key. | ||
| * @param constructor Factory function called on cache miss to produce the value. | ||
| * @returns The cached or newly constructed value. | ||
| */ | ||
| get(key: K, constructor: () => V): V { | ||
| const existing = this.cache.get(key); | ||
| if (existing !== undefined) { | ||
| // Move to most-recently-used position | ||
| this.cache.delete(key); | ||
| this.cache.set(key, existing); | ||
| return existing; | ||
| } | ||
| // Evict LRU entry if at capacity | ||
| if (this.cache.size >= this.maxSize) { | ||
| const oldest = this.cache.keys().next().value; | ||
| if (oldest !== undefined) { | ||
| if (this.onEvict) { | ||
| this.onEvict(oldest, this.cache.get(oldest)!); | ||
| } | ||
| this.cache.delete(oldest); | ||
| } | ||
| } | ||
| const value = constructor(); | ||
| this.cache.set(key, value); | ||
| return value; | ||
| } | ||
|
|
||
| /** | ||
| * Check whether eviction would be needed for a new entry. | ||
| * | ||
| * Useful when the caller needs to perform side effects before eviction | ||
| * (e.g. flushing pending GPU commands before destroying an evicted buffer). | ||
| * | ||
| * @param key The key to check. | ||
| * @returns true if inserting `key` would trigger eviction of another entry. | ||
| */ | ||
| needEviction(key: K): boolean { | ||
| if (this.cache.has(key)) return false; | ||
| return this.cache.size >= this.maxSize; | ||
| } | ||
|
|
||
| /** | ||
| * Clear all cached entries. | ||
| * | ||
| * Does not dispose values — the caller is responsible for cleanup | ||
| * (e.g. destroying GPU buffers) before calling invalidate. | ||
| */ | ||
| invalidate(): void { | ||
| this.cache.clear(); | ||
| } | ||
|
|
||
| /** Number of entries currently in the cache. */ | ||
| get size(): number { | ||
| return this.cache.size; | ||
| } | ||
|
|
||
| /** Iterate over all cached values (for disposal). */ | ||
| values(): IterableIterator<V> { | ||
| return this.cache.values(); | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * CacheState manages domain-specific caches for the WebGPU runtime. | ||
| * | ||
| * Currently contains: | ||
| * - **shapeCache**: Caches TVM ShapeTuple objects keyed by dimension string. | ||
| * - Why: `makeShapeTuple()` is called on every tensor operation, crossing | ||
| * the JS→WASM FFI boundary each time. During LLM decode, the same shapes | ||
| * repeat every token (e.g. [1,32,128]), so caching avoids thousands of | ||
| * redundant FFI round-trips. | ||
| * - Invalidation: Never. Shape tuples are immutable value objects that | ||
| * remain valid for the lifetime of the TVM instance. | ||
| * | ||
| * Future additions (follow-up PR): | ||
| * - **uniformCache**: Caches GPU uniform buffers keyed by content hash. | ||
| * - Why: Many dispatches use identical scalar arguments (matrix dims, etc.). | ||
| * Reusing the buffer avoids `createBuffer` + `writeBuffer` overhead. | ||
| * - Invalidation: Must invalidate on any GPU buffer deallocation, since | ||
| * buffer pointers can be reused by the allocator, making cached entries | ||
| * that reference the old buffer stale. | ||
| */ | ||
| export class CacheState { | ||
| /** | ||
| * Cache for TVM ShapeTuple objects. | ||
| * | ||
| * Key: comma-separated dimension string, e.g. "1,32,128" | ||
| * Value: TVM ShapeTuple object (Disposable) | ||
| * | ||
| * Invalidation rule: None required — shape tuples are immutable. | ||
| */ | ||
| readonly shapeCache: LRUCache<string, Disposable>; | ||
|
|
||
| constructor(shapeCacheSize: number = 256) { | ||
| this.shapeCache = new LRUCache<string, Disposable>( | ||
| shapeCacheSize, | ||
| (_key, value) => value.dispose() | ||
| ); | ||
| } | ||
|
|
||
| /** | ||
| * Compute the cache key for a shape tuple. | ||
| * | ||
| * @param shape Array of dimension values. | ||
| * @returns String key suitable for shapeCache lookup. | ||
| */ | ||
| static computeShapeKey(shape: Array<number>): string { | ||
| return shape.toString(); | ||
| } | ||
|
|
||
| /** | ||
| * Dispose all cached objects and clear all caches. | ||
| */ | ||
| dispose(): void { | ||
| for (const obj of this.shapeCache.values()) { | ||
| obj.dispose(); | ||
| } | ||
| this.shapeCache.invalidate(); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,7 @@ import { assert, StringToUint8Array, LinearCongruentialGenerator } from "./suppo | |
| import { Environment } from "./environment"; | ||
| import { AsyncifyHandler } from "./asyncify"; | ||
| import { FunctionInfo, WebGPUContext } from "./webgpu"; | ||
| import { CacheState } from "./cache_state"; | ||
| import { | ||
| ArtifactCache, | ||
| ArtifactCacheTemplate, | ||
|
|
@@ -859,6 +860,7 @@ export class Instance implements Disposable { | |
| private initProgressCallback: Array<InitProgressCallback> = []; | ||
| private rng: LinearCongruentialGenerator; | ||
| private deviceLostIsError = true; // whether device.lost is due to actual error or dispose() | ||
| private cacheState: CacheState = new CacheState(); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
|
||
| /** | ||
| * Internal function(registered by the runtime) | ||
|
|
@@ -954,6 +956,8 @@ export class Instance implements Disposable { | |
| dispose(): void { | ||
| this.deviceLostIsError = false; // prevent dispose to trigger device.lost error | ||
| // order matters | ||
| // dispose caches before ctx | ||
| this.cacheState.dispose(); | ||
| // ctx release goes back into lib. | ||
| this.ctx.dispose(); | ||
| this.lib.dispose(); | ||
|
|
@@ -1674,8 +1678,14 @@ export class Instance implements Disposable { | |
| * @returns The created shape tuple. | ||
| */ | ||
| makeShapeTuple(shape: Array<number>): TVMObject { | ||
| const shapeArray = shape.map((value) => new Scalar(value, "int")); | ||
| return this.ctx.makeShapeTuple(...shapeArray); | ||
| const key = CacheState.computeShapeKey(shape); | ||
| return this.cacheState.shapeCache.get(key, () => { | ||
| const shapeArray = shape.map((value) => new Scalar(value, "int")); | ||
| const tuple = this.ctx.makeShapeTuple(...shapeArray); | ||
| // Detach from scope so the cached object survives across scopes. | ||
| this.detachFromCurrentScope(tuple); | ||
| return tuple; | ||
| }) as TVMObject; | ||
| } | ||
|
mitiskuma marked this conversation as resolved.
mitiskuma marked this conversation as resolved.
|
||
| /** | ||
| * Get type index from type key. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
computeShapeKeymethod usesshape.toString()to generate the cache key. While this works for simple number arrays, it's generally safer to use a more explicit serialization method (e.g.,JSON.stringify(shape)) to avoid potential collisions or unexpected behavior if array elements could be non-numeric or contain special characters thattoString()might handle ambiguously. For example,[1,23]and[12,3]both become"1,23"withtoString()if not handled carefully, though for integer shapes this is less likely to be an issue. Given the context ofArray<number>,toString()is likely sufficient, but it's a point to consider for robustness.