Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 8dd16e0

Browse files
authored
added support for GatherNd op (#1298)
* added support for gather_nd op * addressed the review comments * make sure empty tensor has the same type as input * fixed lint error * address review comments * fix for review comments * fixed the index boundary check
1 parent e48803c commit 8dd16e0

File tree

8 files changed

+400
-0
lines changed

8 files changed

+400
-0
lines changed

src/kernels/backend.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,8 @@ export interface KernelBackend extends TensorStorage, BackendTimer {
230230

231231
gather<T extends Tensor>(x: T, indices: Tensor1D, axis: number): T;
232232

233+
gatherND(x: Tensor, indices: Tensor): Tensor;
234+
233235
scatterND<R extends Rank>(
234236
indices: Tensor, updates: Tensor, shape: ShapeMap[R]): Tensor<R>;
235237

src/kernels/backend_cpu.ts

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import * as broadcast_util from '../ops/broadcast_util';
2525
import * as concat_util from '../ops/concat_util';
2626
import {Conv2DInfo} from '../ops/conv_util';
2727
import * as erf_util from '../ops/erf_util';
28+
import * as gather_nd_util from '../ops/gather_nd_util';
2829
import * as ops from '../ops/ops';
2930
import {buffer, tensor, tensor3d, tensor4d} from '../ops/ops';
3031
import * as scatter_nd_util from '../ops/scatter_nd_util';
@@ -2848,6 +2849,40 @@ export class MathBackendCPU implements KernelBackend {
28482849
return output.toTensor();
28492850
}
28502851

2852+
gatherND(x: Tensor, indices: Tensor): Tensor<Rank> {
2853+
const indicesShape = indices.shape;
2854+
const sliceRank = indicesShape[indicesShape.length - 1];
2855+
2856+
const [resultShape, numSlices, sliceSize, strides] =
2857+
gather_nd_util.prepareAndValidate(x, indices);
2858+
if (numSlices === 0) {
2859+
return tensor([], resultShape, x.dtype);
2860+
}
2861+
2862+
const buffer = new TensorBuffer([numSlices, sliceSize], x.dtype);
2863+
const indicesData = indices.dataSync();
2864+
const xData = x.dataSync();
2865+
2866+
for (let i = 0; i < numSlices; i++) {
2867+
const index = [];
2868+
let flattenIndex = 0;
2869+
for (let j = 0; j < sliceRank; j++) {
2870+
const dim = indicesData[i * sliceRank + j];
2871+
flattenIndex += dim * strides[j];
2872+
index.push(dim);
2873+
}
2874+
if (flattenIndex < 0 || flattenIndex >= x.size / sliceSize) {
2875+
throw new Error(
2876+
`Invalid indices: ${index} does not index into ${x.shape}`);
2877+
}
2878+
2879+
for (let k = 0; k < sliceSize; k++) {
2880+
buffer.values[i * sliceSize + k] = xData[flattenIndex * sliceSize + k];
2881+
}
2882+
}
2883+
return buffer.toTensor().reshape(resultShape);
2884+
}
2885+
28512886
scatterND<R extends Rank>(
28522887
indices: Tensor, updates: Tensor, shape: ShapeMap[R]): Tensor<R> {
28532888
const [sliceRank, numUpdates, sliceSize, strides, outputSize] =

src/kernels/backend_webgl.ts

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import * as array_ops_util from '../ops/array_ops_util';
2323
import * as axis_util from '../ops/axis_util';
2424
import {computeOutShape} from '../ops/concat_util';
2525
import {Conv2DInfo} from '../ops/conv_util';
26+
import * as gather_nd_util from '../ops/gather_nd_util';
2627
import * as reduce_util from '../ops/reduce_util';
2728
import * as scatter_nd_util from '../ops/scatter_nd_util';
2829
import * as segment_util from '../ops/segment_util';
@@ -61,6 +62,7 @@ import * as fft_gpu from './webgl/fft_gpu';
6162
import {FFTProgram} from './webgl/fft_gpu';
6263
import {FromPixelsProgram} from './webgl/from_pixels_gpu';
6364
import {GatherProgram} from './webgl/gather_gpu';
65+
import {GatherNDProgram} from './webgl/gather_nd_gpu';
6466
import {GPGPUContext} from './webgl/gpgpu_context';
6567
import * as gpgpu_math from './webgl/gpgpu_math';
6668
import {GPGPUBinary, GPGPUProgram, TensorData} from './webgl/gpgpu_math';
@@ -1623,6 +1625,21 @@ export class MathBackendWebGL implements KernelBackend {
16231625
return complex;
16241626
}
16251627

1628+
gatherND(x: Tensor, indices: Tensor): Tensor<Rank> {
1629+
const indicesShape = indices.shape;
1630+
const sliceRank = indicesShape[indicesShape.length - 1];
1631+
1632+
const [resultShape, numSlices, sliceSize, strides] =
1633+
gather_nd_util.prepareAndValidate(x, indices);
1634+
1635+
const flattenIndices = indices.reshape([numSlices, sliceRank]);
1636+
const flattenX = x.reshape([x.size / sliceSize, sliceSize]);
1637+
const program =
1638+
new GatherNDProgram(sliceRank, strides, [numSlices, sliceSize]);
1639+
return (this.compileAndRun(program, [flattenX, flattenIndices]) as Tensor)
1640+
.reshape(resultShape);
1641+
}
1642+
16261643
private makeOutputArray<T extends Tensor>(shape: number[], dtype: DataType):
16271644
T {
16281645
return Tensor.make(shape, {}, dtype) as T;

src/kernels/webgl/gather_nd_gpu.ts

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/**
2+
* @license
3+
* Copyright 2018 Google Inc. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
import {GPGPUProgram} from './gpgpu_math';
18+
import {getCoordsDataType} from './shader_compiler';
19+
20+
export class GatherNDProgram implements GPGPUProgram {
21+
variableNames = ['x', 'indices'];
22+
outputShape: number[];
23+
userCode: string;
24+
constructor(
25+
private sliceDim: number, private strides: number[], shape: number[]) {
26+
this.outputShape = shape;
27+
const stridesType = getCoordsDataType(strides.length);
28+
const dtype = getCoordsDataType(shape.length);
29+
const strideString = this.sliceDim > 1 ? 'strides[j]' : 'strides';
30+
this.userCode = `
31+
${stridesType} strides = ${stridesType}(${this.strides});
32+
void main() {
33+
${dtype} coords = getOutputCoords();
34+
int flattenIndex = 0;
35+
for (int j = 0; j < ${this.sliceDim}; j++) {
36+
int index = round(getIndices(coords[0], j));
37+
flattenIndex += index * ${strideString};
38+
}
39+
setOutput(getX(flattenIndex, coords[1]));
40+
}
41+
`;
42+
}
43+
}

src/ops/gather_nd.ts

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/**
2+
* @license
3+
* Copyright 2018 Google Inc. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
import {ENV} from '../environment';
18+
import {Tensor} from '../tensor';
19+
import {convertToTensor} from '../tensor_util_env';
20+
import {Rank, TensorLike} from '../types';
21+
import {op} from './operation';
22+
23+
/**
24+
* Gather slices from input tensor into a Tensor with shape specified by
25+
* `indices`.
26+
*
27+
* `indices` is an K-dimensional integer tensor, best thought of as a
28+
* (K-1)-dimensional tensor of indices into input, where each element defines a
29+
* slice of input:
30+
* output[\\(i_0, ..., i_{K-2}\\)] = input[indices[\\(i_0, ..., i_{K-2}\\)]]
31+
*
32+
* Whereas in `gather` `indices` defines slices into the first dimension of
33+
* input, in `gatherND`, `indices` defines slices into the first N dimensions
34+
* of input, where N = indices.shape[-1].
35+
*
36+
* The last dimension of indices can be at most the rank of input:
37+
* indices.shape[-1] <= input.rank
38+
*
39+
* The last dimension of `indices` corresponds to elements
40+
* (if indices.shape[-1] == input.rank) or slices
41+
* (if indices.shape[-1] < input.rank) along dimension indices.shape[-1] of
42+
* input.
43+
* The output tensor has shape
44+
* indices.shape[:-1] + input.shape[indices.shape[-1]:]
45+
*
46+
* Note that on CPU, if an out of bound index is found, an error is returned. On
47+
* GPU, if an out of bound index is found, a 0 is stored in the corresponding
48+
* output value.
49+
*
50+
* ```js
51+
* const indices = tf.tensor2d([0, 1, 1, 0], [2,2], 'int32');
52+
* const input = tf.tensor2d([9, 10, 11, 12], [2, 2]);
53+
* tf.gatherND(input, indices).print() //[10, 11]
54+
* ```
55+
*
56+
* @param x The tensor from which to gather values.
57+
* @param indices Index tensor, must be of type int32.
58+
*/
59+
/** @doc {heading: 'Operations', subheading: 'Slicing and Joining'} */
60+
function gatherND_(
61+
x: Tensor|TensorLike, indices: Tensor|TensorLike): Tensor<Rank> {
62+
const $indices = convertToTensor(indices, 'indices', 'gatherND', 'int32');
63+
const $x = convertToTensor(x, 'x', 'gatherND');
64+
return ENV.engine.runKernel(
65+
backend => backend.gatherND($x, $indices), {$x, $indices}) as
66+
Tensor<Rank>;
67+
}
68+
export const gatherND = op({gatherND_});

src/ops/gather_nd_test.ts

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
/**
2+
* @license
3+
* Copyright 2018 Google Inc. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {describeWithFlags} from '../jasmine_util';
19+
import {ALL_ENVS, CPU_ENVS, expectArraysClose} from '../test_util';
20+
21+
import {gatherND} from './gather_nd';
22+
import {scalar, tensor1d, tensor2d, tensor3d} from './tensor_ops';
23+
24+
describeWithFlags('gatherND', ALL_ENVS, () => {
25+
it('should work for simple slice', () => {
26+
const indices = tensor2d([0, 4, 8], [3, 1], 'int32');
27+
const input =
28+
tensor1d([100, 101, 102, 777, 778, 779, 1000, 1001, 1002], 'int32');
29+
const shape = [3];
30+
const result = gatherND(input, indices);
31+
expect(result.shape).toEqual(shape);
32+
expect(result.dtype).toEqual(input.dtype);
33+
expectArraysClose(result, [100, 778, 1002]);
34+
});
35+
36+
it('should work for indexing 2d', () => {
37+
const indices = tensor2d([0, 2], [2, 1], 'int32');
38+
const input = tensor2d(
39+
[
40+
5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
41+
5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8
42+
],
43+
[8, 4], 'float32');
44+
const shape = [2, 4];
45+
const result = gatherND(input, indices);
46+
expect(result.shape).toEqual(shape);
47+
expect(result.dtype).toEqual(input.dtype);
48+
expectArraysClose(result, [5, 5, 5, 5, 7, 7, 7, 7]);
49+
});
50+
51+
it('should work for indexing 3d', () => {
52+
const indices = tensor2d([0, 2, 1, 1], [2, 2], 'int32');
53+
const input = tensor3d(
54+
[
55+
5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
56+
5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8
57+
],
58+
[2, 4, 4], 'float32');
59+
const shape = [2, 4];
60+
const result = gatherND(input, indices);
61+
expect(result.shape).toEqual(shape);
62+
expect(result.dtype).toEqual(input.dtype);
63+
expectArraysClose(result, [7, 7, 7, 7, 6, 6, 6, 6]);
64+
});
65+
66+
it('should work for batch slice', () => {
67+
const indices = tensor3d([0, 4, 2], [3, 1, 1], 'int32');
68+
const input =
69+
tensor1d([100, 101, 102, 777, 778, 779, 10000, 10001, 10002], 'int32');
70+
const shape = [3, 1];
71+
const result = gatherND(input, indices);
72+
expect(result.shape).toEqual(shape);
73+
expect(result.dtype).toEqual(input.dtype);
74+
expectArraysClose(result, [100, 778, 102]);
75+
});
76+
77+
it('should work for batch indexing 2d', () => {
78+
const indices = tensor3d([0, 2], [2, 1, 1], 'int32');
79+
const input = tensor2d(
80+
[
81+
5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
82+
5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8
83+
],
84+
[8, 4], 'float32');
85+
const shape = [2, 1, 4];
86+
const result = gatherND(input, indices);
87+
expect(result.shape).toEqual(shape);
88+
expect(result.dtype).toEqual(input.dtype);
89+
expectArraysClose(result, [5, 5, 5, 5, 7, 7, 7, 7]);
90+
});
91+
92+
it('should work for batch indexing 3d', () => {
93+
const indices = tensor3d([0, 2, 1, 1], [2, 1, 2], 'int32');
94+
const input = tensor3d(
95+
[
96+
5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
97+
5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8
98+
],
99+
[2, 4, 4], 'float32');
100+
const shape = [2, 1, 4];
101+
const result = gatherND(input, indices);
102+
expect(result.shape).toEqual(shape);
103+
expect(result.dtype).toEqual(input.dtype);
104+
expectArraysClose(result, [7, 7, 7, 7, 6, 6, 6, 6]);
105+
});
106+
107+
it('should work for TensorLike inputs', () => {
108+
const indices = [[0], [4], [8]];
109+
const input = [100, 101, 102, 777, 778, 779, 1000, 1001, 1002];
110+
const shape = [3];
111+
const result = gatherND(input, indices);
112+
expect(result.shape).toEqual(shape);
113+
expect(result.dtype).toEqual('float32');
114+
expectArraysClose(result, [100, 778, 1002]);
115+
});
116+
117+
it('should throw error when indices are not int32', () => {
118+
const indices = tensor1d([1], 'float32');
119+
const input = tensor2d(
120+
[100, 101, 102, 103, 777, 778, 779, 780, 10000, 10001, 10002, 10004],
121+
[3, 4], 'float32');
122+
expect(() => gatherND(input, indices)).toThrow();
123+
});
124+
it('should throw error when indices are scalar', () => {
125+
const indices = scalar(1, 'int32');
126+
const input = tensor2d(
127+
[100, 101, 102, 103, 777, 778, 779, 780, 10000, 10001, 10002, 10004],
128+
[3, 4], 'float32');
129+
expect(() => gatherND(input, indices)).toThrow();
130+
});
131+
it('should throw error when x is scalar', () => {
132+
const indices = tensor2d([0, 4, 2], [3, 1], 'int32');
133+
const input = scalar(1.0, 'float32');
134+
expect(() => gatherND(input, indices)).toThrow();
135+
});
136+
it('should throw error when indices inner dim > x shape length', () => {
137+
const indices = tensor2d([0, 4, 2], [1, 3], 'int32');
138+
const input =
139+
tensor2d([100, 101, 102, 10000, 10001, 10002], [3, 2], 'float32');
140+
expect(() => gatherND(input, indices)).toThrow();
141+
});
142+
});
143+
describeWithFlags('gatherND CPU', CPU_ENVS, () => {
144+
it('should throw error when index out of range', () => {
145+
const indices = tensor2d([0, 2, 99], [3, 1], 'int32');
146+
const input = tensor2d(
147+
[100, 101, 102, 777, 778, 779, 10000, 10001, 10002], [3, 3], 'float32');
148+
expect(() => gatherND(input, indices)).toThrow();
149+
});
150+
});

0 commit comments

Comments
 (0)