Skip to content

Commit d8a8afe

Browse files
authored
[webgpu] Add support for scatterND (#5643)
FEATURE
1 parent 89363dd commit d8a8afe

File tree

4 files changed

+88
-2
lines changed

4 files changed

+88
-2
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/**
2+
* @license
3+
* Copyright 2021 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {backend_util, KernelConfig, KernelFunc, ScatterNd, ScatterNdAttrs, ScatterNdInputs, TensorInfo} from '@tensorflow/tfjs-core';
19+
20+
import {WebGPUBackend} from '../backend_webgpu';
21+
22+
import {reshape} from './Reshape';
23+
import {ScatterProgram} from './scatter_webgpu';
24+
25+
export function scatterNd(args: {
26+
inputs: ScatterNdInputs,
27+
backend: WebGPUBackend,
28+
attrs: ScatterNdAttrs
29+
}): TensorInfo {
30+
const {inputs, backend, attrs} = args;
31+
const {indices, updates} = inputs;
32+
const {shape} = attrs;
33+
34+
const {sliceRank, numUpdates, sliceSize, strides, outputSize} =
35+
backend_util.calculateShapes(updates, indices, shape);
36+
37+
const flattenShape = [outputSize / sliceSize, sliceSize];
38+
39+
if (outputSize === 0) {
40+
return backend.makeTensorInfo(shape, indices.dtype);
41+
}
42+
43+
const flattenIndices = reshape(
44+
{inputs: {x: indices}, backend, attrs: {shape: [numUpdates, sliceRank]}});
45+
const flattenX = reshape(
46+
{inputs: {x: updates}, backend, attrs: {shape: [numUpdates, sliceSize]}});
47+
48+
const defaultValue = backend.makeTensorInfo(
49+
[], 'float32', new Float32Array([0])); // scalar(0)
50+
const uniformData = [
51+
{type: 'int32', data: [numUpdates]},
52+
{type: 'int32', data: [sliceRank]},
53+
{type: 'int32', data: strides},
54+
];
55+
const program = new ScatterProgram(
56+
numUpdates, sliceRank, flattenIndices.shape.length, flattenX.shape.length,
57+
strides, flattenShape);
58+
const res = backend.runWebGPUProgram(
59+
program, [flattenX, flattenIndices, defaultValue], flattenX.dtype,
60+
uniformData);
61+
62+
const reshaped = reshape({inputs: {x: res}, backend, attrs: {shape}});
63+
64+
backend.disposeData(flattenIndices.dataId);
65+
backend.disposeData(flattenX.dataId);
66+
backend.disposeData(res.dataId);
67+
backend.disposeData(defaultValue.dataId);
68+
69+
return reshaped;
70+
}
71+
72+
export const scatterNdConfig: KernelConfig = {
73+
kernelName: ScatterNd,
74+
backendName: 'webgpu',
75+
kernelFunc: scatterNd as {} as KernelFunc
76+
};

tfjs-backend-webgpu/src/kernels/scatter_webgpu.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ export class ScatterProgram implements WebGPUProgram {
4242
this.dispatchLayout = flatDispatchLayout(this.outputShape);
4343
this.dispatch = computeDispatch(
4444
this.dispatchLayout, this.outputShape, this.workGroupSize);
45-
this.shaderKey = `scatter_${indicesRank}_${updatesRank}`;
45+
const sliceDimGreaterThanOne = sliceDim > 1;
46+
this.shaderKey =
47+
`scatter_${indicesRank}_${updatesRank}_${sliceDimGreaterThanOne}`;
4648
this.size = util.sizeFromShape(this.outputShape);
4749
const stridesType = getCoordsDataType(strides.length);
4850
this.uniforms =
@@ -64,7 +66,7 @@ export class ScatterProgram implements WebGPUProgram {
6466
this.updatesSnippet = `getUpdates(${updatesString})`;
6567

6668
this.strideString =
67-
sliceDim > 1 ? 'uniforms.strides[j]' : 'uniforms.strides';
69+
sliceDimGreaterThanOne ? 'uniforms.strides[j]' : 'uniforms.strides';
6870
}
6971

7072
getUserCode(): string {

tfjs-backend-webgpu/src/register_all_kernels.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ import {resizeBilinearConfig} from './kernels/ResizeBilinear';
9090
import {resizeNearestNeighborConfig} from './kernels/ResizeNearestNeighbor';
9191
import {rotateWithOffsetConfig} from './kernels/RotateWithOffset';
9292
import {rsqrtConfig} from './kernels/Rsqrt';
93+
import {scatterNdConfig} from './kernels/ScatterNd';
9394
import {selectConfig} from './kernels/Select';
9495
import {sigmoidConfig} from './kernels/Sigmoid';
9596
import {sinConfig} from './kernels/Sin';
@@ -189,6 +190,7 @@ const kernelConfigs: KernelConfig[] = [
189190
resizeNearestNeighborConfig,
190191
rotateWithOffsetConfig,
191192
rsqrtConfig,
193+
scatterNdConfig,
192194
selectConfig,
193195
sigmoidConfig,
194196
sinConfig,

tfjs-backend-webgpu/src/setup_test.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,12 @@ const TEST_FILTERS: TestFilter[] = [
584584
'gradient' // gradient function not found.
585585
]
586586
},
587+
{
588+
include: 'scatterND',
589+
excludes: [
590+
'gradient' // gradient function not found.
591+
]
592+
},
587593
{
588594
startsWith: 'logicalAnd ',
589595
},

0 commit comments

Comments
 (0)