Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions web/src/cache_state.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/**
* Caching utilities for the TVM web runtime.
*
* Provides a generic LRUCache and a CacheState container that manages
* domain-specific caches used by the WebGPU runtime.
*/
import { Disposable } from "./types";

/**
* A generic LRU (Least Recently Used) cache with bounded size.
*
* Entries are evicted in insertion order when the cache exceeds `maxSize`.
* Uses a Map to maintain insertion order for O(1) LRU eviction.
*
* @typeParam K - Cache key type.
* @typeParam V - Cache value type.
*/
export class LRUCache<K, V> {
private cache: Map<K, V> = new Map();
private readonly maxSize: number;
/** Optional callback invoked when an entry is evicted. */
private readonly onEvict?: (key: K, value: V) => void;

constructor(maxSize: number, onEvict?: (key: K, value: V) => void) {
this.maxSize = maxSize;
this.onEvict = onEvict;
}

/**
* Get a value from the cache, constructing it via `constructor` on miss.
*
* On hit: moves the entry to most-recently-used position and returns it.
* On miss: calls `constructor()` to create the value, inserts it, and
* returns it. If the cache is full, the least-recently-used entry is
* evicted first.
*
* @param key The cache key.
* @param constructor Factory function called on cache miss to produce the value.
* @returns The cached or newly constructed value.
*/
get(key: K, constructor: () => V): V {
const existing = this.cache.get(key);
if (existing !== undefined) {
// Move to most-recently-used position
this.cache.delete(key);
this.cache.set(key, existing);
return existing;
}
// Evict LRU entry if at capacity
if (this.cache.size >= this.maxSize) {
const oldest = this.cache.keys().next().value;
if (oldest !== undefined) {
if (this.onEvict) {
this.onEvict(oldest, this.cache.get(oldest)!);
}
this.cache.delete(oldest);
}
}
const value = constructor();
this.cache.set(key, value);
return value;
}

/**
* Check whether eviction would be needed for a new entry.
*
* Useful when the caller needs to perform side effects before eviction
* (e.g. flushing pending GPU commands before destroying an evicted buffer).
*
* @param key The key to check.
* @returns true if inserting `key` would trigger eviction of another entry.
*/
needEviction(key: K): boolean {
if (this.cache.has(key)) return false;
return this.cache.size >= this.maxSize;
}

/**
* Clear all cached entries.
*
* Does not dispose values — the caller is responsible for cleanup
* (e.g. destroying GPU buffers) before calling invalidate.
*/
invalidate(): void {
this.cache.clear();
}

/** Number of entries currently in the cache. */
get size(): number {
return this.cache.size;
}

/** Iterate over all cached values (for disposal). */
values(): IterableIterator<V> {
return this.cache.values();
}
}

/**
* CacheState manages domain-specific caches for the WebGPU runtime.
*
* Currently contains:
* - **shapeCache**: Caches TVM ShapeTuple objects keyed by dimension string.
* - Why: `makeShapeTuple()` is called on every tensor operation, crossing
* the JS→WASM FFI boundary each time. During LLM decode, the same shapes
* repeat every token (e.g. [1,32,128]), so caching avoids thousands of
* redundant FFI round-trips.
* - Invalidation: Never. Shape tuples are immutable value objects that
* remain valid for the lifetime of the TVM instance.
*
* Future additions (follow-up PR):
* - **uniformCache**: Caches GPU uniform buffers keyed by content hash.
* - Why: Many dispatches use identical scalar arguments (matrix dims, etc.).
* Reusing the buffer avoids `createBuffer` + `writeBuffer` overhead.
* - Invalidation: Must invalidate on any GPU buffer deallocation, since
* buffer pointers can be reused by the allocator, making cached entries
* that reference the old buffer stale.
*/
export class CacheState {
/**
* Cache for TVM ShapeTuple objects.
*
* Key: comma-separated dimension string, e.g. "1,32,128"
* Value: TVM ShapeTuple object (Disposable)
*
* Invalidation rule: None required — shape tuples are immutable.
*/
readonly shapeCache: LRUCache<string, Disposable>;

constructor(shapeCacheSize: number = 256) {
this.shapeCache = new LRUCache<string, Disposable>(
shapeCacheSize,
(_key, value) => value.dispose()
);
}

/**
* Compute the cache key for a shape tuple.
*
* @param shape Array of dimension values.
* @returns String key suitable for shapeCache lookup.
*/
static computeShapeKey(shape: Array<number>): string {
return shape.toString();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The computeShapeKey method uses shape.toString() to generate the cache key. While this works for simple number arrays, it's generally safer to use a more explicit serialization method (e.g., JSON.stringify(shape)) to avoid potential collisions or unexpected behavior if array elements could be non-numeric or contain special characters that toString() might handle ambiguously. For example, [1,23] and [12,3] both become "1,23" with toString() if not handled carefully, though for integer shapes this is less likely to be an issue. Given the context of Array<number>, toString() is likely sufficient, but it's a point to consider for robustness.

}

/**
* Dispose all cached objects and clear all caches.
*/
dispose(): void {
for (const obj of this.shapeCache.values()) {
obj.dispose();
}
this.shapeCache.invalidate();
}
}
1 change: 1 addition & 0 deletions web/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ export { Disposable, LibraryProvider } from "./types";
export { RPCServer } from "./rpc_server";
export { assert, wasmPath, LinearCongruentialGenerator } from "./support";
export { detectGPUDevice, GPUDeviceDetectOutput } from "./webgpu";
export { LRUCache, CacheState } from "./cache_state";
export { createPolyfillWASI } from "./compact";
14 changes: 12 additions & 2 deletions web/src/runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import { assert, StringToUint8Array, LinearCongruentialGenerator } from "./suppo
import { Environment } from "./environment";
import { AsyncifyHandler } from "./asyncify";
import { FunctionInfo, WebGPUContext } from "./webgpu";
import { CacheState } from "./cache_state";
import {
ArtifactCache,
ArtifactCacheTemplate,
Expand Down Expand Up @@ -859,6 +860,7 @@ export class Instance implements Disposable {
private initProgressCallback: Array<InitProgressCallback> = [];
private rng: LinearCongruentialGenerator;
private deviceLostIsError = true; // whether device.lost is due to actual error or dispose()
private cacheState: CacheState = new CacheState();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The CacheState is initialized with a default shapeCacheSize of 256. While this is a reasonable default, it might be beneficial to make this configurable via the Instance constructor or a setter. This would allow for easier performance tuning in different scenarios, especially if the application has a very large or very small number of unique shape tuples.


/**
* Internal function(registered by the runtime)
Expand Down Expand Up @@ -954,6 +956,8 @@ export class Instance implements Disposable {
dispose(): void {
this.deviceLostIsError = false; // prevent dispose to trigger device.lost error
// order matters
// dispose caches before ctx
this.cacheState.dispose();
// ctx release goes back into lib.
this.ctx.dispose();
this.lib.dispose();
Expand Down Expand Up @@ -1674,8 +1678,14 @@ export class Instance implements Disposable {
* @returns The created shape tuple.
*/
makeShapeTuple(shape: Array<number>): TVMObject {
const shapeArray = shape.map((value) => new Scalar(value, "int"));
return this.ctx.makeShapeTuple(...shapeArray);
const key = CacheState.computeShapeKey(shape);
return this.cacheState.shapeCache.get(key, () => {
const shapeArray = shape.map((value) => new Scalar(value, "int"));
const tuple = this.ctx.makeShapeTuple(...shapeArray);
// Detach from scope so the cached object survives across scopes.
this.detachFromCurrentScope(tuple);
return tuple;
}) as TVMObject;
}
Comment thread
mitiskuma marked this conversation as resolved.
Comment thread
mitiskuma marked this conversation as resolved.
/**
* Get type index from type key.
Expand Down
Loading
Loading