diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index 65c8805c7ce..f07c4c0a147 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -202,6 +202,7 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst) hi = simplifyNegatedSubtraction(hop, hi, i); //e.g., -(B-A)->A-B hi = simplifyTransposeAddition(hop, hi, i); //e.g., t(A+s1)+s2 -> t(A)+(s1+s2) + potential constant folding hi = simplifyNotOverComparisons(hop, hi, i); //e.g., !(A>B) -> (A<=B) + hi = simplifyMatrixScalarPMOperation(hop, hi, i); //e.g., a-A-b -> (a-b)-A; a+A-b -> (a-b)+A //hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X)) //process childs recursively after rewrites (to investigate pattern newly created by rewrites) @@ -212,6 +213,60 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst) hop.setVisited(); } + private Hop simplifyMatrixScalarPMOperation(Hop parent, Hop hi, int pos) { + if (!(hi instanceof BinaryOp)) + return hi; + + BinaryOp outer = (BinaryOp) hi; + Hop left = outer.getInput().get(0); + Hop right = outer.getInput().get(1); + OpOp2 outerOp = outer.getOp(); + + if ((outerOp != OpOp2.PLUS && outerOp != OpOp2.MINUS) || !(left instanceof BinaryOp)) + return hi; + + BinaryOp inner = (BinaryOp) left; + Hop a = inner.getInput().get(0); + Hop A = inner.getInput().get(1); + Hop b = right; + OpOp2 innerOp = inner.getOp(); + + java.util.function.Predicate isScalar = h -> h.getDataType().isScalar(); + if (!isScalar.test(a) || !isScalar.test(b) || A.getDataType() != DataType.MATRIX) + return hi; + + // Determine the scalarOp (between a and b) and matrixOp (with A) + OpOp2 scalarOp = null; + OpOp2 matrixOp = null; + + if (innerOp == OpOp2.MINUS && outerOp == OpOp2.MINUS) { + scalarOp = OpOp2.MINUS; + matrixOp = OpOp2.MINUS; + } + else if (innerOp == OpOp2.PLUS && outerOp == OpOp2.MINUS) { + scalarOp = OpOp2.MINUS; + matrixOp = OpOp2.PLUS; + } + else if (innerOp == OpOp2.MINUS && outerOp == OpOp2.PLUS) { + scalarOp = OpOp2.PLUS; + matrixOp = OpOp2.MINUS; + } + else if (innerOp == OpOp2.PLUS && outerOp == OpOp2.PLUS) { + scalarOp = OpOp2.PLUS; + matrixOp = OpOp2.PLUS; + } + else { + return hi; + } + + Hop scalarCombined = HopRewriteUtils.createBinary(a, b, scalarOp); + Hop result = HopRewriteUtils.createBinary(scalarCombined, A, matrixOp); + + HopRewriteUtils.replaceChildReference(parent, hi, result, pos); + LOG.debug("Applied simplifyMatrixScalarPMOperation"); + return result; + } + private static Hop simplifyTransposeAddition(Hop parent, Hop hi, int pos) { //pattern: t(A+s1)+s2 -> t(A)+(s1+s2), and subsequent constant folding if (HopRewriteUtils.isBinary(hi, OpOp2.PLUS) diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyScalarMatrixPMOperationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyScalarMatrixPMOperationTest.java new file mode 100644 index 00000000000..d4e8dbf7b2e --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyScalarMatrixPMOperationTest.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Types; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.HashMap; + +public class RewriteSimplifyScalarMatrixPMOperationTest extends AutomatedTestBase { + private static final String TEST_NAME1 = "RewriteScalarMinusMatrixMinusScalar"; + private static final String TEST_NAME2 = "RewriteScalarPlusMatrixMinusScalar"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteSimplifyScalarMatrixPMOperationTest.class.getSimpleName() + "/"; + private static final int rows = 100; + private static final int cols = 100; + private static final double eps = 1e-6; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"A", "a", "b", "R"})); + addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"A", "a", "b", "R"})); + } + + @Test + public void testScalarMinusMatrixMinusScalarRewriteEnabled() { + runRewriteTest(TEST_NAME1, true); + } + + @Test + public void testScalarMinusMatrixMinusScalarRewriteDisabled() { + runRewriteTest(TEST_NAME1, false); + } + + @Test + public void testScalarPlusMatrixMinusScalarRewriteEnabled() { + runRewriteTest(TEST_NAME2, true); + } + + @Test + public void testScalarPlusMatrixMinusScalarRewriteDisabled() { + runRewriteTest(TEST_NAME2, false); + } + + private void runRewriteTest(String testName, boolean rewriteEnabled) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(testName); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testName + ".dml"; + fullRScriptName = HOME + testName + ".R"; + programArgs = new String[]{"-stats", "-args", input("A"), input("a"), input("b"), output("R")}; + rCmd = getRCmd(inputDir(), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewriteEnabled; + + double[][] A = getRandomMatrix(rows, cols, -100, 100, 0.9, 3); + double[][] a = getRandomMatrix(1, 1, -10, 10, 1.0, 7); + double[][] b = getRandomMatrix(1, 1, -10, 10, 1.0, 5); + + writeInputMatrixWithMTD("A", A, true); + writeInputMatrixWithMTD("a", a, true); + writeInputMatrixWithMTD("b", b, true); + + runTest(true, false, null, -1); + runRScript(true); + + HashMap dml = readDMLMatrixFromOutputDir("R"); + HashMap r = readRMatrixFromExpectedDir("R"); + + Assert.assertEquals("DML and R outputs do not match", r, dml); + } finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} + diff --git a/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.R b/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.R new file mode 100644 index 00000000000..bd9ab23ed2e --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.R @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +args <- commandArgs(TRUE) +library("Matrix") + +A <- as.matrix(readMM(paste(args[1], "A.mtx", sep=""))) +a <- as.numeric(readMM(paste(args[1], "a.mtx", sep=""))) +b <- as.numeric(readMM(paste(args[1], "b.mtx", sep=""))) + +R <- (a-b)-A + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")) diff --git a/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.dml b/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.dml new file mode 100644 index 00000000000..28cdb61dec0 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +A = read($1); +a = read($2); +b = read($3); + +R = a - A - b; + +write(R, $4); + diff --git a/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.R b/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.R new file mode 100644 index 00000000000..ec2764bb282 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.R @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +args <- commandArgs(TRUE) +library("Matrix") + +A <- as.matrix(readMM(paste(args[1], "A.mtx", sep=""))) +a <- as.numeric(readMM(paste(args[1], "a.mtx", sep=""))) +b <- as.numeric(readMM(paste(args[1], "b.mtx", sep=""))) + +R <- (a-b)+A + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")) diff --git a/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.dml b/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.dml new file mode 100644 index 00000000000..5ba04566efc --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +A = read($1); +a = as.scalar(read($2)); +b = as.scalar(read($3)); + +# Original form: a + A - b +R = a + A - b; + +write(R, $4);