Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 13601f4

Browse files
authored
added support prod reduce op (#1279)
* added support prod reduce op * addressed the review comments, and removed the gradient func for prod op, since it requires the cumprod op
1 parent ebcb598 commit 13601f4

File tree

7 files changed

+204
-3
lines changed

7 files changed

+204
-3
lines changed

src/kernels/backend.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ export interface KernelBackend extends TensorStorage, BackendTimer {
108108
floorDiv(a: Tensor, b: Tensor): Tensor;
109109

110110
sum(x: Tensor, axes: number[]): Tensor;
111+
prod(x: Tensor, axes: number[]): Tensor;
111112

112113
unsortedSegmentSum<T extends Tensor>(
113114
x: T, segmentIds: Tensor1D, numSegments: number): Tensor;

src/kernels/backend_cpu.ts

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,28 @@ export class MathBackendCPU implements KernelBackend {
489489
return result;
490490
}
491491

492+
prod(x: Tensor, axes: number[]): Tensor {
493+
this.assertNotComplex(x, 'sum');
494+
495+
const [outShape, reduceShape] =
496+
axis_util.computeOutAndReduceShapes(x.shape, axes);
497+
const resultDtype = upcastType(x.dtype, 'int32');
498+
const result = ops.zeros(outShape, resultDtype);
499+
const reduceSize = util.sizeFromShape(reduceShape);
500+
const vals = result.dataSync();
501+
502+
const aVals = x.dataSync();
503+
for (let i = 0; i < vals.length; ++i) {
504+
const offset = i * reduceSize;
505+
let prod = 1;
506+
for (let j = 0; j < reduceSize; ++j) {
507+
prod *= aVals[offset + j];
508+
}
509+
vals[i] = prod;
510+
}
511+
return result;
512+
}
513+
492514
unsortedSegmentSum<T extends Tensor>(
493515
x: T, segmentIds: Tensor1D, numSegments: number): Tensor {
494516
this.assertNotComplex(x, 'unsortedSegmentSum');

src/kernels/backend_webgl.ts

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,7 @@ export class MathBackendWebGL implements KernelBackend {
769769
}
770770

771771
private reduce(
772-
x: Tensor2D, reduceType: 'all'|'any'|'max'|'min'|'sum',
772+
x: Tensor2D, reduceType: 'all'|'any'|'max'|'min'|'sum'|'prod',
773773
dtype: DataType): Tensor2D {
774774
const batchSize = x.shape[0];
775775
const inSize = x.shape[1];
@@ -824,6 +824,15 @@ export class MathBackendWebGL implements KernelBackend {
824824
return this.reduce(a2D, 'sum', outputDType).reshape(outShape);
825825
}
826826

827+
prod(x: Tensor, axes: number[]): Tensor {
828+
const [outShape, reduceShape] =
829+
axis_util.computeOutAndReduceShapes(x.shape, axes);
830+
const inSize = util.sizeFromShape(reduceShape);
831+
const a2D = x.as2D(-1, inSize);
832+
const outputDType = sumOutType(x.dtype);
833+
return this.reduce(a2D, 'prod', outputDType).reshape(outShape);
834+
}
835+
827836
unsortedSegmentSum<T extends Tensor>(
828837
x: T, segmentIds: Tensor1D, numSegments: number): Tensor {
829838
let axis = 0;

src/kernels/webgl/reduce_gpu.ts

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ export class ReduceProgram implements GPGPUProgram {
2424
userCode: string;
2525

2626
constructor(
27-
reduceInfo: ReduceInfo, reduceType: 'all'|'any'|'max'|'min'|'sum') {
27+
reduceInfo: ReduceInfo,
28+
reduceType: 'all'|'any'|'max'|'min'|'sum'|'prod') {
2829
const windowSize = reduceInfo.windowSize;
2930
const batchSize = reduceInfo.batchSize;
3031
const inSize = reduceInfo.inSize;
@@ -34,7 +35,9 @@ export class ReduceProgram implements GPGPUProgram {
3435
let initializationValue = '0.0';
3536
let compareOp = ``;
3637

37-
if (reduceType === 'min') {
38+
if (reduceType === 'prod') {
39+
initializationValue = '1.0';
40+
} else if (reduceType === 'min') {
3841
initializationValue = '1.0 / 0.0';
3942
compareOp = `min`;
4043
} else if (reduceType === 'max') {
@@ -47,6 +50,8 @@ export class ReduceProgram implements GPGPUProgram {
4750

4851
if (reduceType === 'sum') {
4952
returnValue = `sumValue`;
53+
} else if (reduceType === 'prod') {
54+
returnValue = `prodValue`;
5055
} else if (reduceType === 'all') {
5156
returnValue = `allValue`;
5257
} else if (reduceType === 'any') {
@@ -59,6 +64,9 @@ export class ReduceProgram implements GPGPUProgram {
5964
let updateSnippet = `
6065
if (${reduceType === 'sum'}) {
6166
sumValue += dot(values, ones);
67+
} else if (${reduceType === 'prod'}) {
68+
vec2 tmp = vec2(values[0], values[1]) * vec2(values[2], values[3]);
69+
prodValue *= tmp[0] * tmp[1];
6270
} else {
6371
minMaxValue = ${compareOp}(values, minMaxValue);
6472
}
@@ -108,6 +116,7 @@ export class ReduceProgram implements GPGPUProgram {
108116
int inOffset = outIdx * ${windowSize};
109117
110118
vec4 minMaxValue = vec4(${initializationValue});
119+
float prodValue = 1.0;
111120
float sumValue = 0.0;
112121
float allValue = 1.0;
113122
float anyValue = 0.0;

src/ops/reduction_ops.ts

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,60 @@ function sum_<T extends Tensor>(
142142
return customOp($x) as T;
143143
}
144144

145+
/**
146+
* Computes the product of elements across dimensions of a `Tensor`.
147+
*
148+
* Reduces the input along the dimensions given in `axes`. Unless `keepDims`
149+
* is true, the rank of the `Tensor` is reduced by 1 for each entry in `axes`.
150+
* If `keepDims` is true, the reduced dimensions are retained with length 1.
151+
* If `axes` has no entries, all dimensions are reduced, and a `Tensor` with a
152+
* single element is returned.
153+
*
154+
* ```js
155+
* const x = tf.tensor1d([1, 2, 3]);
156+
*
157+
* x.prod().print(); // or tf.prod(x)
158+
* ```
159+
*
160+
* ```js
161+
* const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
162+
*
163+
* const axis = 1;
164+
* x.prod(axis).print(); // or tf.prod(x, axis)
165+
* ```
166+
*
167+
* @param x The input tensor to compute the product over. If the dtype is `bool`
168+
* it will be converted to `int32` and the output dtype will be `int32`.
169+
* @param axis The dimension(s) to reduce. By default it reduces
170+
* all dimensions.
171+
* @param keepDims If true, retains reduced dimensions with size 1.
172+
*/
173+
/** @doc {heading: 'Operations', subheading: 'Reduction'} */
174+
function prod_<T extends Tensor>(
175+
x: Tensor|TensorLike, axis: number|number[] = null, keepDims = false): T {
176+
let $x = convertToTensor(x, 'x', 'prod');
177+
178+
if ($x.dtype === 'bool') {
179+
$x = $x.toInt();
180+
}
181+
const axes = axis_util.parseAxisParam(axis, $x.shape);
182+
183+
const permutation = axis_util.getAxesPermutation(axes, $x.rank);
184+
let reductionAxes = axes;
185+
let permutedX = $x;
186+
if (permutation != null) {
187+
permutedX = $x.transpose(permutation);
188+
reductionAxes = axis_util.getInnerMostAxes(reductionAxes.length, $x.rank);
189+
}
190+
let value = ENV.engine.runKernel(
191+
backend => backend.prod(permutedX, reductionAxes), {permutedX});
192+
if (keepDims) {
193+
const newShape = axis_util.expandShapeToKeepDim(value.shape, axes);
194+
value = value.reshape(newShape);
195+
}
196+
197+
return value as T;
198+
}
145199
/**
146200
* Computes the mean of elements across dimensions of a `Tensor`.
147201
*
@@ -554,3 +608,4 @@ export const mean = op({mean_});
554608
export const min = op({min_});
555609
export const moments = op({moments_});
556610
export const sum = op({sum_});
611+
export const prod = op({prod_});

src/ops/reduction_ops_test.ts

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,105 @@ describeWithFlags('Reduction: sum', ALL_ENVS, () => {
852852
});
853853
});
854854

855+
describeWithFlags('Reduction: prod', ALL_ENVS, () => {
856+
it('basic', () => {
857+
const a = tf.tensor2d([1, 2, 3, 0, 0, 1], [3, 2]);
858+
const result = tf.prod(a);
859+
expectNumbersClose(result.get(), 0);
860+
});
861+
862+
it('propagates NaNs', () => {
863+
const a = tf.tensor2d([1, 2, 3, NaN, 0, 1], [3, 2]);
864+
expect(tf.prod(a).get()).toEqual(NaN);
865+
});
866+
867+
it('prod over dtype int32', () => {
868+
const a = tf.tensor1d([1, 5, 7, 3], 'int32');
869+
const prod = tf.prod(a);
870+
expect(prod.get()).toBe(105);
871+
});
872+
873+
it('prod over dtype bool', () => {
874+
const a = tf.tensor1d([true, false, false, true, true], 'bool');
875+
const prod = tf.prod(a);
876+
expect(prod.get()).toBe(0);
877+
});
878+
879+
it('prods all values in 2D array with keep dim', () => {
880+
const a = tf.tensor2d([1, 2, 3, 1, 0, 1], [3, 2]);
881+
const res = tf.prod(a, null, true /* keepDims */);
882+
883+
expect(res.shape).toEqual([1, 1]);
884+
expectArraysClose(res, [0]);
885+
});
886+
887+
it('prods across axis=0 in 2D array', () => {
888+
const a = tf.tensor2d([1, 2, 3, 1, 0, 1], [3, 2]);
889+
const res = tf.prod(a, [0]);
890+
891+
expect(res.shape).toEqual([2]);
892+
expectArraysClose(res, [0, 2]);
893+
});
894+
895+
it('prods across axis=0 in 2D array, keepDims', () => {
896+
const a = tf.tensor2d([1, 2, 3, 1, 0, 1], [3, 2]);
897+
const res = tf.prod(a, [0], true /* keepDims */);
898+
899+
expect(res.shape).toEqual([1, 2]);
900+
expectArraysClose(res, [0, 2]);
901+
});
902+
903+
it('prods across axis=1 in 2D array', () => {
904+
const a = tf.tensor2d([1, 2, 3, 1, 1, 1], [3, 2]);
905+
const res = tf.prod(a, [1]);
906+
907+
expect(res.shape).toEqual([3]);
908+
expectArraysClose(res, [2, 3, 1]);
909+
});
910+
911+
it('2D, axis=1 provided as number', () => {
912+
const a = tf.tensor2d([1, 2, 3, 1, 1, 1], [2, 3]);
913+
const res = tf.prod(a, 1);
914+
915+
expect(res.shape).toEqual([2]);
916+
expectArraysClose(res, [6, 1]);
917+
});
918+
919+
it('2D, axis = -1 provided as number', () => {
920+
const a = tf.tensor2d([1, 2, 3, 1, 1, 1], [2, 3]);
921+
const res = tf.prod(a, -1);
922+
923+
expect(res.shape).toEqual([2]);
924+
expectArraysClose(res, [6, 1]);
925+
});
926+
927+
it('prods across axis=0,1 in 2D array', () => {
928+
const a = tf.tensor2d([1, 2, 3, 1, 1, 1], [3, 2]);
929+
const res = tf.prod(a, [0, 1]);
930+
931+
expect(res.shape).toEqual([]);
932+
expectArraysClose(res, [6]);
933+
});
934+
935+
it('2D, axis=[-1,-2] in 2D array', () => {
936+
const a = tf.tensor2d([1, 2, 3, 1, 1, 1], [3, 2]);
937+
const res = tf.prod(a, [-1, -2]);
938+
939+
expect(res.shape).toEqual([]);
940+
expectArraysClose(res, [6]);
941+
});
942+
943+
it('throws when passed a non-tensor', () => {
944+
expect(() => tf.prod({} as tf.Tensor))
945+
.toThrowError(/Argument 'x' passed to 'prod' must be a Tensor/);
946+
});
947+
948+
it('accepts a tensor-like object', () => {
949+
const result = tf.prod([[1, 2], [3, 1], [1, 1]]);
950+
expectNumbersClose(result.get(), 6);
951+
});
952+
});
953+
855954
describeWithFlags('Reduction: mean', ALL_ENVS, () => {
856955
it('basic', () => {
857956
const a = tf.tensor2d([1, 2, 3, 0, 0, 1], [3, 2]);

src/tensor.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,8 @@ export interface OpHandler {
190190
logSumExp<T extends Tensor>(
191191
x: Tensor, axis: number|number[], keepDims: boolean): T;
192192
sum<T extends Tensor>(x: Tensor, axis: number|number[], keepDims: boolean): T;
193+
prod<T extends Tensor>(x: Tensor, axis: number|number[], keepDims: boolean):
194+
T;
193195
mean<T extends Tensor>(x: Tensor, axis: number|number[], keepDims: boolean):
194196
T;
195197
min<T extends Tensor>(x: Tensor, axis: number|number[], keepDims: boolean): T;
@@ -769,6 +771,10 @@ export class Tensor<R extends Rank = Rank> {
769771
this.throwIfDisposed();
770772
return opHandler.sum(this, axis, keepDims);
771773
}
774+
prod<T extends Tensor>(axis: number|number[] = null, keepDims = false): T {
775+
this.throwIfDisposed();
776+
return opHandler.prod(this, axis, keepDims);
777+
}
772778
mean<T extends Tensor>(axis: number|number[] = null, keepDims = false): T {
773779
this.throwIfDisposed();
774780
return opHandler.mean(this, axis, keepDims);

0 commit comments

Comments
 (0)