Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
if(OptimizerUtils.ALLOW_OPERATOR_FUSION)
hi = fuseBinarySubDAGToUnaryOperation(hop, hi, i); //e.g., X*(1-X)-> sprop(X) || 1/(1+exp(-X)) -> sigmoid(X) || X*(X>0) -> selp(X)
hi = simplifyTraceMatrixMult(hop, hi, i); //e.g., trace(X%*%Y)->sum(X*t(Y));
hi = simplifyTraceSum(hop, hi, i); //e.g. , trace(A+B)->trace(A)+trace(B);
hi = simplifyTraceTranspose(hop, hi, i); //e.g. , trace(t(A))->trace(A)
hi = simplifySlicedMatrixMult(hop, hi, i); //e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1];
hi = simplifyListIndexing(hi); //e.g., L[i:i, 1:ncol(L)] -> L[i:i, 1:1]
hi = simplifyScalarIndexing(hop, hi, i); //e.g., as.scalar(X[i,1])->X[i,1] w/ scalar output
Expand All @@ -201,7 +203,6 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
hi = simplifyNotOverComparisons(hop, hi, i); //e.g., !(A>B) -> (A<=B)
//hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))


//process childs recursively after rewrites (to investigate pattern newly created by rewrites)
if( !descendFirst )
rule_AlgebraicSimplification(hi, descendFirst);
Expand Down Expand Up @@ -1603,6 +1604,54 @@ private static Hop simplifyTraceMatrixMult(Hop parent, Hop hi, int pos)
return hi;
}

private static Hop simplifyTraceSum(Hop parent, Hop hi, int pos) {
if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.TRACE) {
Hop hi2 = hi.getInput().get(0);
if (HopRewriteUtils.isBinary(hi2, OpOp2.PLUS) && hi2.getParent().size() == 1) {
Hop left = hi2.getInput().get(0);
Hop right = hi2.getInput().get(1);

// Create trace nodes
AggUnaryOp traceLeft = HopRewriteUtils.createAggUnaryOp(left, AggOp.TRACE, Direction.RowCol);
AggUnaryOp traceRight = HopRewriteUtils.createAggUnaryOp(right, AggOp.TRACE, Direction.RowCol);

// Add them
BinaryOp sum = HopRewriteUtils.createBinary(traceLeft, traceRight, OpOp2.PLUS);

// Replace in DAG
HopRewriteUtils.replaceChildReference(parent, hi, sum, pos);
HopRewriteUtils.cleanupUnreferenced(hi, hi2);

LOG.debug("Applied simplifyTraceSum rewrite");
return sum;
}
}
return hi;
}

private static Hop simplifyTraceTranspose(Hop parent, Hop hi, int pos) {
// Check if the current Hop is a trace operation
if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.TRACE) {
Hop input = hi.getInput().get(0);

// Check if input is a transpose and it is only consumer
if (input instanceof ReorgOp && ((ReorgOp) input).getOp() == ReOrgOp.TRANS && input.getParent().size() == 1) {
Hop A = input.getInput().get(0);

// Create a new trace operation directly on A
AggUnaryOp newTrace = HopRewriteUtils.createAggUnaryOp(A, AggOp.TRACE, Direction.RowCol);

// Replace in DAG
HopRewriteUtils.replaceChildReference(parent, hi, newTrace, pos);
HopRewriteUtils.cleanupUnreferenced(hi, input);

LOG.debug("Applied simplifyTraceTranspose rewrite");
return newTrace;
}
}
return hi;
}

private static Hop simplifySlicedMatrixMult(Hop parent, Hop hi, int pos)
{
//e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package org.apache.sysds.test.functions.rewrite;

import org.junit.Assert;
import org.junit.Test;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;

public class RewriteBooleanSimplificationTest extends AutomatedTestBase {

private static final String TEST_NAME_AND = "RewriteBooleanSimplificationTestAnd";
private static final String TEST_NAME_OR = "RewriteBooleanSimplificationTestOr";
private static final String TEST_DIR = "functions/rewrite/";
private static final String TEST_CLASS_DIR = TEST_DIR + RewriteBooleanSimplificationTest.class.getSimpleName() + "/";

@Override
public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME_AND, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_AND));
addTestConfiguration(TEST_NAME_OR, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_OR));
}

@Test
public void testBooleanRewriteAnd() {
testRewriteBooleanSimplification(TEST_NAME_AND, ExecType.CP, 0.0);
}

@Test
public void testBooleanRewriteOr() {
testRewriteBooleanSimplification(TEST_NAME_OR, ExecType.CP, 1.0);
}

private void testRewriteBooleanSimplification(String testname, ExecType et, double expected) {
ExecMode platformOld = rtplatform;
rtplatform = (et == ExecType.SPARK) ? ExecMode.SPARK : ExecMode.HYBRID;

boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
if (rtplatform == ExecMode.SPARK || rtplatform == ExecMode.HYBRID) {
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
}

try {
TestConfiguration config = getTestConfiguration(testname);
loadTestConfiguration(config);

String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
programArgs = new String[]{};

runTest(true, false, null, -1);

Assert.assertEquals("Expected boolean simplification result does not match", expected, getRewriteBooleanSimplificationResult(testname), 0.0001);
} finally {
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
}

private double getRewriteBooleanSimplificationResult(String testname) {

if (testname.equals(TEST_NAME_AND)) {
// a & !a simplifies to false (0.0)
return 0.0;
} else if (testname.equals(TEST_NAME_OR)) {
// a | !a simplifies to true (1.0)
return 1.0;
} else {
// In case of an unknown operation, we return a default value (e.g., 0.0).
return 0.0;
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.test.functions.rewrite;

import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Test;
import java.util.HashMap;

public class RewriteSimplifyTraceSumTest extends AutomatedTestBase {
private static final String TEST_NAME = "RewriteSimplifyTraceSum";
private static final String TEST_DIR = "functions/rewrite/";
private static final String TEST_CLASS_DIR = TEST_DIR + RewriteSimplifyTraceSumTest.class.getSimpleName() + "/";

private static final int rows = 500;
private static final int cols = 500;
private static final double eps = 1e-10;

@Override
public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"R"}));
}

@Test
public void testSimplifyTraceSumRewrite() {
runTraceRewriteTest(TEST_NAME, true);
}

@Test
public void testSimplifyTraceSumNoRewrite() {
runTraceRewriteTest(TEST_NAME, false);
}

private void runTraceRewriteTest(String testname, boolean rewrites) {
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
try {
TestConfiguration config = getTestConfiguration(testname);
loadTestConfiguration(config);

String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
fullRScriptName = HOME + testname + ".R";

programArgs = new String[]{"-explain", "-stats", "-args", input("A"), input("B"), output("R")};
rCmd = getRCmd(inputDir(), expectedDir());
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
double[][] A = getRandomMatrix(rows, cols, -1, 1, 0.70d, 7);
double[][] B = getRandomMatrix(cols, rows, -1, 1, 0.70d, 6);
writeInputMatrixWithMTD("A", A, true);
writeInputMatrixWithMTD("B", B, true);
// Run SystemDS and R scripts
runTest(true, false, null, -1);
runRScript(true);

// Compare DML and R outputs
HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLScalarFromOutputDir("R");
HashMap<MatrixValue.CellIndex, Double> rfile = readRScalarFromExpectedDir("R");

// Ensure they're equal (within tolerance)
TestUtils.compareMatrices(dmlfile, rfile, eps, "DMLResult", "RResult");
} finally {
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
}
}
}


Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.test.functions.rewrite;

import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Assert;
import org.junit.Test;

import java.util.HashMap;

public class RewriteSimplifyTraceTransposeTest extends AutomatedTestBase {
private static final String TEST_NAME = "RewriteSimplifyTraceTranspose";
private static final String TEST_DIR = "functions/rewrite/";
private static final String TEST_CLASS_DIR = TEST_DIR + RewriteSimplifyTraceTransposeTest.class.getSimpleName() + "/";

private static final int rows = 100;
private static final int cols = 100;
private static final double eps = 1e-10;

@Override
public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"R"}));
}

@Test
public void testRewriteEnabled() {
runRewriteTest(true);
}

@Test
public void testRewriteDisabled() {
runRewriteTest(false);
}

private void runRewriteTest(boolean rewriteEnabled) {
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
try {
TestConfiguration config = getTestConfiguration(TEST_NAME);
loadTestConfiguration(config);

String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
fullRScriptName = HOME + TEST_NAME + ".R";
programArgs = new String[]{"-stats", "-args", input("A"), output("R")};
rCmd = getRCmd(inputDir(), expectedDir());

OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewriteEnabled;
double[][] A = getRandomMatrix(rows, cols, -1, 1, 0.70d, 7);
writeInputMatrixWithMTD("A", A, true);
runTest(true, false, null, -1);
runRScript(true);

// Read DML scalar output
HashMap<MatrixValue.CellIndex, Double> dmlMap = readDMLScalarFromOutputDir("R");
double dmlTrace = dmlMap.get(new MatrixValue.CellIndex(1, 1));

// Read R scalar output
HashMap<MatrixValue.CellIndex, Double> rMap = readRScalarFromExpectedDir("R");
double rTrace = rMap.get(new MatrixValue.CellIndex(1, 1));

double tolerance = 1e-6;
// Compare the scalar values within the given tolerance
Assert.assertEquals("Trace result mismatch", rTrace, dmlTrace, tolerance);

} finally {
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
}
}
}


39 changes: 39 additions & 0 deletions src/test/scripts/functions/rewrite/RewriteSimplifyTraceSum.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
args <- commandArgs(TRUE)

# Set options for numeric precision
options(digits=22)

library("Matrix")
library("matrixStats")

A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
B = as.matrix(readMM(paste(args[1], "B.mtx", sep="")))

# Perform the matrix operation
R = sum(diag(A))+sum(diag(B))

# Write the result scalar R
write(R, paste(args[2], "R" ,sep=""))



Loading