@@ -852,6 +852,105 @@ describeWithFlags('Reduction: sum', ALL_ENVS, () => {
852852 } ) ;
853853} ) ;
854854
855+ describeWithFlags ( 'Reduction: prod' , ALL_ENVS , ( ) => {
856+ it ( 'basic' , ( ) => {
857+ const a = tf . tensor2d ( [ 1 , 2 , 3 , 0 , 0 , 1 ] , [ 3 , 2 ] ) ;
858+ const result = tf . prod ( a ) ;
859+ expectNumbersClose ( result . get ( ) , 0 ) ;
860+ } ) ;
861+
862+ it ( 'propagates NaNs' , ( ) => {
863+ const a = tf . tensor2d ( [ 1 , 2 , 3 , NaN , 0 , 1 ] , [ 3 , 2 ] ) ;
864+ expect ( tf . prod ( a ) . get ( ) ) . toEqual ( NaN ) ;
865+ } ) ;
866+
867+ it ( 'prod over dtype int32' , ( ) => {
868+ const a = tf . tensor1d ( [ 1 , 5 , 7 , 3 ] , 'int32' ) ;
869+ const prod = tf . prod ( a ) ;
870+ expect ( prod . get ( ) ) . toBe ( 105 ) ;
871+ } ) ;
872+
873+ it ( 'prod over dtype bool' , ( ) => {
874+ const a = tf . tensor1d ( [ true , false , false , true , true ] , 'bool' ) ;
875+ const prod = tf . prod ( a ) ;
876+ expect ( prod . get ( ) ) . toBe ( 0 ) ;
877+ } ) ;
878+
879+ it ( 'prods all values in 2D array with keep dim' , ( ) => {
880+ const a = tf . tensor2d ( [ 1 , 2 , 3 , 1 , 0 , 1 ] , [ 3 , 2 ] ) ;
881+ const res = tf . prod ( a , null , true /* keepDims */ ) ;
882+
883+ expect ( res . shape ) . toEqual ( [ 1 , 1 ] ) ;
884+ expectArraysClose ( res , [ 0 ] ) ;
885+ } ) ;
886+
887+ it ( 'prods across axis=0 in 2D array' , ( ) => {
888+ const a = tf . tensor2d ( [ 1 , 2 , 3 , 1 , 0 , 1 ] , [ 3 , 2 ] ) ;
889+ const res = tf . prod ( a , [ 0 ] ) ;
890+
891+ expect ( res . shape ) . toEqual ( [ 2 ] ) ;
892+ expectArraysClose ( res , [ 0 , 2 ] ) ;
893+ } ) ;
894+
895+ it ( 'prods across axis=0 in 2D array, keepDims' , ( ) => {
896+ const a = tf . tensor2d ( [ 1 , 2 , 3 , 1 , 0 , 1 ] , [ 3 , 2 ] ) ;
897+ const res = tf . prod ( a , [ 0 ] , true /* keepDims */ ) ;
898+
899+ expect ( res . shape ) . toEqual ( [ 1 , 2 ] ) ;
900+ expectArraysClose ( res , [ 0 , 2 ] ) ;
901+ } ) ;
902+
903+ it ( 'prods across axis=1 in 2D array' , ( ) => {
904+ const a = tf . tensor2d ( [ 1 , 2 , 3 , 1 , 1 , 1 ] , [ 3 , 2 ] ) ;
905+ const res = tf . prod ( a , [ 1 ] ) ;
906+
907+ expect ( res . shape ) . toEqual ( [ 3 ] ) ;
908+ expectArraysClose ( res , [ 2 , 3 , 1 ] ) ;
909+ } ) ;
910+
911+ it ( '2D, axis=1 provided as number' , ( ) => {
912+ const a = tf . tensor2d ( [ 1 , 2 , 3 , 1 , 1 , 1 ] , [ 2 , 3 ] ) ;
913+ const res = tf . prod ( a , 1 ) ;
914+
915+ expect ( res . shape ) . toEqual ( [ 2 ] ) ;
916+ expectArraysClose ( res , [ 6 , 1 ] ) ;
917+ } ) ;
918+
919+ it ( '2D, axis = -1 provided as number' , ( ) => {
920+ const a = tf . tensor2d ( [ 1 , 2 , 3 , 1 , 1 , 1 ] , [ 2 , 3 ] ) ;
921+ const res = tf . prod ( a , - 1 ) ;
922+
923+ expect ( res . shape ) . toEqual ( [ 2 ] ) ;
924+ expectArraysClose ( res , [ 6 , 1 ] ) ;
925+ } ) ;
926+
927+ it ( 'prods across axis=0,1 in 2D array' , ( ) => {
928+ const a = tf . tensor2d ( [ 1 , 2 , 3 , 1 , 1 , 1 ] , [ 3 , 2 ] ) ;
929+ const res = tf . prod ( a , [ 0 , 1 ] ) ;
930+
931+ expect ( res . shape ) . toEqual ( [ ] ) ;
932+ expectArraysClose ( res , [ 6 ] ) ;
933+ } ) ;
934+
935+ it ( '2D, axis=[-1,-2] in 2D array' , ( ) => {
936+ const a = tf . tensor2d ( [ 1 , 2 , 3 , 1 , 1 , 1 ] , [ 3 , 2 ] ) ;
937+ const res = tf . prod ( a , [ - 1 , - 2 ] ) ;
938+
939+ expect ( res . shape ) . toEqual ( [ ] ) ;
940+ expectArraysClose ( res , [ 6 ] ) ;
941+ } ) ;
942+
943+ it ( 'throws when passed a non-tensor' , ( ) => {
944+ expect ( ( ) => tf . prod ( { } as tf . Tensor ) )
945+ . toThrowError ( / A r g u m e n t ' x ' p a s s e d t o ' p r o d ' m u s t b e a T e n s o r / ) ;
946+ } ) ;
947+
948+ it ( 'accepts a tensor-like object' , ( ) => {
949+ const result = tf . prod ( [ [ 1 , 2 ] , [ 3 , 1 ] , [ 1 , 1 ] ] ) ;
950+ expectNumbersClose ( result . get ( ) , 6 ) ;
951+ } ) ;
952+ } ) ;
953+
855954describeWithFlags ( 'Reduction: mean' , ALL_ENVS , ( ) => {
856955 it ( 'basic' , ( ) => {
857956 const a = tf . tensor2d ( [ 1 , 2 , 3 , 0 , 0 , 1 ] , [ 3 , 2 ] ) ;
0 commit comments