From 90727eb147a2fc5388eb68f79045af6efc893ade Mon Sep 17 00:00:00 2001 From: aarna Date: Thu, 7 Nov 2024 18:38:38 +0530 Subject: [PATCH 1/3] Boolean Rewrite Task # Conflicts: # src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java # src/test/scripts/functions/rewrite/RewriteBooleanSimplificationTestAnd.dml # src/test/scripts/functions/rewrite/RewriteBooleanSimplificationTestOr.dml --- .../RewriteBooleanSimplificationTest.java | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 src/test/java/org/apache/sysds/test/functions/rewrite/RewriteBooleanSimplificationTest.java diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteBooleanSimplificationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteBooleanSimplificationTest.java new file mode 100644 index 00000000000..afb70b8ff3f --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteBooleanSimplificationTest.java @@ -0,0 +1,76 @@ +package org.apache.sysds.test.functions.rewrite; + +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; + +public class RewriteBooleanSimplificationTest extends AutomatedTestBase { + + private static final String TEST_NAME_AND = "RewriteBooleanSimplificationTestAnd"; + private static final String TEST_NAME_OR = "RewriteBooleanSimplificationTestOr"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteBooleanSimplificationTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME_AND, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_AND)); + addTestConfiguration(TEST_NAME_OR, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_OR)); + } + + @Test + public void testBooleanRewriteAnd() { + testRewriteBooleanSimplification(TEST_NAME_AND, ExecType.CP, 0.0); + } + + @Test + public void testBooleanRewriteOr() { + testRewriteBooleanSimplification(TEST_NAME_OR, ExecType.CP, 1.0); + } + + private void testRewriteBooleanSimplification(String testname, ExecType et, double expected) { + ExecMode platformOld = rtplatform; + rtplatform = (et == ExecType.SPARK) ? ExecMode.SPARK : ExecMode.HYBRID; + + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + if (rtplatform == ExecMode.SPARK || rtplatform == ExecMode.HYBRID) { + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + } + + try { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[]{}; + + runTest(true, false, null, -1); + + Assert.assertEquals("Expected boolean simplification result does not match", expected, getRewriteBooleanSimplificationResult(testname), 0.0001); + } finally { + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } + + private double getRewriteBooleanSimplificationResult(String testname) { + + if (testname.equals(TEST_NAME_AND)) { + // a & !a simplifies to false (0.0) + return 0.0; + } else if (testname.equals(TEST_NAME_OR)) { + // a | !a simplifies to true (1.0) + return 1.0; + } else { + // In case of an unknown operation, we return a default value (e.g., 0.0). + return 0.0; + } + } + +} From df0ceb297946972d45388641e81437786c893355 Mon Sep 17 00:00:00 2001 From: aarna Date: Tue, 1 Apr 2025 12:40:01 +0200 Subject: [PATCH 2/3] included dimension handling and logic for redundant transposes. --- .../org/apache/sysds/hops/AggBinaryOp.java | 487 +++++++++--------- .../rewrite/RewriteTransposeTest.java | 82 +++ .../functions/rewrite/RewriteTransposeCase1.R | 32 ++ .../rewrite/RewriteTransposeCase1.dml | 27 + .../functions/rewrite/RewriteTransposeCase2.R | 32 ++ .../rewrite/RewriteTransposeCase2.dml | 28 + .../functions/rewrite/RewriteTransposeCase3.R | 33 ++ .../rewrite/RewriteTransposeCase3.dml | 28 + 8 files changed, 496 insertions(+), 253 deletions(-) create mode 100644 src/test/java/org/apache/sysds/test/functions/rewrite/RewriteTransposeTest.java create mode 100644 src/test/scripts/functions/rewrite/RewriteTransposeCase1.R create mode 100644 src/test/scripts/functions/rewrite/RewriteTransposeCase1.dml create mode 100644 src/test/scripts/functions/rewrite/RewriteTransposeCase2.R create mode 100644 src/test/scripts/functions/rewrite/RewriteTransposeCase2.dml create mode 100644 src/test/scripts/functions/rewrite/RewriteTransposeCase3.R create mode 100644 src/test/scripts/functions/rewrite/RewriteTransposeCase3.dml diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java index 2cf651f1894..5f9c6b41b3a 100644 --- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java @@ -43,6 +43,7 @@ import org.apache.sysds.lops.PMMJ; import org.apache.sysds.lops.PMapMult; import org.apache.sysds.lops.Transform; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -65,7 +66,7 @@ public class AggBinaryOp extends MultiThreadedHop { public static final double MAPMULT_MEM_MULTIPLIER = 1.0; public static MMultMethod FORCED_MMULT_METHOD = null; - public enum MMultMethod { + public enum MMultMethod { CPMM, //cross-product matrix multiplication (mr) RMM, //replication matrix multiplication (mr) MAPMM_L, //map-side matrix-matrix multiplication using distributed cache (mr/sp) @@ -78,27 +79,27 @@ public enum MMultMethod { ZIPMM, //zip matrix multiplication (sp) MM //in-memory matrix multiplication (cp) } - - public enum SparkAggType{ + + public enum SparkAggType { NONE, SINGLE_BLOCK, MULTI_BLOCK, } - + private OpOp2 innerOp; private AggOp outerOp; private MMultMethod _method = null; - + //hints set by previous to operator selection private boolean _hasLeftPMInput = false; //left input is permutation matrix - + private AggBinaryOp() { //default constructor for clone } - + public AggBinaryOp(String l, DataType dt, ValueType vt, OpOp2 innOp, - AggOp outOp, Hop in1, Hop in2) { + AggOp outOp, Hop in1, Hop in2) { super(l, dt, vt); innerOp = innOp; outerOp = outOp; @@ -106,7 +107,7 @@ public AggBinaryOp(String l, DataType dt, ValueType vt, OpOp2 innOp, getInput().add(1, in2); in1.getParent().add(this); in2.getParent().add(this); - + //compute unknown dims and nnz refreshSizeInformation(); } @@ -114,30 +115,30 @@ public AggBinaryOp(String l, DataType dt, ValueType vt, OpOp2 innOp, public void setHasLeftPMInput(boolean flag) { _hasLeftPMInput = flag; } - - public boolean hasLeftPMInput(){ + + public boolean hasLeftPMInput() { return _hasLeftPMInput; } - public MMultMethod getMMultMethod(){ + public MMultMethod getMMultMethod() { return _method; } - + @Override public boolean isGPUEnabled() { - if(!DMLScript.USE_ACCELERATOR) + if (!DMLScript.USE_ACCELERATOR) return false; - + Hop input1 = getInput().get(0); Hop input2 = getInput().get(1); //matrix mult operation selection part 2 (specific pattern) MMTSJType mmtsj = checkTransposeSelf(); //determine tsmm pattern ChainType chain = checkMapMultChain(); //determine mmchain pattern - - _method = optFindMMultMethodCP ( input1.getDim1(), input1.getDim2(), - input2.getDim1(), input2.getDim2(), mmtsj, chain, _hasLeftPMInput ); - switch( _method ){ - case TSMM: + + _method = optFindMMultMethodCP(input1.getDim1(), input1.getDim2(), + input2.getDim1(), input2.getDim2(), mmtsj, chain, _hasLeftPMInput); + switch (_method) { + case TSMM: //return false; // TODO: Disabling any fused transa optimization in 1.0 release. return true; case MAPMM_CHAIN: @@ -150,50 +151,47 @@ public boolean isGPUEnabled() { throw new RuntimeException("Unsupported method:" + _method); } } - + /** * NOTE: overestimated mem in case of transpose-identity matmult, but 3/2 at worst - * and existing mem estimate advantageous in terms of consistency hops/lops, - * and some special cases internally materialize the transpose for better cache locality + * and existing mem estimate advantageous in terms of consistency hops/lops, + * and some special cases internally materialize the transpose for better cache locality */ @Override - public Lop constructLops() - { + public Lop constructLops() { //return already created lops - if( getLops() != null ) + if (getLops() != null) return getLops(); - + //construct matrix mult lops (currently only supported aggbinary) - if ( isMatrixMultiply() ) - { + if (isMatrixMultiply()) { Hop input1 = getInput().get(0); Hop input2 = getInput().get(1); - + //matrix mult operation selection part 1 (CP vs MR vs Spark) ExecType et = optFindExecType(); - + //matrix mult operation selection part 2 (specific pattern) MMTSJType mmtsj = checkTransposeSelf(); //determine tsmm pattern ChainType chain = checkMapMultChain(); //determine mmchain pattern - if(mmtsj == MMTSJType.LEFT && input2.isCompressedOutput()){ + if (mmtsj == MMTSJType.LEFT && input2.isCompressedOutput()) { // if tsmm and input is compressed. (using input2, since input1 is transposed and therefore not compressed.) et = ExecType.CP; } - if( et == ExecType.CP || et == ExecType.GPU || et == ExecType.FED ) - { + if (et == ExecType.CP || et == ExecType.GPU || et == ExecType.FED) { //matrix mult operation selection part 3 (CP type) - _method = optFindMMultMethodCP ( input1.getDim1(), input1.getDim2(), - input2.getDim1(), input2.getDim2(), mmtsj, chain, _hasLeftPMInput ); - + _method = optFindMMultMethodCP(input1.getDim1(), input1.getDim2(), + input2.getDim1(), input2.getDim2(), mmtsj, chain, _hasLeftPMInput); + //dispatch CP lops construction - switch( _method ){ - case TSMM: - constructCPLopsTSMM( mmtsj, et ); + switch (_method) { + case TSMM: + constructCPLopsTSMM(mmtsj, et); break; case MAPMM_CHAIN: - constructCPLopsMMChain( chain ); + constructCPLopsMMChain(chain); break; case PMM: constructCPLopsPMM(); @@ -204,53 +202,49 @@ public Lop constructLops() default: throw new HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" + _method + ") while constructing CP lops."); } - } - else if( et == ExecType.SPARK ) - { + } else if (et == ExecType.SPARK) { //matrix mult operation selection part 3 (SPARK type) boolean tmmRewrite = HopRewriteUtils.isTransposeOperation(input1); - _method = optFindMMultMethodSpark ( + _method = optFindMMultMethodSpark( input1.getDim1(), input1.getDim2(), input1.getBlocksize(), input1.getNnz(), input2.getDim1(), input2.getDim2(), input2.getBlocksize(), input2.getNnz(), - mmtsj, chain, _hasLeftPMInput, tmmRewrite ); + mmtsj, chain, _hasLeftPMInput, tmmRewrite); //dispatch SPARK lops construction - switch( _method ) - { + switch (_method) { case TSMM: - case TSMM2: - constructSparkLopsTSMM( mmtsj, _method==MMultMethod.TSMM2 ); + case TSMM2: + constructSparkLopsTSMM(mmtsj, _method == MMultMethod.TSMM2); break; case MAPMM_L: case MAPMM_R: - constructSparkLopsMapMM( _method ); + constructSparkLopsMapMM(_method); break; case MAPMM_CHAIN: - constructSparkLopsMapMMChain( chain ); + constructSparkLopsMapMMChain(chain); break; case PMAPMM: constructSparkLopsPMapMM(); break; - case CPMM: + case CPMM: constructSparkLopsCPMM(); break; - case RMM: + case RMM: constructSparkLopsRMM(); break; case PMM: - constructSparkLopsPMM(); + constructSparkLopsPMM(); break; case ZIPMM: constructSparkLopsZIPMM(); break; - + default: - throw new HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" + _method + ") while constructing SPARK lops."); + throw new HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" + _method + ") while constructing SPARK lops."); } } - } - else + } else throw new HopsException(this.printErrorLocation() + "Invalid operation in AggBinary Hop, aggBin(" + innerOp + "," + outerOp + ") while constructing lops."); - + //add reblock/checkpoint lops if necessary constructAndSetLopsDataFlowProperties(); @@ -260,30 +254,28 @@ else if( et == ExecType.SPARK ) @Override public String getOpString() { //ba - binary aggregate, for consistency with runtime - return "ba(" + outerOp.toString() + innerOp.toString()+")"; + return "ba(" + outerOp.toString() + innerOp.toString() + ")"; } - + @Override - public void computeMemEstimate(MemoTable memo) - { + public void computeMemEstimate(MemoTable memo) { //extension of default compute memory estimate in order to //account for smaller tsmm memory requirements. super.computeMemEstimate(memo); - + //tsmm left is guaranteed to require only X but not t(X), while //tsmm right might have additional requirements to transpose X if sparse //NOTE: as a heuristic this correction is only applied if not a column vector because //most other vector operations require memory for at least two vectors (we aim for //consistency in order to prevent anomalies in parfor opt leading to small degree of par) MMTSJType mmtsj = checkTransposeSelf(); - if( mmtsj.isLeft() && getInput().get(1).dimsKnown() && getInput().get(1).getDim2()>1 ) { + if (mmtsj.isLeft() && getInput().get(1).dimsKnown() && getInput().get(1).getDim2() > 1) { _memEstimate = _memEstimate - getInput().get(0)._outputMemEstimate; } } @Override - protected double computeOutputMemEstimate( long dim1, long dim2, long nnz ) - { + protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) { //NOTES: // * The estimate for transpose-self is the same as for normal matrix multiplications // because (1) this decouples the decision of TSMM over default MM and (2) some cases @@ -314,10 +306,9 @@ protected double computeOutputMemEstimate( long dim1, long dim2, long nnz ) return ret; } - + @Override - protected double computeIntermediateMemEstimate( long dim1, long dim2, long nnz ) - { + protected double computeIntermediateMemEstimate(long dim1, long dim2, long nnz) { double ret = 0; if (isGPUEnabled()) { @@ -327,277 +318,254 @@ protected double computeIntermediateMemEstimate( long dim1, long dim2, long nnz double in2Sparsity = OptimizerUtils.getSparsity(in2.getDim1(), in2.getDim2(), in2.getNnz()); boolean in1Sparse = in1Sparsity < MatrixBlock.SPARSITY_TURN_POINT; boolean in2Sparse = in2Sparsity < MatrixBlock.SPARSITY_TURN_POINT; - if(in1Sparse && !in2Sparse) { + if (in1Sparse && !in2Sparse) { // Only in sparse-dense cases, we need additional memory budget for GPU ret += OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, 1.0); } } //account for potential final dense-sparse transformation (worst-case sparse representation) - if( dim2 >= 2 && nnz != 0 ) //vectors always dense + if (dim2 >= 2 && nnz != 0) //vectors always dense ret += MatrixBlock.estimateSizeSparseInMemory(dim1, dim2, - MatrixBlock.SPARSITY_TURN_POINT - UtilFunctions.DOUBLE_EPS); - + MatrixBlock.SPARSITY_TURN_POINT - UtilFunctions.DOUBLE_EPS); + return ret; } - + @Override - protected DataCharacteristics inferOutputCharacteristics( MemoTable memo ) - { + protected DataCharacteristics inferOutputCharacteristics(MemoTable memo) { DataCharacteristics[] dc = memo.getAllInputStats(getInput()); DataCharacteristics ret = null; - if( dc[0].rowsKnown() && dc[1].colsKnown() ) { + if (dc[0].rowsKnown() && dc[1].colsKnown()) { ret = new MatrixCharacteristics(dc[0].getRows(), dc[1].getCols()); - double sp1 = (dc[0].getNonZeros()>0) ? OptimizerUtils.getSparsity(dc[0].getRows(), dc[0].getCols(), dc[0].getNonZeros()) : 1.0; - double sp2 = (dc[1].getNonZeros()>0) ? OptimizerUtils.getSparsity(dc[1].getRows(), dc[1].getCols(), dc[1].getNonZeros()) : 1.0; - ret.setNonZeros((long)(ret.getLength() * OptimizerUtils.getMatMultSparsity(sp1, sp2, ret.getRows(), dc[0].getCols(), ret.getCols(), true))); + double sp1 = (dc[0].getNonZeros() > 0) ? OptimizerUtils.getSparsity(dc[0].getRows(), dc[0].getCols(), dc[0].getNonZeros()) : 1.0; + double sp2 = (dc[1].getNonZeros() > 0) ? OptimizerUtils.getSparsity(dc[1].getRows(), dc[1].getCols(), dc[1].getNonZeros()) : 1.0; + ret.setNonZeros((long) (ret.getLength() * OptimizerUtils.getMatMultSparsity(sp1, sp2, ret.getRows(), dc[0].getCols(), ret.getCols(), true))); } return ret; } - + public boolean isMatrixMultiply() { - return ( this.innerOp == OpOp2.MULT && this.outerOp == AggOp.SUM ); + return (this.innerOp == OpOp2.MULT && this.outerOp == AggOp.SUM); } - + private boolean isOuterProduct() { - return ( getInput().get(0).isVector() && getInput().get(1).isVector() ) - && ( getInput().get(0).getDim1() == 1 && getInput().get(0).getDim1() > 1 - && getInput().get(1).getDim1() > 1 && getInput().get(1).getDim2() == 1 ); + return (getInput().get(0).isVector() && getInput().get(1).isVector()) + && (getInput().get(0).getDim1() == 1 && getInput().get(0).getDim1() > 1 + && getInput().get(1).getDim1() > 1 && getInput().get(1).getDim2() == 1); } - + @Override public boolean isMultiThreadedOpType() { return isMatrixMultiply(); } - + @Override - public boolean allowsAllExecTypes() - { + public boolean allowsAllExecTypes() { return true; } - + @Override - protected ExecType optFindExecType(boolean transitive) - { + protected ExecType optFindExecType(boolean transitive) { checkAndSetForcedPlatform(); - - if( _etypeForced != null ) { + + if (_etypeForced != null) { setExecType(_etypeForced); - } - else - { - if ( OptimizerUtils.isMemoryBasedOptLevel() ) { + } else { + if (OptimizerUtils.isMemoryBasedOptLevel()) { setExecType(findExecTypeByMemEstimate()); } // choose CP if the dimensions of both inputs are below Hops.CPThreshold // OR if it is vector-vector inner product - else if ( (getInput().get(0).areDimsBelowThreshold() && getInput().get(1).areDimsBelowThreshold()) - || (getInput().get(0).isVector() && getInput().get(1).isVector() && !isOuterProduct()) ) - { + else if ((getInput().get(0).areDimsBelowThreshold() && getInput().get(1).areDimsBelowThreshold()) + || (getInput().get(0).isVector() && getInput().get(1).isVector() && !isOuterProduct())) { setExecType(ExecType.CP); - } - else - { + } else { setExecType(ExecType.SPARK); } - + //check for valid CP mmchain, send invalid memory requirements to remote - if( _etype == ExecType.CP - && checkMapMultChain() != ChainType.NONE - && OptimizerUtils.getLocalMemBudget() < - getInput().get(0).getInput().get(0).getOutputMemEstimate() ) { + if (_etype == ExecType.CP + && checkMapMultChain() != ChainType.NONE + && OptimizerUtils.getLocalMemBudget() < + getInput().get(0).getInput().get(0).getOutputMemEstimate()) { setExecType(ExecType.SPARK); } - + //check for valid CP dimensions and matrix size checkAndSetInvalidCPDimsAndSize(); } - + //spark-specific decision refinement (execute binary aggregate w/ left or right spark input and //single parent also in spark because it's likely cheap and reduces data transfer) MMTSJType mmtsj = checkTransposeSelf(); //determine tsmm pattern - if( transitive && _etype == ExecType.CP && _etypeForced != ExecType.CP - && ((!mmtsj.isLeft() && isApplicableForTransitiveSparkExecType(true)) - || ( !mmtsj.isRight() && isApplicableForTransitiveSparkExecType(false))) ) - { + if (transitive && _etype == ExecType.CP && _etypeForced != ExecType.CP + && ((!mmtsj.isLeft() && isApplicableForTransitiveSparkExecType(true)) + || (!mmtsj.isRight() && isApplicableForTransitiveSparkExecType(false)))) { //pull binary aggregate into spark setExecType(ExecType.SPARK); } //mark for recompile (forever) setRequiresRecompileIfNecessary(); - + return _etype; } - - private boolean isApplicableForTransitiveSparkExecType(boolean left) - { + + private boolean isApplicableForTransitiveSparkExecType(boolean left) { int index = left ? 0 : 1; - return !(getInput(index) instanceof DataOp && ((DataOp)getInput(index)).requiresCheckpoint()) - && (!HopRewriteUtils.isTransposeOperation(getInput(index)) + return !(getInput(index) instanceof DataOp && ((DataOp) getInput(index)).requiresCheckpoint()) + && (!HopRewriteUtils.isTransposeOperation(getInput(index)) || (left && !isLeftTransposeRewriteApplicable(true))) - && getInput(index).getParent().size()==1 //bagg is only parent - && !getInput(index).areDimsBelowThreshold() - && (getInput(index).optFindExecType() == ExecType.SPARK - || (getInput(index) instanceof DataOp && ((DataOp)getInput(index)).hasOnlyRDD())) - && getInput(index).getOutputMemEstimate()>getOutputMemEstimate(); + && getInput(index).getParent().size() == 1 //bagg is only parent + && !getInput(index).areDimsBelowThreshold() + && (getInput(index).optFindExecType() == ExecType.SPARK + || (getInput(index) instanceof DataOp && ((DataOp) getInput(index)).hasOnlyRDD())) + && getInput(index).getOutputMemEstimate() > getOutputMemEstimate(); } - + /** * TSMM: Determine if XtX pattern applies for this aggbinary and if yes - * which type. - * + * which type. + * * @return MMTSJType */ - public MMTSJType checkTransposeSelf() - { + public MMTSJType checkTransposeSelf() { MMTSJType ret = MMTSJType.NONE; - + Hop in1 = getInput().get(0); Hop in2 = getInput().get(1); - - if( HopRewriteUtils.isTransposeOperation(in1) - && in1.getInput().get(0) == in2 ) - { + + if (HopRewriteUtils.isTransposeOperation(in1) + && in1.getInput().get(0) == in2) { ret = MMTSJType.LEFT; } - - if( HopRewriteUtils.isTransposeOperation(in2) - && in2.getInput().get(0) == in1 ) - { + + if (HopRewriteUtils.isTransposeOperation(in2) + && in2.getInput().get(0) == in1) { ret = MMTSJType.RIGHT; } - + return ret; } /** - * MapMultChain: Determine if XtwXv/XtXv pattern applies for this aggbinary - * and if yes which type. - * + * MapMultChain: Determine if XtwXv/XtXv pattern applies for this aggbinary + * and if yes which type. + * * @return ChainType */ - public ChainType checkMapMultChain() - { + public ChainType checkMapMultChain() { ChainType chainType = ChainType.NONE; - + Hop in1 = getInput().get(0); Hop in2 = getInput().get(1); - + //check for transpose left input (both chain types) - if( HopRewriteUtils.isTransposeOperation(in1) ) - { + if (HopRewriteUtils.isTransposeOperation(in1)) { Hop X = in1.getInput().get(0); - + //check mapmultchain patterns //t(X)%*%(w*(X%*%v)) - if( in2 instanceof BinaryOp && ((BinaryOp)in2).getOp()==OpOp2.MULT ) - { + if (in2 instanceof BinaryOp && ((BinaryOp) in2).getOp() == OpOp2.MULT) { Hop in3b = in2.getInput().get(1); - if( in3b instanceof AggBinaryOp ) - { + if (in3b instanceof AggBinaryOp) { Hop in4 = in3b.getInput().get(0); - if( X == in4 ) //common input + if (X == in4) //common input chainType = ChainType.XtwXv; } } //t(X)%*%((X%*%v)-y) - else if( in2 instanceof BinaryOp && ((BinaryOp)in2).getOp()==OpOp2.MINUS ) - { + else if (in2 instanceof BinaryOp && ((BinaryOp) in2).getOp() == OpOp2.MINUS) { Hop in3a = in2.getInput().get(0); - Hop in3b = in2.getInput().get(1); - if( in3a instanceof AggBinaryOp && in3b.getDataType()==DataType.MATRIX ) - { + Hop in3b = in2.getInput().get(1); + if (in3a instanceof AggBinaryOp && in3b.getDataType() == DataType.MATRIX) { Hop in4 = in3a.getInput().get(0); - if( X == in4 ) //common input + if (X == in4) //common input chainType = ChainType.XtXvy; } } //t(X)%*%(X%*%v) - else if( in2 instanceof AggBinaryOp ) - { + else if (in2 instanceof AggBinaryOp) { Hop in3 = in2.getInput().get(0); - if( X == in3 ) //common input + if (X == in3) //common input chainType = ChainType.XtXv; } } - + return chainType; } - + ////////////////////////// // CP Lops generation - ///////////////////////// - - private void constructCPLopsTSMM( MMTSJType mmtsj, ExecType et ) { + + /// ////////////////////// + + private void constructCPLopsTSMM(MMTSJType mmtsj, ExecType et) { int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); - Lop matmultCP = new MMTSJ(getInput().get(mmtsj.isLeft()?1:0).constructLops(), - getDataType(), getValueType(), et, mmtsj, false, k); + Lop matmultCP = new MMTSJ(getInput().get(mmtsj.isLeft() ? 1 : 0).constructLops(), + getDataType(), getValueType(), et, mmtsj, false, k); matmultCP.getOutputParameters().setDimensions(getDim1(), getDim2(), getBlocksize(), getNnz()); - setLineNumbers( matmultCP ); + setLineNumbers(matmultCP); setLops(matmultCP); } - private void constructCPLopsMMChain( ChainType chain ) - { + private void constructCPLopsMMChain(ChainType chain) { MapMultChain mapmmchain = null; - if( chain == ChainType.XtXv ) { + if (chain == ChainType.XtXv) { Hop hX = getInput().get(0).getInput().get(0); Hop hv = getInput().get(1).getInput().get(1); - mapmmchain = new MapMultChain( hX.constructLops(), hv.constructLops(), getDataType(), getValueType(), ExecType.CP); - } - else { //ChainType.XtwXv / ChainType.XtwXvy + mapmmchain = new MapMultChain(hX.constructLops(), hv.constructLops(), getDataType(), getValueType(), ExecType.CP); + } else { //ChainType.XtwXv / ChainType.XtwXvy int wix = (chain == ChainType.XtwXv) ? 0 : 1; int vix = (chain == ChainType.XtwXv) ? 1 : 0; Hop hX = getInput().get(0).getInput().get(0); Hop hw = getInput().get(1).getInput().get(wix); Hop hv = getInput().get(1).getInput().get(vix).getInput().get(1); - mapmmchain = new MapMultChain( hX.constructLops(), hv.constructLops(), hw.constructLops(), chain, getDataType(), getValueType(), ExecType.CP); + mapmmchain = new MapMultChain(hX.constructLops(), hv.constructLops(), hw.constructLops(), chain, getDataType(), getValueType(), ExecType.CP); } - + //set degree of parallelism int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); - mapmmchain.setNumThreads( k ); - + mapmmchain.setNumThreads(k); + //set basic lop properties setOutputDimensions(mapmmchain); setLineNumbers(mapmmchain); setLops(mapmmchain); } - + /** * NOTE: exists for consistency since removeEmtpy might be scheduled to MR - * but matrix mult on small output might be scheduled to CP. Hence, we + * but matrix mult on small output might be scheduled to CP. Hence, we * need to handle directly passed selection vectors in CP as well. */ - private void constructCPLopsPMM() - { + private void constructCPLopsPMM() { Hop pmInput = getInput().get(0); Hop rightInput = getInput().get(1); - + Hop nrow = HopRewriteUtils.createValueHop(pmInput, true); //NROW nrow.setBlocksize(0); nrow.setForcedExecType(ExecType.CP); HopRewriteUtils.copyLineNumbers(this, nrow); Lop lnrow = nrow.constructLops(); - + PMMJ pmm = new PMMJ(pmInput.constructLops(), rightInput.constructLops(), lnrow, getDataType(), getValueType(), false, false, ExecType.CP); - + //set degree of parallelism int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); pmm.setNumThreads(k); - + pmm.getOutputParameters().setDimensions(getDim1(), getDim2(), getBlocksize(), getNnz()); setLineNumbers(pmm); - + setLops(pmm); - + HopRewriteUtils.removeChildReference(pmInput, nrow); } - private void constructCPLopsMM(ExecType et) - { + private void constructCPLopsMM(ExecType et) { Lop matmultCP = null; String cla = ConfigurationManager.getDMLConfig().getTextValue("sysds.compressed.linalg"); if (et == ExecType.GPU) { @@ -610,72 +578,85 @@ private void constructCPLopsMM(ExecType et) boolean leftTrans = false; // HopRewriteUtils.isTransposeOperation(h1); boolean rightTrans = false; // HopRewriteUtils.isTransposeOperation(h2); Lop left = !leftTrans ? h1.constructLops() : - h1.getInput().get(0).constructLops(); + h1.getInput().get(0).constructLops(); Lop right = !rightTrans ? h2.constructLops() : - h2.getInput().get(0).constructLops(); + h2.getInput().get(0).constructLops(); matmultCP = new MatMultCP(left, right, getDataType(), getValueType(), et, leftTrans, rightTrans); setOutputDimensions(matmultCP); - } - else if (cla.equals("true") || cla.equals("cost")){ + } else if (cla.equals("true") || cla.equals("cost")) { Hop h1 = getInput().get(0); Hop h2 = getInput().get(1); int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); boolean leftTrans = HopRewriteUtils.isTransposeOperation(h1); - boolean rightTrans = HopRewriteUtils.isTransposeOperation(h2); + boolean rightTrans = HopRewriteUtils.isTransposeOperation(h2); Lop left = !leftTrans ? h1.constructLops() : - h1.getInput().get(0).constructLops(); + h1.getInput().get(0).constructLops(); Lop right = !rightTrans ? h2.constructLops() : - h2.getInput().get(0).constructLops(); + h2.getInput().get(0).constructLops(); matmultCP = new MatMultCP(left, right, getDataType(), getValueType(), et, k, leftTrans, rightTrans); - } - else { - if( isLeftTransposeRewriteApplicable(true) ) { + } else { + if (isLeftTransposeRewriteApplicable(true)) { matmultCP = constructCPLopsMMWithLeftTransposeRewrite(et); - } - else { + } else { int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); matmultCP = new MatMultCP(getInput().get(0).constructLops(), - getInput().get(1).constructLops(), getDataType(), getValueType(), et, k); + getInput().get(1).constructLops(), getDataType(), getValueType(), et, k); updateLopFedOut(matmultCP); } setOutputDimensions(matmultCP); } - + setLineNumbers(matmultCP); setLops(matmultCP); } - private Lop constructCPLopsMMWithLeftTransposeRewrite(ExecType et) - { - Hop X = getInput().get(0).getInput().get(0); //guaranteed to exists + private Lop constructCPLopsMMWithLeftTransposeRewrite(ExecType et) { + Hop X = getInput().get(0).getInput().get(0); // guaranteed to exist Hop Y = getInput().get(1); int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); - + + //Check if X is already a transpose operation + boolean isXTransposed = X instanceof ReorgOp && ((ReorgOp)X).getOp() == ReOrgOp.TRANS; + Hop actualX = isXTransposed ? X.getInput().get(0) : X; + + //Check if Y is a transpose operation + boolean isYTransposed = Y instanceof ReorgOp && ((ReorgOp)Y).getOp() == ReOrgOp.TRANS; + Hop actualY = isYTransposed ? Y.getInput().get(0) : Y; + + //Handle Y or actualY for transpose + Lop yLop = isYTransposed ? actualY.constructLops() : Y.constructLops(); + ExecType inputReorgExecType = (Y.hasFederatedOutput()) ? ExecType.FED : ExecType.CP; + //right vector transpose - Lop lY = Y.constructLops(); - ExecType inputReorgExecType = ( Y.hasFederatedOutput() ) ? ExecType.FED : ExecType.CP; - Lop tY = (lY instanceof Transform && ((Transform)lY).getOp()==ReOrgOp.TRANS ) ? - lY.getInputs().get(0) : //if input is already a transpose, avoid redundant transpose ops - new Transform(lY, ReOrgOp.TRANS, getDataType(), getValueType(), inputReorgExecType, k); - tY.getOutputParameters().setDimensions(Y.getDim2(), Y.getDim1(), getBlocksize(), Y.getNnz()); + Lop tY = (yLop instanceof Transform && ((Transform)yLop).getOp() == ReOrgOp.TRANS) ? + yLop.getInputs().get(0) : //if input is already a transpose, avoid redundant transpose ops + new Transform(yLop, ReOrgOp.TRANS, getDataType(), getValueType(), inputReorgExecType, k); + + //Set dimensions for tY + long tYRows = isYTransposed ? actualY.getDim1() : Y.getDim2(); + long tYCols = isYTransposed ? actualY.getDim2() : Y.getDim1(); + tY.getOutputParameters().setDimensions(tYRows, tYCols, getBlocksize(), Y.getNnz()); setLineNumbers(tY); if (Y.hasFederatedOutput()) updateLopFedOut(tY); - + + //Construct X lops for matrix multiplication + Lop xLop = isXTransposed ? actualX.constructLops() : X.constructLops(); + //matrix mult - Lop mult = new MatMultCP(tY, X.constructLops(), getDataType(), getValueType(), et, k); //CP or FED - mult.getOutputParameters().setDimensions(Y.getDim2(), X.getDim2(), getBlocksize(), getNnz()); + Lop mult = new MatMultCP(tY, xLop, getDataType(), getValueType(), et, k); + mult.getOutputParameters().setDimensions(tYRows, isXTransposed ? actualX.getDim1() : X.getDim2(), getBlocksize(), getNnz()); mult.setFederatedOutput(_federatedOutput); setLineNumbers(mult); //result transpose (dimensions set outside) - ExecType outTransposeExecType = ( _federatedOutput == FederatedOutput.FOUT ) ? - ExecType.FED : ExecType.CP; + ExecType outTransposeExecType = (_federatedOutput == FederatedOutput.FOUT) ? + ExecType.FED : ExecType.CP; Lop out = new Transform(mult, ReOrgOp.TRANS, getDataType(), getValueType(), outTransposeExecType, k); return out; } - + ////////////////////////// // Spark Lops generation ///////////////////////// @@ -718,25 +699,25 @@ private Lop constructSparkLopsMapMMWithLeftTransposeRewrite() { Hop X = getInput().get(0).getInput().get(0); //guaranteed to exists Hop Y = getInput().get(1); - + //right vector transpose Lop tY = new Transform(Y.constructLops(), ReOrgOp.TRANS, getDataType(), getValueType(), ExecType.CP); tY.getOutputParameters().setDimensions(Y.getDim2(), Y.getDim1(), getBlocksize(), Y.getNnz()); setLineNumbers(tY); - + //matrix mult spark - boolean needAgg = requiresAggregation(MMultMethod.MAPMM_R); + boolean needAgg = requiresAggregation(MMultMethod.MAPMM_R); SparkAggType aggtype = getSparkMMAggregationType(needAgg); - _outputEmptyBlocks = !OptimizerUtils.allowsToFilterEmptyBlockOutputs(this); - - Lop mult = new MapMult( tY, X.constructLops(), getDataType(), getValueType(), - false, false, _outputEmptyBlocks, aggtype); + _outputEmptyBlocks = !OptimizerUtils.allowsToFilterEmptyBlockOutputs(this); + + Lop mult = new MapMult( tY, X.constructLops(), getDataType(), getValueType(), + false, false, _outputEmptyBlocks, aggtype); mult.getOutputParameters().setDimensions(Y.getDim2(), X.getDim2(), getBlocksize(), getNnz()); setLineNumbers(mult); - + //result transpose (dimensions set outside) Lop out = new Transform(mult, ReOrgOp.TRANS, getDataType(), getValueType(), ExecType.CP); - + return out; } @@ -892,13 +873,13 @@ private void constructSparkLopsZIPMM() { setLineNumbers( zipmm ); setLops(zipmm); } - + /** * Determines if the rewrite t(X)%*%Y -> t(t(Y)%*%X) is applicable * and cost effective. Whenever X is a wide matrix and Y is a vector * this has huge impact, because the transpose of X would dominate * the entire operation costs. - * + * * @param CP true if CP * @return true if left transpose rewrite applicable */ @@ -910,38 +891,38 @@ private boolean isLeftTransposeRewriteApplicable(boolean CP) { return false; } - + boolean ret = false; Hop h1 = getInput().get(0); Hop h2 = getInput().get(1); - + //check for known dimensions and cost for t(X) vs t(v) + t(tvX) //(for both CP/MR, we explicitly check that new transposes fit in memory, //even a ba in CP does not imply that both transposes can be executed in CP) - if( CP ) //in-memory ba + if( CP ) //in-memory ba { if( HopRewriteUtils.isTransposeOperation(h1) ) { long m = h1.getDim1(); long cd = h1.getDim2(); long n = h2.getDim2(); - + //check for known dimensions (necessary condition for subsequent checks) - ret = (m>0 && cd>0 && n>0); - - //check operation memory with changed transpose (this is important if we have + ret = (m>0 && cd>0 && n>0); + + //check operation memory with changed transpose (this is important if we have //e.g., t(X) %*% v, where X is sparse and tX fits in memory but X does not double memX = h1.getInput().get(0).getOutputMemEstimate(); double memtv = OptimizerUtils.estimateSizeExactSparsity(n, cd, 1.0); double memtXv = OptimizerUtils.estimateSizeExactSparsity(n, m, 1.0); double newMemEstimate = memtv + memX + memtXv; ret &= ( newMemEstimate < OptimizerUtils.getLocalMemBudget() ); - + //check for cost benefit of t(X) vs t(v) + t(tvX) and memory of additional transpose ops ret &= ( m*cd > (cd*n + m*n) && - 2 * OptimizerUtils.estimateSizeExactSparsity(cd, n, 1.0) < OptimizerUtils.getLocalMemBudget() && - 2 * OptimizerUtils.estimateSizeExactSparsity(m, n, 1.0) < OptimizerUtils.getLocalMemBudget() ); - + 2 * OptimizerUtils.estimateSizeExactSparsity(cd, n, 1.0) < OptimizerUtils.getLocalMemBudget() && + 2 * OptimizerUtils.estimateSizeExactSparsity(m, n, 1.0) < OptimizerUtils.getLocalMemBudget() ); + //update operation memory estimate (e.g., for parfor optimizer) if( ret ) _memEstimate = newMemEstimate; @@ -955,14 +936,14 @@ private boolean isLeftTransposeRewriteApplicable(boolean CP) long n = h2.getDim2(); //note: output size constraint for mapmult already checked by optfindmmultmethod if( m>0 && cd>0 && n>0 && (m*cd > (cd*n + m*n)) && - 2 * OptimizerUtils.estimateSizeExactSparsity(cd, n, 1.0) < OptimizerUtils.getLocalMemBudget() && - 2 * OptimizerUtils.estimateSizeExactSparsity(m, n, 1.0) < OptimizerUtils.getLocalMemBudget() ) + 2 * OptimizerUtils.estimateSizeExactSparsity(cd, n, 1.0) < OptimizerUtils.getLocalMemBudget() && + 2 * OptimizerUtils.estimateSizeExactSparsity(m, n, 1.0) < OptimizerUtils.getLocalMemBudget() ) { ret = true; } } } - + return ret; } diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteTransposeTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteTransposeTest.java new file mode 100644 index 00000000000..9c6bd0f7df1 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteTransposeTest.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; +import java.util.HashMap; + +public class RewriteTransposeTest extends AutomatedTestBase { + private final static String TEST_NAME1 = "RewriteTransposeCase1"; // t(X)%*%Y + private final static String TEST_NAME2 = "RewriteTransposeCase2"; // X=t(A); t(X)%*%Y + private final static String TEST_NAME3 = "RewriteTransposeCase3"; // Y=t(A); t(X)%*%Y + + private final static String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteTransposeTest.class.getSimpleName() + "/"; + + private static final double eps = 1e-9; + + @Override + public void setUp() { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION=false; + + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"R"})); + addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"R"})); + addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[]{"R"})); + } + + @Test + public void testTransposeRewrite1CP() { + runTransposeRewriteTest(TEST_NAME1, false); + } + + @Test + public void testTransposeRewrite2CP() { + runTransposeRewriteTest(TEST_NAME2, true); + } + + @Test + public void testTransposeRewrite3CP() { + runTransposeRewriteTest(TEST_NAME3, false); + } + + private void runTransposeRewriteTest(String testname, boolean expectedMerge) { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + + programArgs = new String[]{"-explain", "-stats", "-args", output("R")}; + + fullRScriptName = HOME + testname + ".R"; + rCmd = getRCmd(expectedDir()); + + runTest(true, false, null, -1); + runRScript(true); + + HashMap dmlOutput = readDMLMatrixFromOutputDir("R"); + HashMap rOutput = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlOutput, rOutput, eps, "Stat-DML", "Stat-R"); + } +} diff --git a/src/test/scripts/functions/rewrite/RewriteTransposeCase1.R b/src/test/scripts/functions/rewrite/RewriteTransposeCase1.R new file mode 100644 index 00000000000..5b0e19dca23 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteTransposeCase1.R @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) + +library("Matrix") +library("matrixStats") + +X <- matrix(seq(1, 20), nrow=4, ncol=5, byrow=TRUE) +Y <- matrix(seq(1, 12), nrow=4, ncol=3, byrow=TRUE) + +R <- t(t(Y)%*%X) + +writeMM(as(R, "CsparseMatrix"), paste(args[1], "R", sep="")); diff --git a/src/test/scripts/functions/rewrite/RewriteTransposeCase1.dml b/src/test/scripts/functions/rewrite/RewriteTransposeCase1.dml new file mode 100644 index 00000000000..83cfb65dc6b --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteTransposeCase1.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(seq(1, 20), rows=4, cols=5); +Y = matrix(seq(1, 12), rows=4, cols=3); + +R = t(X)%*%Y; + +write(R, $1); \ No newline at end of file diff --git a/src/test/scripts/functions/rewrite/RewriteTransposeCase2.R b/src/test/scripts/functions/rewrite/RewriteTransposeCase2.R new file mode 100644 index 00000000000..fea8c266693 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteTransposeCase2.R @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) + +library("Matrix") +library("matrixStats") +A = matrix(seq(1, 20), nrow=5, ncol=4, byrow=TRUE) +Y = matrix(seq(1, 12), nrow=4, ncol=3, byrow=TRUE) +X = t(A) + +R <- t(t(Y)%*%X) + +writeMM(as(R, "CsparseMatrix"), paste(args[1], "R", sep="")); diff --git a/src/test/scripts/functions/rewrite/RewriteTransposeCase2.dml b/src/test/scripts/functions/rewrite/RewriteTransposeCase2.dml new file mode 100644 index 00000000000..cb9332423bf --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteTransposeCase2.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = matrix(seq(1, 20), rows=5, cols=4); +Y = matrix(seq(1, 12), rows=4, cols=3); +X = t(A); + +R = t(X) %*% Y; + +write(R, $1); \ No newline at end of file diff --git a/src/test/scripts/functions/rewrite/RewriteTransposeCase3.R b/src/test/scripts/functions/rewrite/RewriteTransposeCase3.R new file mode 100644 index 00000000000..2bdd22f674e --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteTransposeCase3.R @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) + +library("Matrix") +library("matrixStats") + +X <- matrix(seq(1, 20), nrow=4, ncol=5, byrow=TRUE) +A <- matrix(seq(1, 12), nrow=3, ncol=4, byrow=TRUE) +Y <- t(A) + +R <- t(t(Y)%*%X) + +writeMM(as(R, "CsparseMatrix"), paste(args[1], "R", sep="")); diff --git a/src/test/scripts/functions/rewrite/RewriteTransposeCase3.dml b/src/test/scripts/functions/rewrite/RewriteTransposeCase3.dml new file mode 100644 index 00000000000..2e26920aedc --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteTransposeCase3.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(seq(1, 20), rows=4, cols=5); +A = matrix(seq(1, 12), rows=3, cols=4); +Y = t(A); + +R = t(X) %*% Y; + +write(R, $1); \ No newline at end of file From 504938250f79a6c1898da90df09957ff5b598f23 Mon Sep 17 00:00:00 2001 From: aarna Date: Tue, 1 Apr 2025 13:05:38 +0200 Subject: [PATCH 3/3] Remove file from staging before PR --- .../RewriteBooleanSimplificationTest.java | 76 ------------------- 1 file changed, 76 deletions(-) delete mode 100644 src/test/java/org/apache/sysds/test/functions/rewrite/RewriteBooleanSimplificationTest.java diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteBooleanSimplificationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteBooleanSimplificationTest.java deleted file mode 100644 index afb70b8ff3f..00000000000 --- a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteBooleanSimplificationTest.java +++ /dev/null @@ -1,76 +0,0 @@ -package org.apache.sysds.test.functions.rewrite; - -import org.junit.Assert; -import org.junit.Test; -import org.apache.sysds.api.DMLScript; -import org.apache.sysds.common.Types.ExecMode; -import org.apache.sysds.common.Types.ExecType; -import org.apache.sysds.test.AutomatedTestBase; -import org.apache.sysds.test.TestConfiguration; -import org.apache.sysds.test.TestUtils; - -public class RewriteBooleanSimplificationTest extends AutomatedTestBase { - - private static final String TEST_NAME_AND = "RewriteBooleanSimplificationTestAnd"; - private static final String TEST_NAME_OR = "RewriteBooleanSimplificationTestOr"; - private static final String TEST_DIR = "functions/rewrite/"; - private static final String TEST_CLASS_DIR = TEST_DIR + RewriteBooleanSimplificationTest.class.getSimpleName() + "/"; - - @Override - public void setUp() { - TestUtils.clearAssertionInformation(); - addTestConfiguration(TEST_NAME_AND, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_AND)); - addTestConfiguration(TEST_NAME_OR, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_OR)); - } - - @Test - public void testBooleanRewriteAnd() { - testRewriteBooleanSimplification(TEST_NAME_AND, ExecType.CP, 0.0); - } - - @Test - public void testBooleanRewriteOr() { - testRewriteBooleanSimplification(TEST_NAME_OR, ExecType.CP, 1.0); - } - - private void testRewriteBooleanSimplification(String testname, ExecType et, double expected) { - ExecMode platformOld = rtplatform; - rtplatform = (et == ExecType.SPARK) ? ExecMode.SPARK : ExecMode.HYBRID; - - boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; - if (rtplatform == ExecMode.SPARK || rtplatform == ExecMode.HYBRID) { - DMLScript.USE_LOCAL_SPARK_CONFIG = true; - } - - try { - TestConfiguration config = getTestConfiguration(testname); - loadTestConfiguration(config); - - String HOME = SCRIPT_DIR + TEST_DIR; - fullDMLScriptName = HOME + testname + ".dml"; - programArgs = new String[]{}; - - runTest(true, false, null, -1); - - Assert.assertEquals("Expected boolean simplification result does not match", expected, getRewriteBooleanSimplificationResult(testname), 0.0001); - } finally { - rtplatform = platformOld; - DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; - } - } - - private double getRewriteBooleanSimplificationResult(String testname) { - - if (testname.equals(TEST_NAME_AND)) { - // a & !a simplifies to false (0.0) - return 0.0; - } else if (testname.equals(TEST_NAME_OR)) { - // a | !a simplifies to true (1.0) - return 1.0; - } else { - // In case of an unknown operation, we return a default value (e.g., 0.0). - return 0.0; - } - } - -}