Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public class CNodeUnary extends CNode
public enum UnaryType {
LOOKUP_R, LOOKUP_C, LOOKUP_RC, LOOKUP0, //codegen specific
ROW_SUMS, ROW_SUMSQS, ROW_COUNTNNZS, //codegen specific
ROW_MEANS, ROW_MINS, ROW_MAXS,
ROW_MEANS, ROW_MINS, ROW_MAXS, ROW_VARS,
VECT_EXP, VECT_POW2, VECT_MULT2, VECT_SQRT, VECT_LOG,
VECT_ABS, VECT_ROUND, VECT_CEIL, VECT_FLOOR, VECT_SIGN,
VECT_SIN, VECT_COS, VECT_TAN, VECT_ASIN, VECT_ACOS, VECT_ATAN,
Expand Down Expand Up @@ -139,6 +139,7 @@ public String toString() {
case ROW_MINS: return "u(Rmin)";
case ROW_MAXS: return "u(Rmax)";
case ROW_MEANS: return "u(Rmean)";
case ROW_VARS: return "u(Rvar)";
case ROW_COUNTNNZS: return "u(Rnnz)";
case VECT_EXP:
case VECT_POW2:
Expand Down Expand Up @@ -210,6 +211,7 @@ public void setOutputDims() {
case ROW_MINS:
case ROW_MAXS:
case ROW_MEANS:
case ROW_VARS:
case ROW_COUNTNNZS:
case EXP:
case LOOKUP_R:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ public String getTemplate(UnaryType type, boolean sparse) {
case ROW_MINS:
case ROW_MAXS:
case ROW_MEANS:
case ROW_VARS:
case ROW_COUNTNNZS: {
String vectName = StringUtils.capitalize(type.name().substring(4, type.name().length()-1).toLowerCase());
return sparse ? " double %TMP% = LibSpoofPrimitives.vect"+vectName+"(%IN1v%, %IN1i%, %POS1%, alen, len);\n":
" double %TMP% = LibSpoofPrimitives.vect"+vectName+"(%IN1%, %POS1%, %LEN%);\n";
}

case VECT_EXP:
case VECT_POW2:
case VECT_MULT2:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@

public class TemplateRow extends TemplateBase
{
private static final AggOp[] SUPPORTED_ROW_AGG = new AggOp[]{AggOp.SUM, AggOp.MIN, AggOp.MAX, AggOp.MEAN};
private static final AggOp[] SUPPORTED_ROW_AGG = new AggOp[]{AggOp.SUM, AggOp.MIN, AggOp.MAX, AggOp.MEAN, AggOp.VAR};
private static final OpOp1[] SUPPORTED_VECT_UNARY = new OpOp1[]{
OpOp1.EXP, OpOp1.SQRT, OpOp1.LOG, OpOp1.ABS, OpOp1.ROUND, OpOp1.CEIL, OpOp1.FLOOR, OpOp1.SIGN,
OpOp1.SIN, OpOp1.COS, OpOp1.TAN, OpOp1.ASIN, OpOp1.ACOS, OpOp1.ATAN, OpOp1.SINH, OpOp1.COSH, OpOp1.TANH,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2151,7 +2151,19 @@ public static double[] vectConv2dmmWrite(double[] a, double[] b, int ai, int bi,
new DenseBlockFP64(new int[]{K, PQ}, c), PQ, CRS, 0, K, 0, PQ);
return c;
}


public static double vectVar(double[] a, int ai, int len) {
double meanVal = Math.pow(vectMean(a, ai, len), 2);
double[] aSqr = vectPow2Write(a, ai, len);
return (vectSum(aSqr, 0, len)-len*meanVal)/(len-1);
}

public static double vectVar(double[] avals, int[] aix, int ai, int alen, int len) {
double meanVal = Math.pow(vectMean(avals, aix, ai, alen, len), 2);
double[] avalsSqr = vectPow2Write(avals, aix, ai, alen, len);
return (vectSum(avalsSqr, 0, len)-len*meanVal)/(len-1);
}

//complex builtin functions that are not directly generated
//(included here in order to reduce the number of imports)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ public class RowAggTmplTest extends AutomatedTestBase
private static final String TEST_NAME44 = TEST_NAME+"44"; //maxpool(X - mean(X)) + 7;
private static final String TEST_NAME45 = TEST_NAME+"45"; //vector allocation;
private static final String TEST_NAME46 = TEST_NAME+"46"; //conv2d(X - mean(X), F1) + conv2d(X - mean(X), F2);

private static final String TEST_NAME47 = TEST_NAME+"47"; //sum(X + rowVars(X))
private static final String TEST_NAME48 = TEST_NAME+"48"; //sum(rowVars(X))

private static final String TEST_DIR = "functions/codegen/";
private static final String TEST_CLASS_DIR = TEST_DIR + RowAggTmplTest.class.getSimpleName() + "/";
private final static String TEST_CONF = "SystemDS-config-codegen.xml";
Expand All @@ -98,7 +100,7 @@ public class RowAggTmplTest extends AutomatedTestBase
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
for(int i=1; i<=46; i++)
for(int i=1; i<=48; i++)
addTestConfiguration( TEST_NAME+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) }) );
}

Expand Down Expand Up @@ -795,6 +797,36 @@ public void testCodegenRowAgg46SP() {
testCodegenIntegration( TEST_NAME46, false, ExecType.SPARK );
}

@Test
public void testCodegenRowAggRewrite47CP() {
testCodegenIntegration( TEST_NAME47, true, ExecType.CP );
}

@Test
public void testCodegenRowAgg47CP() {
testCodegenIntegration( TEST_NAME47, false, ExecType.CP );
}

@Test
public void testCodegenRowAgg47SP() {
testCodegenIntegration( TEST_NAME47, false, ExecType.SPARK );
}

@Test
public void testCodegenRowAggRewrite48CP() {
testCodegenIntegration( TEST_NAME48, true, ExecType.CP );
}

@Test
public void testCodegenRowAgg48CP() {
testCodegenIntegration( TEST_NAME48, false, ExecType.CP );
}

@Test
public void testCodegenRowAgg48SP() {
testCodegenIntegration( TEST_NAME48, false, ExecType.SPARK );
}

private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType )
{
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
Expand All @@ -807,7 +839,7 @@ private void testCodegenIntegration( String testname, boolean rewrites, ExecType

String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
programArgs = new String[]{"-stats", "-args", output("S") };
programArgs = new String[]{"-explain", "codegen", "-stats", "-args", output("S") };

fullRScriptName = HOME + testname + ".R";
rCmd = getRCmd(inputDir(), expectedDir());
Expand Down
36 changes: 36 additions & 0 deletions src/test/scripts/functions/codegen/rowAggPattern47.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

args<-commandArgs(TRUE)
options(digits=22)
library("Matrix")
library("matrixStats")

# rowVars <- function(X) {
# apply(X, 1, function(x) sum((x - mean(x))^2) / length(x))
# }

X = matrix(seq(7, 50*10+6), 50, 10, byrow=TRUE);
z = seq(1,50)

R = as.matrix(sum(X + rowVars(X)));

writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep=""));
29 changes: 29 additions & 0 deletions src/test/scripts/functions/codegen/rowAggPattern47.dml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

X = matrix(seq(7, 50*10+6), 50, 10);
z = seq(1,50)

while(FALSE){}

R = as.matrix(sum(X + rowVars(X)));

write(R, $1)
36 changes: 36 additions & 0 deletions src/test/scripts/functions/codegen/rowAggPattern48.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
args<-commandArgs(TRUE)
options(digits=22)
library("Matrix")
library("matrixStats")

# rowVars <- function(X) {
# apply(X, 1, function(x) sum((x - mean(x))^2) / length(x))
# }

Z = matrix(seq(1,10), 1, 10)
Y = matrix(0, 10, 10)
X = rbind(Y, Z, Y)

R = as.matrix(sum(rowVars(X)));

writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep=""));
30 changes: 30 additions & 0 deletions src/test/scripts/functions/codegen/rowAggPattern48.dml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

Z = matrix(seq(1,10), 1, 10)
Y = matrix(0, 10, 10)
X = rbind(Y, Z, Y)

while(FALSE){}

R = as.matrix(sum(rowVars(X)));

write(R, $1)