From 77c398ac2356e00ac70bf18b102e5f1913e15c80 Mon Sep 17 00:00:00 2001 From: aarna Date: Tue, 10 Jun 2025 15:49:00 +0530 Subject: [PATCH 1/5] implemented simplification of scalar matrix scalar addition and subtraction operations --- .../RewriteAlgebraicSimplificationStatic.java | 58 ++++++++++ ...teSimplifyScalarMatrixPMOperationTest.java | 101 ++++++++++++++++++ .../RewriteScalarMinusMatrixMinusScalar.R | 30 ++++++ .../RewriteScalarMinusMatrixMinusScalar.dml | 28 +++++ .../RewriteScalarPlusMatrixMinusScalar.R | 30 ++++++ .../RewriteScalarPlusMatrixMinusScalar.dml | 28 +++++ 6 files changed, 275 insertions(+) create mode 100644 src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyScalarMatrixPMOperationTest.java create mode 100644 src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.R create mode 100644 src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.dml create mode 100644 src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.R create mode 100644 src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.dml diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index 65c8805c7ce..c1461469251 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -202,6 +202,7 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst) hi = simplifyNegatedSubtraction(hop, hi, i); //e.g., -(B-A)->A-B hi = simplifyTransposeAddition(hop, hi, i); //e.g., t(A+s1)+s2 -> t(A)+(s1+s2) + potential constant folding hi = simplifyNotOverComparisons(hop, hi, i); //e.g., !(A>B) -> (A<=B) + hi = simplifyMatrixScalarPMOperation(hop, hi, i); //e.g., a-A-b -> (a-b)-A; a+A-b -> (a-b)+A //hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X)) //process childs recursively after rewrites (to investigate pattern newly created by rewrites) @@ -212,6 +213,63 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst) hop.setVisited(); } + private Hop simplifyMatrixScalarPMOperation(Hop parent, Hop hi, int pos) { + if (!(hi instanceof BinaryOp)) + return hi; + + BinaryOp outer = (BinaryOp) hi; + Hop left = outer.getInput().get(0); + Hop right = outer.getInput().get(1); + OpOp2 outerOp = outer.getOp(); + + if ((outerOp != OpOp2.PLUS && outerOp != OpOp2.MINUS) || !(left instanceof BinaryOp)) + return hi; + + BinaryOp inner = (BinaryOp) left; + Hop a = inner.getInput().get(0); + Hop A = inner.getInput().get(1); + Hop b = right; + OpOp2 innerOp = inner.getOp(); + + // Only consider flat expressions: a op1 A op2 b where a, b are scalars, A is matrix + java.util.function.Predicate isScalar = h -> h.getDataType().isScalar(); + if (!isScalar.test(a) || !isScalar.test(b) || A.getDataType() != DataType.MATRIX) + return hi; + + BinaryOp scalarCombined = null; + BinaryOp result = null; + + // Rewrite cases + if (innerOp == OpOp2.MINUS && outerOp == OpOp2.MINUS) { + // a - A - b => (a - b) - A + scalarCombined = HopRewriteUtils.createBinary(a, b, OpOp2.MINUS); + result = HopRewriteUtils.createBinary(scalarCombined, A, OpOp2.MINUS); + } + else if (innerOp == OpOp2.PLUS && outerOp == OpOp2.MINUS) { + // a + A - b => (a - b) + A + scalarCombined = HopRewriteUtils.createBinary(a, b, OpOp2.MINUS); + result = HopRewriteUtils.createBinary(scalarCombined, A, OpOp2.PLUS); + } + else if (innerOp == OpOp2.MINUS && outerOp == OpOp2.PLUS) { + // a - A + b => (a + b) - A + scalarCombined = HopRewriteUtils.createBinary(a, b, OpOp2.PLUS); + result = HopRewriteUtils.createBinary(scalarCombined, A, OpOp2.MINUS); + } + else if (innerOp == OpOp2.PLUS && outerOp == OpOp2.PLUS) { + // a + A + b => (a + b) + A + scalarCombined = HopRewriteUtils.createBinary(a, b, OpOp2.PLUS); + result = HopRewriteUtils.createBinary(scalarCombined, A, OpOp2.PLUS); + } + + if (result != null) { + HopRewriteUtils.replaceChildReference(parent, hi, result, pos); + LOG.debug("Applied simplifyMatrixScalarPMOperation"); + return result; + } + + return hi; + } + private static Hop simplifyTransposeAddition(Hop parent, Hop hi, int pos) { //pattern: t(A+s1)+s2 -> t(A)+(s1+s2), and subsequent constant folding if (HopRewriteUtils.isBinary(hi, OpOp2.PLUS) diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyScalarMatrixPMOperationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyScalarMatrixPMOperationTest.java new file mode 100644 index 00000000000..4474cc4a8d8 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyScalarMatrixPMOperationTest.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteSimplifyScalarMatrixPMOperationTest extends AutomatedTestBase { + private static final String TEST_NAME1 = "RewriteScalarMinusMatrixMinusScalar"; + private static final String TEST_NAME2 = "RewriteScalarPlusMatrixMinusScalar"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteSimplifyScalarMatrixPMOperationTest.class.getSimpleName() + "/"; + private static final int rows = 10; + private static final int cols = 10; + private static final double eps = 1e-6; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"A", "a", "b", "R"})); + addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"A", "a", "b", "R"})); + } + + @Test + public void testScalarMinusMatrixMinusScalarRewriteEnabled() { + runRewriteTest(TEST_NAME1, true); + } + + @Test + public void testScalarMinusMatrixMinusScalarRewriteDisabled() { + runRewriteTest(TEST_NAME1, false); + } + + @Test + public void testScalarPlusMatrixMinusScalarRewriteEnabled() { + runRewriteTest(TEST_NAME2, true); + } + + @Test + public void testScalarPlusMatrixMinusScalarRewriteDisabled() { + runRewriteTest(TEST_NAME2, false); + } + + private void runRewriteTest(String testName, boolean rewriteEnabled) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(testName); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testName + ".dml"; + fullRScriptName = HOME + testName + ".R"; + programArgs = new String[]{"-stats", "-args", input("A"), input("a"), input("b"), output("R")}; + rCmd = getRCmd(inputDir(), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewriteEnabled; + + double[][] A = getRandomMatrix(rows, cols, -10, 10, 0.9, 3); + double[][] a = getRandomMatrix(1, 1, -10, 10, 1.0, 7); + double[][] b = getRandomMatrix(1, 1, -10, 10, 1.0, 5); + + writeInputMatrixWithMTD("A", A, true); + writeInputMatrixWithMTD("a", a, true); + writeInputMatrixWithMTD("b", b, true); + + runTest(true, false, null, -1); + runRScript(true); + + HashMap dml = readDMLMatrixFromOutputDir("R"); + HashMap r = readRMatrixFromExpectedDir("R"); + + Assert.assertEquals("DML and R outputs do not match", r, dml); + } finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} + diff --git a/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.R b/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.R new file mode 100644 index 00000000000..1b56d8b9696 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.R @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +args <- commandArgs(TRUE) +library("Matrix") + +A <- as.matrix(readMM(paste(args[1], "A.mtx", sep=""))) +a <- as.numeric(readMM(paste(args[1], "a.mtx", sep=""))) +b <- as.numeric(readMM(paste(args[1], "b.mtx", sep=""))) + +R <- a - A - b + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")) diff --git a/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.dml b/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.dml new file mode 100644 index 00000000000..28cdb61dec0 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +A = read($1); +a = read($2); +b = read($3); + +R = a - A - b; + +write(R, $4); + diff --git a/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.R b/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.R new file mode 100644 index 00000000000..18593c7321e --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.R @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +args <- commandArgs(TRUE) +library("Matrix") + +A <- as.matrix(readMM(paste(args[1], "A.mtx", sep=""))) +a <- as.numeric(readMM(paste(args[1], "a.mtx", sep=""))) +b <- as.numeric(readMM(paste(args[1], "b.mtx", sep=""))) + +R <- a + A - b + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")) diff --git a/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.dml b/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.dml new file mode 100644 index 00000000000..5ba04566efc --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +A = read($1); +a = as.scalar(read($2)); +b = as.scalar(read($3)); + +# Original form: a + A - b +R = a + A - b; + +write(R, $4); From c2d6540a12d14936683d048a94b650401f008160 Mon Sep 17 00:00:00 2001 From: aarna Date: Tue, 10 Jun 2025 15:58:53 +0530 Subject: [PATCH 2/5] implemented simplification of scalar matrix scalar addition and subtraction operations --- .../RewriteAlgebraicSimplificationStatic.java | 45 ++++++++++--------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index c1461469251..651a26c6053 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -231,45 +231,46 @@ private Hop simplifyMatrixScalarPMOperation(Hop parent, Hop hi, int pos) { Hop b = right; OpOp2 innerOp = inner.getOp(); - // Only consider flat expressions: a op1 A op2 b where a, b are scalars, A is matrix + // Check for valid types: a and b must be scalar, A must be matrix java.util.function.Predicate isScalar = h -> h.getDataType().isScalar(); if (!isScalar.test(a) || !isScalar.test(b) || A.getDataType() != DataType.MATRIX) return hi; - BinaryOp scalarCombined = null; - BinaryOp result = null; + // Determine the scalarOp (between a and b) and matrixOp (with A) + OpOp2 scalarOp = null; + OpOp2 matrixOp = null; - // Rewrite cases if (innerOp == OpOp2.MINUS && outerOp == OpOp2.MINUS) { - // a - A - b => (a - b) - A - scalarCombined = HopRewriteUtils.createBinary(a, b, OpOp2.MINUS); - result = HopRewriteUtils.createBinary(scalarCombined, A, OpOp2.MINUS); + scalarOp = OpOp2.MINUS; + matrixOp = OpOp2.MINUS; } else if (innerOp == OpOp2.PLUS && outerOp == OpOp2.MINUS) { - // a + A - b => (a - b) + A - scalarCombined = HopRewriteUtils.createBinary(a, b, OpOp2.MINUS); - result = HopRewriteUtils.createBinary(scalarCombined, A, OpOp2.PLUS); + scalarOp = OpOp2.MINUS; + matrixOp = OpOp2.PLUS; } else if (innerOp == OpOp2.MINUS && outerOp == OpOp2.PLUS) { - // a - A + b => (a + b) - A - scalarCombined = HopRewriteUtils.createBinary(a, b, OpOp2.PLUS); - result = HopRewriteUtils.createBinary(scalarCombined, A, OpOp2.MINUS); + scalarOp = OpOp2.PLUS; + matrixOp = OpOp2.MINUS; } else if (innerOp == OpOp2.PLUS && outerOp == OpOp2.PLUS) { - // a + A + b => (a + b) + A - scalarCombined = HopRewriteUtils.createBinary(a, b, OpOp2.PLUS); - result = HopRewriteUtils.createBinary(scalarCombined, A, OpOp2.PLUS); + scalarOp = OpOp2.PLUS; + matrixOp = OpOp2.PLUS; } - - if (result != null) { - HopRewriteUtils.replaceChildReference(parent, hi, result, pos); - LOG.debug("Applied simplifyMatrixScalarPMOperation"); - return result; + else { + // No valid pattern + return hi; } - return hi; + // Create and replace the rewritten expression + Hop scalarCombined = HopRewriteUtils.createBinary(a, b, scalarOp); + Hop result = HopRewriteUtils.createBinary(scalarCombined, A, matrixOp); + + HopRewriteUtils.replaceChildReference(parent, hi, result, pos); + LOG.debug("Applied simplifyMatrixScalarPMOperation"); + return result; } + private static Hop simplifyTransposeAddition(Hop parent, Hop hi, int pos) { //pattern: t(A+s1)+s2 -> t(A)+(s1+s2), and subsequent constant folding if (HopRewriteUtils.isBinary(hi, OpOp2.PLUS) From 2d6964e2d778248d88302b8456245acd261d639c Mon Sep 17 00:00:00 2001 From: aarna Date: Tue, 10 Jun 2025 15:59:59 +0530 Subject: [PATCH 3/5] implemented simplification of scalar matrix scalar addition and subtraction operations --- .../hops/rewrite/RewriteAlgebraicSimplificationStatic.java | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index 651a26c6053..00e5ae2b748 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -231,7 +231,6 @@ private Hop simplifyMatrixScalarPMOperation(Hop parent, Hop hi, int pos) { Hop b = right; OpOp2 innerOp = inner.getOp(); - // Check for valid types: a and b must be scalar, A must be matrix java.util.function.Predicate isScalar = h -> h.getDataType().isScalar(); if (!isScalar.test(a) || !isScalar.test(b) || A.getDataType() != DataType.MATRIX) return hi; @@ -257,11 +256,9 @@ else if (innerOp == OpOp2.PLUS && outerOp == OpOp2.PLUS) { matrixOp = OpOp2.PLUS; } else { - // No valid pattern return hi; } - // Create and replace the rewritten expression Hop scalarCombined = HopRewriteUtils.createBinary(a, b, scalarOp); Hop result = HopRewriteUtils.createBinary(scalarCombined, A, matrixOp); From 74df742a6ba58a29e2e6dd0035e8971fbe278ecb Mon Sep 17 00:00:00 2001 From: aarna Date: Wed, 11 Jun 2025 21:25:35 +0530 Subject: [PATCH 4/5] corrected lines that may have caused an error --- .../RewriteAlgebraicSimplificationStatic.java | 1 - ...teSimplifyScalarMatrixPMOperationTest.java | 22 +++++++++++++++---- .../RewriteScalarMinusMatrixMinusScalar.R | 2 +- .../RewriteScalarPlusMatrixMinusScalar.R | 2 +- 4 files changed, 20 insertions(+), 7 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index 00e5ae2b748..f07c4c0a147 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -267,7 +267,6 @@ else if (innerOp == OpOp2.PLUS && outerOp == OpOp2.PLUS) { return result; } - private static Hop simplifyTransposeAddition(Hop parent, Hop hi, int pos) { //pattern: t(A+s1)+s2 -> t(A)+(s1+s2), and subsequent constant folding if (HopRewriteUtils.isBinary(hi, OpOp2.PLUS) diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyScalarMatrixPMOperationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyScalarMatrixPMOperationTest.java index 4474cc4a8d8..012b227b0eb 100644 --- a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyScalarMatrixPMOperationTest.java +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyScalarMatrixPMOperationTest.java @@ -18,6 +18,7 @@ */ package org.apache.sysds.test.functions.rewrite; +import org.apache.sysds.common.Types; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.runtime.matrix.data.MatrixValue; import org.apache.sysds.test.AutomatedTestBase; @@ -26,6 +27,9 @@ import org.junit.Assert; import org.junit.Test; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; import java.util.HashMap; public class RewriteSimplifyScalarMatrixPMOperationTest extends AutomatedTestBase { @@ -33,15 +37,15 @@ public class RewriteSimplifyScalarMatrixPMOperationTest extends AutomatedTestBas private static final String TEST_NAME2 = "RewriteScalarPlusMatrixMinusScalar"; private static final String TEST_DIR = "functions/rewrite/"; private static final String TEST_CLASS_DIR = TEST_DIR + RewriteSimplifyScalarMatrixPMOperationTest.class.getSimpleName() + "/"; - private static final int rows = 10; - private static final int cols = 10; + private static final int rows = 100; + private static final int cols = 100; private static final double eps = 1e-6; @Override public void setUp() { TestUtils.clearAssertionInformation(); addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"A", "a", "b", "R"})); - addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"A", "a", "b", "R"})); + addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"A", "a", "b", "R"})); } @Test @@ -78,13 +82,23 @@ private void runRewriteTest(String testName, boolean rewriteEnabled) { OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewriteEnabled; - double[][] A = getRandomMatrix(rows, cols, -10, 10, 0.9, 3); + double[][] A = getRandomMatrix(rows, cols, -100, 100, 0.9, 3); double[][] a = getRandomMatrix(1, 1, -10, 10, 1.0, 7); double[][] b = getRandomMatrix(1, 1, -10, 10, 1.0, 5); writeInputMatrixWithMTD("A", A, true); writeInputMatrixWithMTD("a", a, true); writeInputMatrixWithMTD("b", b, true); + // Add this to your test to read and print the metadata content: + try { + String amtdContent = new String(Files.readAllBytes(Paths.get(inputDir() + "A.mtd"))); + System.out.println("A.mtd content: " + amtdContent); + + String bmtdContent = new String(Files.readAllBytes(Paths.get(inputDir() + "b.mtd"))); + System.out.println("b.mtd content: " + bmtdContent); + } catch (IOException e) { + e.printStackTrace(); + } runTest(true, false, null, -1); runRScript(true); diff --git a/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.R b/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.R index 1b56d8b9696..bd9ab23ed2e 100644 --- a/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.R +++ b/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.R @@ -25,6 +25,6 @@ A <- as.matrix(readMM(paste(args[1], "A.mtx", sep=""))) a <- as.numeric(readMM(paste(args[1], "a.mtx", sep=""))) b <- as.numeric(readMM(paste(args[1], "b.mtx", sep=""))) -R <- a - A - b +R <- (a-b)-A writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")) diff --git a/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.R b/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.R index 18593c7321e..ec2764bb282 100644 --- a/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.R +++ b/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.R @@ -25,6 +25,6 @@ A <- as.matrix(readMM(paste(args[1], "A.mtx", sep=""))) a <- as.numeric(readMM(paste(args[1], "a.mtx", sep=""))) b <- as.numeric(readMM(paste(args[1], "b.mtx", sep=""))) -R <- a + A - b +R <- (a-b)+A writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")) From 30579f9db08625751e68f51099d3f7b958f16776 Mon Sep 17 00:00:00 2001 From: aarna Date: Wed, 11 Jun 2025 21:28:33 +0530 Subject: [PATCH 5/5] corrected lines that may have caused an error --- .../RewriteSimplifyScalarMatrixPMOperationTest.java | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyScalarMatrixPMOperationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyScalarMatrixPMOperationTest.java index 012b227b0eb..d4e8dbf7b2e 100644 --- a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyScalarMatrixPMOperationTest.java +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyScalarMatrixPMOperationTest.java @@ -89,16 +89,6 @@ private void runRewriteTest(String testName, boolean rewriteEnabled) { writeInputMatrixWithMTD("A", A, true); writeInputMatrixWithMTD("a", a, true); writeInputMatrixWithMTD("b", b, true); - // Add this to your test to read and print the metadata content: - try { - String amtdContent = new String(Files.readAllBytes(Paths.get(inputDir() + "A.mtd"))); - System.out.println("A.mtd content: " + amtdContent); - - String bmtdContent = new String(Files.readAllBytes(Paths.get(inputDir() + "b.mtd"))); - System.out.println("b.mtd content: " + bmtdContent); - } catch (IOException e) { - e.printStackTrace(); - } runTest(true, false, null, -1); runRScript(true);