From 70911c5798ffe5b140a85d6aa91ae89300b5da52 Mon Sep 17 00:00:00 2001 From: DiegoCao Date: Mon, 5 Feb 2024 12:58:23 -0500 Subject: [PATCH] Parallel Download, Move ArtifactCache to Interface to support future different cache types, fix README path typo, Support delete and batch delete Co-authored-by: DavidGOrtega --- web/README.md | 2 +- web/package-lock.json | 4 +-- web/src/artifact_cache.ts | 19 +++++++++++++++ web/src/index.ts | 2 +- web/src/runtime.ts | 51 +++++++++++++++++++++++++++++++++------ 5 files changed, 67 insertions(+), 11 deletions(-) create mode 100644 web/src/artifact_cache.ts diff --git a/web/README.md b/web/README.md index 64f507579e94..9b3cda1fb76c 100644 --- a/web/README.md +++ b/web/README.md @@ -94,4 +94,4 @@ Right now we use the SPIRV to generate shaders that can be accepted by Chrome an - Firefox should be close pending the support of Fence. - Download vulkan SDK (1.1 or higher) that supports SPIRV 1.3 - Start the WebSocket RPC -- run `python tests/node/webgpu_rpc_test.py` +- run `python tests/python/webgpu_rpc_test.py` diff --git a/web/package-lock.json b/web/package-lock.json index 37bac4493f81..74561324c90d 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -1,12 +1,12 @@ { "name": "tvmjs", - "version": "0.15.0-dev0", + "version": "0.16.0-dev0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "tvmjs", - "version": "0.15.0-dev0", + "version": "0.16.0-dev0", "license": "Apache-2.0", "devDependencies": { "@rollup/plugin-commonjs": "^20.0.0", diff --git a/web/src/artifact_cache.ts b/web/src/artifact_cache.ts new file mode 100644 index 000000000000..394cda83bc43 --- /dev/null +++ b/web/src/artifact_cache.ts @@ -0,0 +1,19 @@ +/* + Common Interface for the artifact cache +*/ +export interface ArtifactCacheTemplate { + /** + * fetch key url from cache + */ + fetchWithCache(url: string); + + /** + * check if cache has all keys in Cache + */ + hasAllKeys(keys: string[]); + + /** + * Delete url in cache if url exists + */ + deleteInCache(url: string); +} diff --git a/web/src/index.ts b/web/src/index.ts index fd27fce9fd25..9099d8f37347 100644 --- a/web/src/index.ts +++ b/web/src/index.ts @@ -22,7 +22,7 @@ export { PackedFunc, Module, NDArray, TVMArray, TVMObject, VirtualMachine, InitProgressCallback, InitProgressReport, - ArtifactCache, Instance, instantiate, hasNDArrayInCache + ArtifactCache, Instance, instantiate, hasNDArrayInCache, deleteNDArrayCache } from "./runtime"; export { Disposable, LibraryProvider } from "./types"; export { RPCServer } from "./rpc_server"; diff --git a/web/src/runtime.ts b/web/src/runtime.ts index 4c560052617d..7811d38d599f 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -26,6 +26,7 @@ import { Memory, CachedCallStack } from "./memory"; import { assert, StringToUint8Array } from "./support"; import { Environment } from "./environment"; import { FunctionInfo, WebGPUContext } from "./webgpu"; +import { ArtifactCacheTemplate } from "./artifact_cache"; import * as compact from "./compact"; import * as ctypes from "./ctypes"; @@ -985,7 +986,7 @@ export type InitProgressCallback = (report: InitProgressReport) => void; /** * Cache to store model related data. */ -export class ArtifactCache { +export class ArtifactCache implements ArtifactCacheTemplate { private scope: string; private cache?: Cache; @@ -1018,6 +1019,14 @@ export class ArtifactCache { .then(cacheKeys => keys.every(key => cacheKeys.indexOf(key) !== -1)) .catch(err => false); } + + async deleteInCache(url: string) { + if (this.cache === undefined) { + this.cache = await caches.open(this.scope); + } + const result = await this.cache.delete(url); + return result; + } } /** @@ -1451,7 +1460,7 @@ export class Instance implements Disposable { } /** - * Fetch NDArray cache from url. + * Given cacheUrl, search up items to fetch based on cacheUrl/ndarray-cache.json * * @param ndarrayCacheUrl The cache url. * @param device The device to be fetched to. @@ -1477,6 +1486,7 @@ export class Instance implements Disposable { this.cacheMetadata = { ...this.cacheMetadata, ...(list["metadata"] as Record) }; } + /** * Fetch list of NDArray into the NDArrayCache. * @@ -1489,7 +1499,7 @@ export class Instance implements Disposable { ndarrayCacheUrl: string, list: Array, device: DLDevice, - artifactCache: ArtifactCache + artifactCache: ArtifactCacheTemplate ) { const perf = compact.getPerformance(); const tstart = perf.now(); @@ -1536,10 +1546,11 @@ export class Instance implements Disposable { }); } - for (let i = 0; i < list.length; ++i) { + const processShard = async (i: number) => { reportCallback(i); - fetchedBytes += list[i].nbytes; - const dataUrl = new URL(list[i].dataPath, ndarrayCacheUrl).href; + const shard = list[i]; + fetchedBytes += shard.nbytes; + const dataUrl = new URL(shard.dataPath, ndarrayCacheUrl).href; let buffer; try { buffer = await (await artifactCache.fetchWithCache(dataUrl)).arrayBuffer(); @@ -1547,7 +1558,7 @@ export class Instance implements Disposable { this.env.logger("Error: Cannot fetch " + dataUrl + " err= " + err); throw err; } - const shardRecords = list[i].records; + const shardRecords = shard.records; for (let j = 0; j < shardRecords.length; ++j) { const rec = shardRecords[j]; const cpu_arr = this.withNewScope(() => { @@ -1578,6 +1589,7 @@ export class Instance implements Disposable { } timeElapsed = Math.ceil((perf.now() - tstart) / 1000); } + await Promise.all(list.map((_, index) => processShard(index))); reportCallback(list.length); } @@ -2432,3 +2444,28 @@ export async function hasNDArrayInCache( list = list["records"] as Array; return await artifactCache.hasAllKeys(list.map(key => new URL(key.dataPath, ndarrayCacheUrl).href)); } + +/** + * Given cacheUrl, search up items to delete based on cacheUrl/ndarray-cache.json + * + * @param cacheUrl + * @param cacheScope + */ +export async function deleteNDArrayCache( + cacheUrl: string, + cacheScope = "tvmjs" +) { + const artifactCache = new ArtifactCache(cacheScope); + const jsonUrl = new URL("ndarray-cache.json", cacheUrl).href; + const result = await artifactCache.fetchWithCache(jsonUrl); + let list; + if (result instanceof Response){ + list = await result.json(); + } + const arrayentry = list["records"] as Array; + const processShard = async (i: number) => { + const dataUrl = new URL(arrayentry[i].dataPath, cacheUrl).href; + await artifactCache.deleteInCache(dataUrl); + } + await Promise.all(arrayentry.map((_, index) => processShard(index))); +}