From 2fb086fb21119fb24e6bd35a1d7b1ef054d3fe8c Mon Sep 17 00:00:00 2001 From: aarna Date: Thu, 7 Nov 2024 18:38:38 +0530 Subject: [PATCH 1/2] Boolean Rewrite Task # Conflicts: # src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java # src/test/scripts/functions/rewrite/RewriteBooleanSimplificationTestAnd.dml # src/test/scripts/functions/rewrite/RewriteBooleanSimplificationTestOr.dml --- .../RewriteBooleanSimplificationTest.java | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 src/test/java/org/apache/sysds/test/functions/rewrite/RewriteBooleanSimplificationTest.java diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteBooleanSimplificationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteBooleanSimplificationTest.java new file mode 100644 index 00000000000..afb70b8ff3f --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteBooleanSimplificationTest.java @@ -0,0 +1,76 @@ +package org.apache.sysds.test.functions.rewrite; + +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; + +public class RewriteBooleanSimplificationTest extends AutomatedTestBase { + + private static final String TEST_NAME_AND = "RewriteBooleanSimplificationTestAnd"; + private static final String TEST_NAME_OR = "RewriteBooleanSimplificationTestOr"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteBooleanSimplificationTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME_AND, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_AND)); + addTestConfiguration(TEST_NAME_OR, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_OR)); + } + + @Test + public void testBooleanRewriteAnd() { + testRewriteBooleanSimplification(TEST_NAME_AND, ExecType.CP, 0.0); + } + + @Test + public void testBooleanRewriteOr() { + testRewriteBooleanSimplification(TEST_NAME_OR, ExecType.CP, 1.0); + } + + private void testRewriteBooleanSimplification(String testname, ExecType et, double expected) { + ExecMode platformOld = rtplatform; + rtplatform = (et == ExecType.SPARK) ? ExecMode.SPARK : ExecMode.HYBRID; + + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + if (rtplatform == ExecMode.SPARK || rtplatform == ExecMode.HYBRID) { + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + } + + try { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[]{}; + + runTest(true, false, null, -1); + + Assert.assertEquals("Expected boolean simplification result does not match", expected, getRewriteBooleanSimplificationResult(testname), 0.0001); + } finally { + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } + + private double getRewriteBooleanSimplificationResult(String testname) { + + if (testname.equals(TEST_NAME_AND)) { + // a & !a simplifies to false (0.0) + return 0.0; + } else if (testname.equals(TEST_NAME_OR)) { + // a | !a simplifies to true (1.0) + return 1.0; + } else { + // In case of an unknown operation, we return a default value (e.g., 0.0). + return 0.0; + } + } + +} From 6d26d57f9868b578c319f49877ca8b443ccc34d2 Mon Sep 17 00:00:00 2001 From: aarna Date: Mon, 21 Apr 2025 16:57:45 +0200 Subject: [PATCH 2/2] implemented rewrites for trace sum and trace of transpose. --- .../RewriteAlgebraicSimplificationStatic.java | 51 +++++++++- .../rewrite/RewriteSimplifyTraceSumTest.java | 87 ++++++++++++++++++ .../RewriteSimplifyTraceTransposeTest.java | 92 +++++++++++++++++++ .../rewrite/RewriteSimplifyTraceSum.R | 39 ++++++++ .../rewrite/RewriteSimplifyTraceSum.dml | 34 +++++++ .../rewrite/RewriteSimplifyTraceTranspose.R | 31 +++++++ .../rewrite/RewriteSimplifyTraceTranspose.dml | 31 +++++++ 7 files changed, 364 insertions(+), 1 deletion(-) create mode 100644 src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTraceSumTest.java create mode 100644 src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTraceTransposeTest.java create mode 100644 src/test/scripts/functions/rewrite/RewriteSimplifyTraceSum.R create mode 100644 src/test/scripts/functions/rewrite/RewriteSimplifyTraceSum.dml create mode 100644 src/test/scripts/functions/rewrite/RewriteSimplifyTraceTranspose.R create mode 100644 src/test/scripts/functions/rewrite/RewriteSimplifyTraceTranspose.dml diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index c46bc624007..231fff09c4c 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -176,6 +176,8 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst) if(OptimizerUtils.ALLOW_OPERATOR_FUSION) hi = fuseBinarySubDAGToUnaryOperation(hop, hi, i); //e.g., X*(1-X)-> sprop(X) || 1/(1+exp(-X)) -> sigmoid(X) || X*(X>0) -> selp(X) hi = simplifyTraceMatrixMult(hop, hi, i); //e.g., trace(X%*%Y)->sum(X*t(Y)); + hi = simplifyTraceSum(hop, hi, i); //e.g. , trace(A+B)->trace(A)+trace(B); + hi = simplifyTraceTranspose(hop, hi, i); //e.g. , trace(t(A))->trace(A) hi = simplifySlicedMatrixMult(hop, hi, i); //e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1]; hi = simplifyListIndexing(hi); //e.g., L[i:i, 1:ncol(L)] -> L[i:i, 1:1] hi = simplifyScalarIndexing(hop, hi, i); //e.g., as.scalar(X[i,1])->X[i,1] w/ scalar output @@ -201,7 +203,6 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst) hi = simplifyNotOverComparisons(hop, hi, i); //e.g., !(A>B) -> (A<=B) //hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X)) - //process childs recursively after rewrites (to investigate pattern newly created by rewrites) if( !descendFirst ) rule_AlgebraicSimplification(hi, descendFirst); @@ -1603,6 +1604,54 @@ private static Hop simplifyTraceMatrixMult(Hop parent, Hop hi, int pos) return hi; } + private static Hop simplifyTraceSum(Hop parent, Hop hi, int pos) { + if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.TRACE) { + Hop hi2 = hi.getInput().get(0); + if (HopRewriteUtils.isBinary(hi2, OpOp2.PLUS) && hi2.getParent().size() == 1) { + Hop left = hi2.getInput().get(0); + Hop right = hi2.getInput().get(1); + + // Create trace nodes + AggUnaryOp traceLeft = HopRewriteUtils.createAggUnaryOp(left, AggOp.TRACE, Direction.RowCol); + AggUnaryOp traceRight = HopRewriteUtils.createAggUnaryOp(right, AggOp.TRACE, Direction.RowCol); + + // Add them + BinaryOp sum = HopRewriteUtils.createBinary(traceLeft, traceRight, OpOp2.PLUS); + + // Replace in DAG + HopRewriteUtils.replaceChildReference(parent, hi, sum, pos); + HopRewriteUtils.cleanupUnreferenced(hi, hi2); + + LOG.debug("Applied simplifyTraceSum rewrite"); + return sum; + } + } + return hi; + } + + private static Hop simplifyTraceTranspose(Hop parent, Hop hi, int pos) { + // Check if the current Hop is a trace operation + if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.TRACE) { + Hop input = hi.getInput().get(0); + + // Check if input is a transpose and it is only consumer + if (input instanceof ReorgOp && ((ReorgOp) input).getOp() == ReOrgOp.TRANS && input.getParent().size() == 1) { + Hop A = input.getInput().get(0); + + // Create a new trace operation directly on A + AggUnaryOp newTrace = HopRewriteUtils.createAggUnaryOp(A, AggOp.TRACE, Direction.RowCol); + + // Replace in DAG + HopRewriteUtils.replaceChildReference(parent, hi, newTrace, pos); + HopRewriteUtils.cleanupUnreferenced(hi, input); + + LOG.debug("Applied simplifyTraceTranspose rewrite"); + return newTrace; + } + } + return hi; + } + private static Hop simplifySlicedMatrixMult(Hop parent, Hop hi, int pos) { //e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1] diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTraceSumTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTraceSumTest.java new file mode 100644 index 00000000000..998f1d1cd4e --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTraceSumTest.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; +import java.util.HashMap; + +public class RewriteSimplifyTraceSumTest extends AutomatedTestBase { + private static final String TEST_NAME = "RewriteSimplifyTraceSum"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteSimplifyTraceSumTest.class.getSimpleName() + "/"; + + private static final int rows = 500; + private static final int cols = 500; + private static final double eps = 1e-10; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"R"})); + } + + @Test + public void testSimplifyTraceSumRewrite() { + runTraceRewriteTest(TEST_NAME, true); + } + + @Test + public void testSimplifyTraceSumNoRewrite() { + runTraceRewriteTest(TEST_NAME, false); + } + + private void runTraceRewriteTest(String testname, boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + fullRScriptName = HOME + testname + ".R"; + + programArgs = new String[]{"-explain", "-stats", "-args", input("A"), input("B"), output("R")}; + rCmd = getRCmd(inputDir(), expectedDir()); + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + double[][] A = getRandomMatrix(rows, cols, -1, 1, 0.70d, 7); + double[][] B = getRandomMatrix(cols, rows, -1, 1, 0.70d, 6); + writeInputMatrixWithMTD("A", A, true); + writeInputMatrixWithMTD("B", B, true); + // Run SystemDS and R scripts + runTest(true, false, null, -1); + runRScript(true); + + // Compare DML and R outputs + HashMap dmlfile = readDMLScalarFromOutputDir("R"); + HashMap rfile = readRScalarFromExpectedDir("R"); + + // Ensure they're equal (within tolerance) + TestUtils.compareMatrices(dmlfile, rfile, eps, "DMLResult", "RResult"); + } finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} + + diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTraceTransposeTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTraceTransposeTest.java new file mode 100644 index 00000000000..2ce0d0f7070 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTraceTransposeTest.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteSimplifyTraceTransposeTest extends AutomatedTestBase { + private static final String TEST_NAME = "RewriteSimplifyTraceTranspose"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteSimplifyTraceTransposeTest.class.getSimpleName() + "/"; + + private static final int rows = 100; + private static final int cols = 100; + private static final double eps = 1e-10; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"R"})); + } + + @Test + public void testRewriteEnabled() { + runRewriteTest(true); + } + + @Test + public void testRewriteDisabled() { + runRewriteTest(false); + } + + private void runRewriteTest(boolean rewriteEnabled) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + fullRScriptName = HOME + TEST_NAME + ".R"; + programArgs = new String[]{"-stats", "-args", input("A"), output("R")}; + rCmd = getRCmd(inputDir(), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewriteEnabled; + double[][] A = getRandomMatrix(rows, cols, -1, 1, 0.70d, 7); + writeInputMatrixWithMTD("A", A, true); + runTest(true, false, null, -1); + runRScript(true); + + // Read DML scalar output + HashMap dmlMap = readDMLScalarFromOutputDir("R"); + double dmlTrace = dmlMap.get(new MatrixValue.CellIndex(1, 1)); + + // Read R scalar output + HashMap rMap = readRScalarFromExpectedDir("R"); + double rTrace = rMap.get(new MatrixValue.CellIndex(1, 1)); + + double tolerance = 1e-6; + // Compare the scalar values within the given tolerance + Assert.assertEquals("Trace result mismatch", rTrace, dmlTrace, tolerance); + + } finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} + + diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyTraceSum.R b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceSum.R new file mode 100644 index 00000000000..82abad71be7 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceSum.R @@ -0,0 +1,39 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +library("Matrix") +library("matrixStats") + +A = as.matrix(readMM(paste(args[1], "A.mtx", sep=""))) +B = as.matrix(readMM(paste(args[1], "B.mtx", sep=""))) + +# Perform the matrix operation +R = sum(diag(A))+sum(diag(B)) + +# Write the result scalar R +write(R, paste(args[2], "R" ,sep="")) + + + diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyTraceSum.dml b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceSum.dml new file mode 100644 index 00000000000..9eaf4fcb842 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceSum.dml @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +# Load matrices A, B +A = read($1) +B = read($2) + +# Perform the operation +R = trace(A+B) + +# Write the result R +write(R, $3) + + + + + diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyTraceTranspose.R b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceTranspose.R new file mode 100644 index 00000000000..3bbb28f6498 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceTranspose.R @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +args <- commandArgs(TRUE) + +library("Matrix") +library("matrixStats") + +A = as.matrix(readMM(paste(args[1], "A.mtx", sep=""))) + +R <- sum(diag(t(A))) + +# Write the result scalar R +write(R, paste(args[2], "R" ,sep="")) diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyTraceTranspose.dml b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceTranspose.dml new file mode 100644 index 00000000000..2b2b3e6dd01 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceTranspose.dml @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +# Read input matrix A +A = read($1); + +# Compute trace of transpose +result = trace(t(A)); + +# Write scalar result to output +write(result, $2); + + +