diff --git a/.github/workflows/javaTests.yml b/.github/workflows/javaTests.yml index ba7da60c8e4..a1b3bf9728b 100644 --- a/.github/workflows/javaTests.yml +++ b/.github/workflows/javaTests.yml @@ -85,7 +85,8 @@ jobs: "**.functions.unary.scalar.**,**.functions.updateinplace.**,**.functions.vect.**", "**.functions.reorg.**,**.functions.rewrite.**,**.functions.ternary.**", "**.functions.transform.**","**.functions.unique.**", - "**.functions.unary.matrix.**,**.functions.linearization.**,**.functions.jmlc.**" + "**.functions.unary.matrix.**,**.functions.linearization.**,**.functions.jmlc.**", + "**.functions.einsum.**", ] name: ${{ matrix.tests }} steps: diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index fe75aec6a05..5fe1721cc20 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -403,6 +403,7 @@ public enum Builtins { UNDER_SAMPLING("underSampling", true), UNIQUE("unique", false, true), UPPER_TRI("upper.tri", false, true), + EINSUM("einsum", false, false), XDUMMY1("xdummy1", true), //error handling test XDUMMY2("xdummy2", true); //error handling test diff --git a/src/main/java/org/apache/sysds/common/InstructionType.java b/src/main/java/org/apache/sysds/common/InstructionType.java index 1980dd7984d..29148f03e92 100644 --- a/src/main/java/org/apache/sysds/common/InstructionType.java +++ b/src/main/java/org/apache/sysds/common/InstructionType.java @@ -62,6 +62,7 @@ public enum InstructionType { PMMJ, MMChain, Union, + EINSUM, //SP Types MAPMM, diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java b/src/main/java/org/apache/sysds/common/Opcodes.java index 64a6c7dd27e..6aeeb7e8f20 100644 --- a/src/main/java/org/apache/sysds/common/Opcodes.java +++ b/src/main/java/org/apache/sysds/common/Opcodes.java @@ -174,7 +174,7 @@ public enum Opcodes { RBIND("rbind", InstructionType.BuiltinNary), EVAL("eval", InstructionType.BuiltinNary), LIST("list", InstructionType.BuiltinNary), - + EINSUM("einsum", InstructionType.BuiltinNary), //Parametrized builtin functions AUTODIFF("autoDiff", InstructionType.ParameterizedBuiltin), CONTAINS("contains", InstructionType.ParameterizedBuiltin), diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index cc7f6eb377a..09a0f8effd8 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -767,8 +767,8 @@ public String toString() { /** Operations that require a variable number of operands*/ public enum OpOpN { - PRINTF, CBIND, RBIND, MIN, MAX, PLUS, MULT, EVAL, LIST; - + PRINTF, CBIND, RBIND, MIN, MAX, PLUS, MULT, EVAL, LIST, EINSUM; + public boolean isCellOp() { return this == MIN || this == MAX || this == PLUS || this == MULT; } diff --git a/src/main/java/org/apache/sysds/hops/NaryOp.java b/src/main/java/org/apache/sysds/hops/NaryOp.java index 1659b0dbc5e..6962beadcbc 100644 --- a/src/main/java/org/apache/sysds/hops/NaryOp.java +++ b/src/main/java/org/apache/sysds/hops/NaryOp.java @@ -26,6 +26,7 @@ import org.apache.sysds.lops.Lop; import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.lops.Nary; +import org.apache.sysds.runtime.einsum.EinsumEquationValidator; import org.apache.sysds.runtime.meta.DataCharacteristics; import org.apache.sysds.runtime.meta.MatrixCharacteristics; @@ -235,6 +236,14 @@ public void refreshSizeInformation() { setDim1(getInput().size()); setDim2(1); break; + case EINSUM: + String equationString = ((LiteralOp) _input.get(0)).getStringValue(); + var dims = EinsumEquationValidator.validateEinsumEquationAndReturnDimensions(equationString, this.getInput().subList(1, this.getInput().size())); + + setDim1(dims.getLeft()); + setDim2(dims.getMiddle()); + setDataType(dims.getRight()); + break; case PRINTF: case EVAL: //do nothing: diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java index 3d2c19ef4c8..2482ac77e22 100644 --- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java +++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java @@ -31,7 +31,7 @@ public class CNodeCell extends CNodeTpl { - protected static final String JAVA_TEMPLATE = + public static final String JAVA_TEMPLATE = "package codegen;\n" + "import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;\n" + "import org.apache.sysds.runtime.codegen.SpoofCellwise;\n" diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeData.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeData.java index 9292972874f..b90789df5eb 100644 --- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeData.java +++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeData.java @@ -56,6 +56,14 @@ public CNodeData(CNodeData node, String newName) { _cols = node.getNumCols(); _dataType = node.getDataType(); } + + public CNodeData(String name, long hopID, long rows, long cols, DataType dataType) { + _name = name; + _hopID = hopID; + _rows = rows; + _cols = cols; + _dataType = dataType; + } @Override public String getVarname() { diff --git a/src/main/java/org/apache/sysds/lops/Nary.java b/src/main/java/org/apache/sysds/lops/Nary.java index e073bc68817..e5382ba0330 100644 --- a/src/main/java/org/apache/sysds/lops/Nary.java +++ b/src/main/java/org/apache/sysds/lops/Nary.java @@ -111,6 +111,7 @@ private String getOpcode() { case RBIND: case EVAL: case LIST: + case EINSUM: return operationType.name().toLowerCase(); case MIN: case MAX: diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java index 540b522a8bb..0a5f30b712a 100644 --- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java @@ -23,6 +23,7 @@ import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedList; import org.antlr.v4.runtime.ParserRuleContext; import org.apache.commons.lang3.ArrayUtils; @@ -35,6 +36,7 @@ import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.parser.LanguageException.LanguageErrorCodes; +import org.apache.sysds.runtime.einsum.EinsumEquationValidator; import org.apache.sysds.runtime.meta.MatrixCharacteristics; import org.apache.sysds.runtime.util.DnnUtils; import org.apache.sysds.runtime.util.UtilFunctions; @@ -751,7 +753,9 @@ else if(((ConstIdentifier) getThirdExpr().getOutput()) else raiseValidateError("Compress/DeCompress instruction not allowed in dml script"); break; - + case EINSUM: + validateEinsum((DataIdentifier) getOutputs()[0]); + break; default: //always unconditional raiseValidateError("Unknown Builtin Function opcode: " + _opcode, false); } @@ -2063,7 +2067,9 @@ else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AV output.setValueType(ValueType.INT64); output.setNnz(id.getDim2()); break; - + case EINSUM: + validateEinsum(output); + break; default: if( isMathFunction() ) { checkMathFunctionParam(); @@ -2096,6 +2102,49 @@ else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AV } } + private void validateEinsum(DataIdentifier output){ + if(getSecondExpr() == null) + raiseValidateError("Einsum: at least one input matrix required", false, + LanguageErrorCodes.INVALID_PARAMETERS); + + if(!(getFirstExpr() instanceof StringIdentifier)) + raiseValidateError("Einsum: first argument has to be equation str", false, + LanguageErrorCodes.INVALID_PARAMETERS); + + String equationString = ((StringIdentifier)getFirstExpr()).getValue(); + + if (equationString.length() == 0) raiseValidateError("Einsum: equation str too short", false, LanguageErrorCodes.INVALID_PARAMETERS); + if (equationString.charAt(0) == '-' || equationString.charAt(0) == ',') raiseValidateError("Einsum: equation str invalid", false, LanguageErrorCodes.INVALID_PARAMETERS); + + Expression[] expressions = getAllExpr(); + boolean allDimsKnown = true; + + LinkedList matrixBlocks = new LinkedList(); + for (int i=1;i charToDimensionSize; + public String equationString; + public boolean[] diagonalInputs; + public HashSet summingChars; + public HashSet contractDimsSet; + public ContractDimensions[] contractDims; + public ArrayList newEquationStringInputsSplit; + public HashMap> characterAppearanceIndexes; // for each character, this tells in which inputs it appears + + private EinsumContext(){}; + public static EinsumContext getEinsumContext(String eqStr, ArrayList inputs){ + EinsumContext res = new EinsumContext(); + + res.equationString = eqStr; + res.charToDimensionSize = new HashMap(); + HashSet summingChars = new HashSet<>(); + ContractDimensions[] contractDims = new ContractDimensions[inputs.size()]; + boolean[] diagonalInputs = new boolean[inputs.size()]; // all false by default + HashSet contractDimsSet = new HashSet(); + HashMap> partsCharactersToIndices = new HashMap<>(); + ArrayList newEquationStringSplit = new ArrayList(); + + Iterator it = inputs.iterator(); + MatrixBlock curArr = it.next(); + int arrSizeIterator = 0; + int arrayIterator = 0; + int i; + // first iteration through string: collect information on character-size and what characters are summing characters + for (i = 0; true; i++) { + char c = eqStr.charAt(i); + if(c == '-'){ + i+=2; + break; + } + if(c == ','){ + arrayIterator++; + curArr = it.next(); + arrSizeIterator = 0; + } + else{ + if (res.charToDimensionSize.containsKey(c)) { // sanity check if dims match, this is already checked at validation + if(arrSizeIterator == 0 && res.charToDimensionSize.get(c) != curArr.getNumRows()) + throw new RuntimeException("Einsum: character "+c+" has multiple conflicting sizes"); + else if(arrSizeIterator == 1 && res.charToDimensionSize.get(c) != curArr.getNumColumns()) + throw new RuntimeException("Einsum: character "+c+" has multiple conflicting sizes"); + summingChars.add(c); + } else { + if(arrSizeIterator == 0) + res.charToDimensionSize.put(c, curArr.getNumRows()); + else if(arrSizeIterator == 1) + res.charToDimensionSize.put(c, curArr.getNumColumns()); + } + + arrSizeIterator++; + } + } + + int numOfRemainingChars = eqStr.length() - i; + + if (numOfRemainingChars > 2) + throw new RuntimeException("Einsum: dim > 2 not supported"); + + arrSizeIterator = 0; + + Character outChar1 = numOfRemainingChars > 0 ? eqStr.charAt(i) : null; + Character outChar2 = numOfRemainingChars > 1 ? eqStr.charAt(i+1) : null; + res.outRows=(numOfRemainingChars > 0 ? res.charToDimensionSize.get(outChar1) : 1); + res.outCols=(numOfRemainingChars > 1 ? res.charToDimensionSize.get(outChar2) : 1); + + arrayIterator=0; + // second iteration through string: collect remaining information + for (i = 0; true; i++) { + char c = eqStr.charAt(i); + if (c == '-') { + break; + } + if (c == ',') { + arrayIterator++; + arrSizeIterator = 0; + continue; + } + String s = ""; + + if(summingChars.contains(c)) { + s+=c; + if(!partsCharactersToIndices.containsKey(c)) + partsCharactersToIndices.put(c, new ArrayList<>()); + partsCharactersToIndices.get(c).add(arrayIterator); + } + else if((outChar1 != null && c == outChar1) || (outChar2 != null && c == outChar2)) { + s+=c; + } + else { + contractDimsSet.add(c); + contractDims[arrayIterator] = ContractDimensions.CONTRACT_LEFT; + } + + if(i + 1 < eqStr.length()) { // process next character together + char c2 = eqStr.charAt(i + 1); + i++; + if (c2 == '-') { newEquationStringSplit.add(s); break;} + if (c2 == ',') { arrayIterator++; newEquationStringSplit.add(s); continue; } + + if (c2 == c){ + diagonalInputs[arrayIterator] = true; + if (contractDims[arrayIterator] == ContractDimensions.CONTRACT_LEFT) contractDims[arrayIterator] = ContractDimensions.CONTRACT_BOTH; + } + else{ + if(summingChars.contains(c2)) { + s+=c2; + if(!partsCharactersToIndices.containsKey(c2)) + partsCharactersToIndices.put(c2, new ArrayList<>()); + partsCharactersToIndices.get(c2).add(arrayIterator); + } + else if((outChar1 != null && c2 == outChar1) || (outChar2 != null && c2 == outChar2)) { + s+=c2; + } + else { + contractDimsSet.add(c2); + contractDims[arrayIterator] = contractDims[arrayIterator] == ContractDimensions.CONTRACT_LEFT ? ContractDimensions.CONTRACT_BOTH : ContractDimensions.CONTRACT_RIGHT; + } + } + } + newEquationStringSplit.add(s); + arrSizeIterator++; + } + + res.contractDims = contractDims; + res.contractDimsSet = contractDimsSet; + res.diagonalInputs = diagonalInputs; + res.summingChars = summingChars; + res.outChar1 = outChar1; + res.outChar2 = outChar2; + res.newEquationStringInputsSplit = newEquationStringSplit; + res.characterAppearanceIndexes = partsCharactersToIndices; + return res; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java b/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java new file mode 100644 index 00000000000..5643159ef9a --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.einsum; + +import org.apache.commons.lang3.tuple.Triple; +import org.apache.sysds.common.Types; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.parser.Identifier; +import org.apache.sysds.parser.LanguageException; +import org.apache.sysds.parser.ParseInfo; + +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; + +public class EinsumEquationValidator { + + public static Triple validateEinsumEquationAndReturnDimensions(String equationString, List expressionsOrIdentifiers) throws LanguageException { + String[] eqStringParts = equationString.split("->"); // length 2 if "...->..." , length 1 if "...->" + boolean isResultScalar = eqStringParts.length == 1; + + if(expressionsOrIdentifiers == null) + throw new RuntimeException("Einsum: called validateEinsumAndReturnDimensions with null list"); + + HashMap charToDimensionSize = new HashMap<>(); + Iterator it = expressionsOrIdentifiers.iterator(); + HopOrIdentifier currArr = it.next(); + int arrSizeIterator = 0; + int numberOfMatrices = 1; + for (int i = 0; i < eqStringParts[0].length(); i++) { + char c = equationString.charAt(i); + if(c==' ') continue; + if(c==','){ + if(!it.hasNext()) + throw new LanguageException("Einsum: Provided less operands than specified in equation str"); + currArr = it.next(); + arrSizeIterator = 0; + numberOfMatrices++; + } else{ + long thisCharDimension = getThisCharDimension(currArr, arrSizeIterator); + if (charToDimensionSize.containsKey(c)){ + if (charToDimensionSize.get(c) != thisCharDimension) + throw new LanguageException("Einsum: Character '" + c + "' expected to be dim " + charToDimensionSize.get(c) + ", but found " + thisCharDimension); + }else{ + charToDimensionSize.put(c, thisCharDimension); + } + arrSizeIterator++; + } + } + if (expressionsOrIdentifiers.size() - 1 > numberOfMatrices) + throw new LanguageException("Einsum: Provided more operands than specified in equation str"); + + if (isResultScalar) + return Triple.of(-1l,-1l, Types.DataType.SCALAR); + + int numberOfOutDimensions = 0; + Character dim1Char = null; + long dim1 = 1; + long dim2 = 1; + for (int i = 0; i < eqStringParts[1].length(); i++) { + char c = eqStringParts[1].charAt(i); + if (c == ' ') continue; + if (numberOfOutDimensions == 0) { + dim1Char = c; + dim1 = charToDimensionSize.get(c); + } else { + if(c==dim1Char) throw new LanguageException("Einsum: output character "+c+" provided multiple times"); + dim2 = charToDimensionSize.get(c); + } + numberOfOutDimensions++; + } + if (numberOfOutDimensions > 2) { + throw new LanguageException("Einsum: output matrices with with no. dims > 2 not supported"); + } else { + return Triple.of(dim1, dim2, Types.DataType.MATRIX); + } + } + + public static Types.DataType validateEinsumEquationNoDimensions(String equationString, int numberOfMatrixInputs) throws LanguageException { + String[] eqStringParts = equationString.split("->"); // length 2 if "...->..." , length 1 if "...->" + boolean isResultScalar = eqStringParts.length == 1; + + int numberOfMatrices = 1; + for (int i = 0; i < eqStringParts[0].length(); i++) { + char c = eqStringParts[0].charAt(i); + if(c == ' ') continue; + if(c == ',') + numberOfMatrices++; + } + if(numberOfMatrixInputs != numberOfMatrices){ + throw new LanguageException("Einsum: Invalid number of parameters, given: " + numberOfMatrixInputs + ", expected: " + numberOfMatrices); + } + + if(isResultScalar){ + return Types.DataType.SCALAR; + }else { + int numberOfDimensions = 0; + Character dim1Char = null; + for (int i = 0; i < eqStringParts[1].length(); i++) { + char c = eqStringParts[i].charAt(i); + if(c == ' ') continue; + numberOfDimensions++; + if (numberOfDimensions == 1 && c == dim1Char) + throw new LanguageException("Einsum: output character "+c+" provided multiple times"); + dim1Char = c; + } + + if (numberOfDimensions > 2) { + throw new LanguageException("Einsum: output matrices with with no. dims > 2 not supported"); + } else { + return Types.DataType.MATRIX; + } + } + } + + private static long getThisCharDimension(HopOrIdentifier currArr, int arrSizeIterator) { + long thisCharDimension; + if(currArr instanceof Hop){ + thisCharDimension = arrSizeIterator == 0 ? ((Hop) currArr).getDim1() : ((Hop) currArr).getDim2(); + } else if(currArr instanceof Identifier){ + thisCharDimension = arrSizeIterator == 0 ? ((Identifier) currArr).getDim1() : ((Identifier) currArr).getDim2(); + } else { + throw new RuntimeException("validateEinsumAndReturnDimensions called with expressions that are not Hop or Identifier: "+ currArr == null ? "null" : currArr.getClass().toString()); + } + return thisCharDimension; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java index 6d230a30f08..e7aa1b5fd76 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java @@ -91,10 +91,13 @@ else if( opcode.equals(Opcodes.NM.toString()) ) { return new MatrixBuiltinNaryCPInstruction( new SimpleOperator(Multiply.getMultiplyFnObject()), opcode, str, outputOperand, inputOperands); } + else if( opcode.equals(Opcodes.EINSUM.toString()) ) { + return new EinsumCPInstruction(null, opcode, str, outputOperand, inputOperands); + } else if (OpOpN.EVAL.name().equalsIgnoreCase(opcode)) { return new EvalNaryCPInstruction(null, opcode, str, outputOperand, inputOperands); } - + throw new DMLRuntimeException("Opcode (" + opcode + ") not recognized in BuiltinMultipleCPInstruction"); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java index c99039bb7f3..b35ca55dab6 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java @@ -44,7 +44,7 @@ public enum CPType { Builtin, Reorg, Variable, FCall, Append, Rand, QSort, QPick, Local, MatrixIndexing, MMTSJ, PMMJ, MMChain, Reshape, Partition, Compression, DeCompression, SpoofFused, StringInit, CentralMoment, Covariance, UaggOuterChain, Dnn, Sql, Prefetch, Broadcast, TrigRemote, - EvictLineageCache, + EvictLineageCache, EINSUM, NoOp, Union, QuantizeCompression diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java new file mode 100644 index 00000000000..87dcf3c6048 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -0,0 +1,838 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.cp; + +import org.apache.commons.lang3.NotImplementedException; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.lang3.tuple.Triple; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.log4j.Level; +import org.apache.log4j.Logger; +import org.apache.sysds.common.Types.DataType; +import org.apache.sysds.hops.LiteralOp; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.hops.codegen.SpoofCompiler; +import org.apache.sysds.hops.codegen.cplan.CNode; +import org.apache.sysds.hops.codegen.cplan.CNodeBinary; +import org.apache.sysds.hops.codegen.cplan.CNodeCell; +import org.apache.sysds.hops.codegen.cplan.CNodeData; +import org.apache.sysds.hops.codegen.cplan.CNodeNary; +import org.apache.sysds.hops.codegen.cplan.CNodeRow; +import org.apache.sysds.runtime.codegen.*; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.einsum.EinsumContext; +import org.apache.sysds.runtime.functionobjects.*; +import org.apache.sysds.runtime.matrix.data.LibMatrixMult; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.AggregateOperator; +import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; +import org.apache.sysds.runtime.matrix.operators.BinaryOperator; +import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.matrix.operators.ReorgOperator; +import org.apache.sysds.runtime.matrix.operators.SimpleOperator; + +import java.util.*; +import java.util.function.Predicate; + +public class EinsumCPInstruction extends BuiltinNaryCPInstruction { + public static boolean FORCE_CELL_TPL = false; + protected static final Log LOG = LogFactory.getLog(EinsumCPInstruction.class.getName()); + public String eqStr; + private final int _numThreads; + private final CPOperand[] _in; + + public EinsumCPInstruction(Operator op, String opcode, String istr, CPOperand out, CPOperand... inputs) + { + super(op, opcode, istr, out, inputs); + _numThreads = OptimizerUtils.getConstrainedNumThreads(-1); + _in = inputs; + this.eqStr = inputs[0].getName(); + Logger.getLogger(EinsumCPInstruction.class).setLevel(Level.TRACE); + } + + private EinsumContext einc = null; + + @Override + public void processInstruction(ExecutionContext ec) { + //get input matrices and scalars, incl pinning of matrices + ArrayList inputs = new ArrayList<>(); + for (CPOperand input : _in) { + if(input.getDataType()==DataType.MATRIX){ + MatrixBlock mb = ec.getMatrixInput(input.getName()); + if(mb instanceof CompressedMatrixBlock){ + mb = ((CompressedMatrixBlock) mb).getUncompressed("Spoof instruction"); + } + inputs.add(mb); + } + } + + EinsumContext einc = EinsumContext.getEinsumContext(eqStr, inputs); + + this.einc = einc; + String resultString = einc.outChar2 != null ? String.valueOf(einc.outChar1) + einc.outChar2 : einc.outChar1 != null ? String.valueOf(einc.outChar1) : ""; + + if( LOG.isDebugEnabled() ) LOG.trace("outrows:"+einc.outRows+", outcols:"+einc.outCols); + + ArrayList inputsChars = einc.newEquationStringInputsSplit; + + if(LOG.isTraceEnabled()) LOG.trace(String.join(",",einc.newEquationStringInputsSplit)); + + contractDimensionsAndComputeDiagonals(einc, inputs); + + //make all vetors col vectors + for(int i = 0; i < inputs.size(); i++){ + if(inputs.get(i) != null && inputsChars.get(i).length() == 1) EnsureMatrixBlockColumnVector(inputs.get(i)); + } + + if(LOG.isTraceEnabled()) for(Character c : einc.characterAppearanceIndexes.keySet()){ + ArrayList a = einc.characterAppearanceIndexes.get(c); + LOG.trace(c+" count= "+a.size()); + } + + // compute scalar by suming-all matrices: + Double scalar = null; + for(int i=0;i< inputs.size(); i++){ + String s = inputsChars.get(i); + if(s.equals("")){ + MatrixBlock mb = inputs.get(i); + if (scalar == null) scalar = mb.get(0,0); + else scalar*= mb.get(0,0); + inputs.set(i,null); + inputsChars.set(i,null); + } + } + + if (scalar != null) { + inputsChars.add(""); + inputs.add(new MatrixBlock(scalar)); + } + + HashMap characterToOccurences = new HashMap<>(); + for (Character key :einc.characterAppearanceIndexes.keySet()) { + characterToOccurences.put(key, einc.characterAppearanceIndexes.get(key).size()); + } + for (Character key :einc.charToDimensionSize.keySet()) { + if(!characterToOccurences.containsKey(key)) + characterToOccurences.put(key, 1); + } + + ArrayList eOpNodes = new ArrayList<>(inputsChars.size()); + for (int i = 0; i < inputsChars.size(); i++) { + if (inputsChars.get(i) == null) continue; + EOpNodeData n = new EOpNodeData(inputsChars.get(i).length() > 0 ? inputsChars.get(i).charAt(0) : null, inputsChars.get(i).length() > 1 ? inputsChars.get(i).charAt(1) : null, i); + eOpNodes.add(n); + } + Pair > plan = FORCE_CELL_TPL ? null : generatePlan(0, eOpNodes, einc.charToDimensionSize, characterToOccurences, einc.outChar1, einc.outChar2); + + + ArrayList resMatrices = FORCE_CELL_TPL ? null : executePlan(plan.getRight(), inputs); +// ArrayList resMatrices = executePlan(plan.getRight(), inputs, true); + + if(!FORCE_CELL_TPL && resMatrices.size() == 1){ + EOpNode resNode = plan.getRight().get(0); + if (einc.outChar1 != null && einc.outChar2 != null){ + if(resNode.c1 == einc.outChar1 && resNode.c2 == einc.outChar2){ + ec.setMatrixOutput(output.getName(), resMatrices.get(0)); + } + else if(resNode.c1 == einc.outChar2 && resNode.c2 == einc.outChar1){ + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); + MatrixBlock resM = resMatrices.get(0).reorgOperations(transpose, new MatrixBlock(),0,0,0); + ec.setMatrixOutput(output.getName(), resM); + }else{ + if(LOG.isTraceEnabled()) LOG.trace("Einsum expected: "+resultString + ", got: "+resNode.c1+resNode.c2); + throw new RuntimeException("Einsum plan produced different result"); + } + }else if (einc.outChar1 != null){ + if(resNode.c1 == einc.outChar1 && resNode.c2 == null){ + ec.setMatrixOutput(output.getName(), resMatrices.get(0)); + }else{ + if(LOG.isTraceEnabled()) LOG.trace("Einsum expected: "+resultString + ", got: "+resNode.c1+resNode.c2); + throw new RuntimeException("Einsum plan produced different result"); + } + }else{ + if(resNode.c1 == null && resNode.c2 == null){ + ec.setScalarOutput(output.getName(), new DoubleObject(resMatrices.get(0).get(0, 0)));; + } + } + }else{ + // use cell template with loops for remaining + ArrayList mbs = resMatrices; + ArrayList chars = new ArrayList<>(); + + for (int i = 0; i < plan.getRight().size(); i++) { + String s; + if(plan.getRight().get(i).c1 == null) s = ""; + else if(plan.getRight().get(i).c2 == null) s = plan.getRight().get(i).c1.toString(); + else s = plan.getRight().get(i).c1.toString() + plan.getRight().get(i).c2; + chars.add(s); + } + + ArrayList summingChars = new ArrayList(); + for (Character c : einc.characterAppearanceIndexes.keySet()) { + if (c != einc.outChar1 && c != einc.outChar2) summingChars.add(c); + } + if(LOG.isTraceEnabled()) LOG.trace("finishing with cell tpl: "+String.join(",", chars)); + + MatrixBlock res = computeCellSummation(mbs, chars, resultString, einc.charToDimensionSize, summingChars, einc.outRows, einc.outCols); + + if (einc.outRows == 1 && einc.outCols == 1) + ec.setScalarOutput(output.getName(), new DoubleObject(res.get(0, 0))); + else ec.setMatrixOutput(output.getName(), res); + } + if(LOG.isTraceEnabled()) LOG.trace("EinsumCPInstruction Finished"); + + releaseMatrixInputs(ec); + + } + + private void contractDimensionsAndComputeDiagonals(EinsumContext einc, ArrayList inputs) { + for(int i = 0; i< einc.contractDims.length; i++){ + //AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(),Types.CorrectionLocationType.LASTCOLUMN); + AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); + + if(einc.diagonalInputs[i]){ + ReorgOperator op = new ReorgOperator(DiagIndex.getDiagIndexFnObject()); + inputs.set(i, inputs.get(i).reorgOperations(op, new MatrixBlock(),0,0,0)); + } + if (einc.contractDims[i] == null) continue; + switch (einc.contractDims[i]){ + case CONTRACT_BOTH: { + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads); + MatrixBlock res = new MatrixBlock(1, 1, false); + inputs.get(i).aggregateUnaryOperations(aggun, res, 0, null); + inputs.set(i, res); + break; + } + case CONTRACT_RIGHT: { + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads); + MatrixBlock res = new MatrixBlock(inputs.get(i).getNumRows(), 1, false); + inputs.get(i).aggregateUnaryOperations(aggun, res, 0, null); + inputs.set(i, res); + break; + } + case CONTRACT_LEFT: { + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), _numThreads); + MatrixBlock res = new MatrixBlock(inputs.get(i).getNumColumns(), 1, false); + inputs.get(i).aggregateUnaryOperations(aggun, res, 0, null); + inputs.set(i, res); + break; + } + default: + break; + } + } + } + + private enum EBinaryOperand { // upper case: char has to remain, lower case: to be summed + ////// summations: ////// + aB_a,// -> B + Ba_a, // -> B + Ba_aC, // mmult -> BC + aB_Ca, + Ba_Ca, // -> BC + aB_aC, // outer mult, possibly with transposing first -> BC + a_a,// dot -> + + ////// elementwisemult and sums, something like ij,ij->i ////// + aB_aB,// elemwise and colsum -> B + Ba_Ba, // elemwise and rowsum ->B + Ba_aB, // elemwise, either colsum or rowsum -> B +// aB_Ba, + + ////// elementwise, no summations: ////// + A_A,// v-elemwise -> A + AB_AB,// M-M elemwise -> AB + AB_BA, // M-M.T elemwise -> AB + AB_A, // M-v colwise -> BA!? + BA_A, // M-v rowwise -> BA + ab_ab,//M-M sum all + ab_ba, //M-M.T sum all + ////// other ////// + A_B, // outer mult -> AB + A_scalar, // v-scalar + AB_scalar, // m-scalar + scalar_scalar + } + private abstract class EOpNode { + public Character c1; + public Character c2; // nullable + public EOpNode(Character c1, Character c2){ + this.c1 = c1; + this.c2 = c2; + } + } + private class EOpNodeBinary extends EOpNode { + public EOpNode left; + public EOpNode right; + public EBinaryOperand operand; + public EOpNodeBinary(Character c1, Character c2, EOpNode left, EOpNode right, EBinaryOperand operand){ + super(c1,c2); + this.left = left; + this.right = right; + this.operand = operand; + } + } + private class EOpNodeData extends EOpNode { + public int matrixIdx; + public EOpNodeData(Character c1, Character c2, int matrixIdx){ + super(c1,c2); + this.matrixIdx = matrixIdx; + } + } + + private Pair /* ideally with one element */> generatePlan(int cost, ArrayList operands, HashMap charToSizeMap, HashMap charToOccurences, Character outChar1, Character outChar2) { + Integer minCost = cost; + List minNodes = operands; + + if (operands.size() == 2){ + boolean swap = (operands.get(0).c2 == null && operands.get(1).c2 != null) || operands.get(0).c1 == null; + EOpNode n1 = operands.get(!swap ? 0 : 1); + EOpNode n2 = operands.get(!swap ? 1 : 0); + Triple> t = TryCombineAndCost(n1, n2, charToSizeMap, charToOccurences, outChar1, outChar2); + if (t != null) { + EOpNodeBinary newNode = new EOpNodeBinary(t.getRight().getLeft(), t.getRight().getRight(), n1, n2, t.getMiddle()); + int thisCost = cost + t.getLeft(); + return Pair.of(thisCost, Arrays.asList(newNode)); + } + return Pair.of(cost, operands); + } + else if (operands.size() == 1){ + // check for transpose + return Pair.of(cost, operands); + } + + for(int i = 0; i < operands.size()-1; i++){ + for (int j = i+1; j < operands.size(); j++){ + boolean swap = (operands.get(i).c2 == null && operands.get(j).c2 != null) || operands.get(i).c1 == null; + EOpNode n1 = operands.get(!swap ? i : j); + EOpNode n2 = operands.get(!swap ? j : i); + + + Triple> t = TryCombineAndCost(n1, n2, charToSizeMap, charToOccurences, outChar1, outChar2); + if (t != null){ + EOpNodeBinary newNode = new EOpNodeBinary(t.getRight().getLeft(), t.getRight().getRight(), n1, n2, t.getMiddle()); + int thisCost = cost + t.getLeft(); + + if(n1.c1 != null) charToOccurences.put(n1.c1, charToOccurences.get(n1.c1)-1); + if(n1.c2 != null) charToOccurences.put(n1.c2, charToOccurences.get(n1.c2)-1); + if(n2.c1 != null) charToOccurences.put(n2.c1, charToOccurences.get(n2.c1)-1); + if(n2.c2 != null) charToOccurences.put(n2.c2, charToOccurences.get(n2.c2)-1); + + if(newNode.c1 != null) charToOccurences.put(newNode.c1, charToOccurences.get(newNode.c1)+1); + if(newNode.c2 != null) charToOccurences.put(newNode.c2, charToOccurences.get(newNode.c2)+1); + + ArrayList newOperands = new ArrayList<>(operands.size()-1); + for(int z = 0; z < operands.size(); z++){ + if(z != i && z != j) newOperands.add(operands.get(z)); + } + newOperands.add(newNode); + + Pair> furtherPlan = generatePlan(thisCost, newOperands,charToSizeMap, charToOccurences, outChar1, outChar2); + if(furtherPlan.getRight().size() < (minNodes.size()) || furtherPlan.getLeft() < minCost){ + minCost = furtherPlan.getLeft(); + minNodes = furtherPlan.getRight(); + } + + if(n1.c1 != null) charToOccurences.put(n1.c1, charToOccurences.get(n1.c1)+1); + if(n1.c2 != null) charToOccurences.put(n1.c2, charToOccurences.get(n1.c2)+1); + if(n2.c1 != null) charToOccurences.put(n2.c1, charToOccurences.get(n2.c1)+1); + if(n2.c2 != null) charToOccurences.put(n2.c2, charToOccurences.get(n2.c2)+1); + if(newNode.c1 != null) charToOccurences.put(newNode.c1, charToOccurences.get(newNode.c1)-1); + if(newNode.c2 != null) charToOccurences.put(newNode.c2, charToOccurences.get(newNode.c2)-1); + } + } + } + + return Pair.of(minCost, minNodes); + } + + private static Triple> TryCombineAndCost(EOpNode n1 , EOpNode n2, HashMap charToSizeMap, HashMap charToOccurences, Character outChar1, Character outChar2){ + Predicate cannotBeSummed = (c) -> + c == outChar1 || c == outChar2 || charToOccurences.get(c) > 2; + + if(n1.c1 == null) { + // n2.c1 also has to be null + return Triple.of(1, EBinaryOperand.scalar_scalar, Pair.of(null, null)); + } + + if(n2.c1 == null) { + if(n1.c2 == null) + return Triple.of(charToSizeMap.get(n1.c1), EBinaryOperand.A_scalar, Pair.of(n1.c1, null)); + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.AB_scalar, Pair.of(n1.c1, n1.c2)); + } + + if(n1.c1 == n2.c1){ + if(n1.c2 != null){ + if ( n1.c2 == n2.c2){ + if( cannotBeSummed.test(n1.c1)){ + if(cannotBeSummed.test(n1.c2)){ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.AB_AB, Pair.of(n1.c1, n1.c2)); + } + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ba_Ba, Pair.of(n1.c1, null)); + } + + if(cannotBeSummed.test(n1.c2)){ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.aB_aB, Pair.of(n1.c2, null)); + } + + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.ab_ab, Pair.of(null, null)); + + } + + else if(n2.c2 == null){ + if(cannotBeSummed.test(n1.c1)){ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*2, EBinaryOperand.AB_A, Pair.of(n1.c1, n1.c2)); + } + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*2, EBinaryOperand.aB_a, Pair.of(n1.c2, null)); // in theory (null, n1.c2) + } + else if(n1.c1 ==outChar1 || n1.c1==outChar2|| charToOccurences.get(n1.c1) > 2){ + return null;// AB,AC + } + else { + return Triple.of((charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2))+(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c2)), EBinaryOperand.aB_aC, Pair.of(n1.c2, n2.c2)); // or n2.c2, n1.c2 + } + }else{ // n1.c2 = null -> c2.c2 = null + if(n1.c1 ==outChar1 || n1.c1==outChar2 || charToOccurences.get(n1.c1) > 2){ + return Triple.of(charToSizeMap.get(n1.c1), EBinaryOperand.A_A, Pair.of(n1.c1, null)); + } + return Triple.of(charToSizeMap.get(n1.c1), EBinaryOperand.a_a, Pair.of(null, null)); + } + + + }else{ // n1.c1 != n2.c1 + if(n1.c2 == null) { + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.A_B, Pair.of(n1.c1, n2.c1)); + } + else if(n2.c2 == null) { // ab,c + if (n1.c2 == n2.c1) { + if(cannotBeSummed.test(n1.c2)){ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.BA_A, Pair.of(n1.c1, n1.c2)); + } + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.Ba_a, Pair.of(n1.c1, null)); + } + return null; // AB,C + } + else if (n1.c2 == n2.c1) { + if(n1.c1 == n2.c2){ // ab,ba + if(cannotBeSummed.test(n1.c1)){ + if(cannotBeSummed.test(n1.c2)){ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.AB_BA, Pair.of(n1.c1, n1.c2)); + } + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ba_aB, Pair.of(n1.c1, null)); + } + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.ab_ba, Pair.of(null, null)); + } + if(cannotBeSummed.test(n1.c2)){ + return null; // AB_B + }else{ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c2), EBinaryOperand.Ba_aC, Pair.of(n1.c1, n2.c2)); +// if(n1.c1 ==outChar1 || n1.c1==outChar2|| charToOccurences.get(n1.c1) > 2){ +// return null; // AB_B +// } +// return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ba_a, Pair.of(n1.c1, null)); + } + } + if(n1.c1 == n2.c2) { + if(cannotBeSummed.test(n1.c1)){ + return null; // AB_B + } + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c1), EBinaryOperand.aB_Ca, Pair.of(n2.c1, n1.c2)); // * its just reorder of mmult + } + else if (n1.c2 == n2.c2) { + if(n1.c2 ==outChar1 || n1.c2==outChar2|| charToOccurences.get(n1.c2) > 2){ + return null; // BA_CA + }else{ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2) +(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c1)), EBinaryOperand.Ba_Ca, Pair.of(n1.c1, n2.c1)); // or n2.c1, n1.c1 + } + } + else { // we have something like ab,cd + return null; + } + } + } + + private ArrayList executePlan(List plan, ArrayList inputs){ + return executePlan(plan, inputs, false); + } + private ArrayList executePlan(List plan, ArrayList inputs, boolean codegen) { + ArrayList res = new ArrayList<>(plan.size()); + for(EOpNode p : plan){ + if(codegen) res.add(ComputeEOpNodeCodegen(p, inputs)); + else res.add(ComputeEOpNode(p, inputs)); + } + return res; + } + + private MatrixBlock ComputeEOpNode(EOpNode eOpNode, ArrayList inputs){ + if(eOpNode instanceof EOpNodeData eOpNodeData){ + return inputs.get(eOpNodeData.matrixIdx); + } + EOpNodeBinary bin = (EOpNodeBinary) eOpNode; + MatrixBlock left = ComputeEOpNode(bin.left, inputs); + MatrixBlock right = ComputeEOpNode(bin.right, inputs); + + AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); + + MatrixBlock res; + switch (bin.operand){ + case AB_AB -> { + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + } + case A_A -> { + EnsureMatrixBlockColumnVector(left); + EnsureMatrixBlockColumnVector(right); + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + } + case a_a -> { + EnsureMatrixBlockColumnVector(left); + EnsureMatrixBlockColumnVector(right); + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads); + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); + } + //////////// + case Ba_Ba -> { + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads); + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); + } + case aB_aB -> { + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), _numThreads); + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); + EnsureMatrixBlockColumnVector(res); + } + case ab_ab -> { + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads); + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); + } + case ab_ba -> { + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); + right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads); + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); + } + case Ba_aB -> { + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); + right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads); + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); + } + + ///////// + case AB_BA -> { + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); + right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + } + case Ba_aC -> { + res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), _numThreads); + } + case aB_Ca -> { + res = LibMatrixMult.matrixMult(right,left, new MatrixBlock(), _numThreads); + } + case Ba_Ca -> { + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); + right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), _numThreads); + } + case aB_aC -> { + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); + left = left.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), _numThreads); + } + case A_scalar, AB_scalar -> { + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left},new ScalarObject[]{new DoubleObject(right.get(0,0))}, new MatrixBlock()); + } + case BA_A -> { + EnsureMatrixBlockRowVector(right); + res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); + } + case Ba_a -> { + EnsureMatrixBlockRowVector(right); + res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads); + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); + } + + case AB_A -> { + EnsureMatrixBlockColumnVector(right); + res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); + } + case aB_a -> { + EnsureMatrixBlockColumnVector(right); + res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), _numThreads); + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); + EnsureMatrixBlockColumnVector(res); + } + + case A_B -> { + EnsureMatrixBlockColumnVector(left); + EnsureMatrixBlockRowVector(right); + res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); + } + case scalar_scalar -> { + return new MatrixBlock(left.get(0,0)*right.get(0,0)); + } + default -> { + throw new IllegalArgumentException("Unexpected value: " + bin.operand.toString()); + } + + } + return res; + } + + private static MatrixBlock ComputeEOpNodeCodegen(EOpNode eOpNode, ArrayList inputs){ + return rComputeEOpNodeCodegen(eOpNode, inputs); +// throw new NotImplementedException(); + } + private static CNodeData MatrixBlockToCNodeData(MatrixBlock mb, int id){ + return new CNodeData("ce"+id, id, mb.getNumRows(), mb.getNumColumns(), DataType.MATRIX); + } + private static MatrixBlock rComputeEOpNodeCodegen(EOpNode eOpNode, ArrayList inputs) { + if (eOpNode instanceof EOpNodeData eOpNodeData){ + return inputs.get(eOpNodeData.matrixIdx); +// return new CNodeData("ce"+eOpNodeData.matrixIdx, eOpNodeData.matrixIdx, inputs.get(eOpNodeData.matrixIdx).getNumRows(), inputs.get(eOpNodeData.matrixIdx).getNumColumns(), DataType.MATRIX); + } + + EOpNodeBinary bin = (EOpNodeBinary) eOpNode; +// CNodeData dataLeft = null; +// if (bin.left instanceof EOpNodeData eOpNodeData) dataLeft = new CNodeData("ce"+eOpNodeData.matrixIdx, eOpNodeData.matrixIdx, inputs.get(eOpNodeData.matrixIdx).getNumRows(), inputs.get(eOpNodeData.matrixIdx).getNumColumns(), DataType.MATRIX); +// CNodeData dataRight = null; +// if (bin.right instanceof EOpNodeData eOpNodeData) dataRight = new CNodeData("ce"+eOpNodeData.matrixIdx, eOpNodeData.matrixIdx, inputs.get(eOpNodeData.matrixIdx).getNumRows(), inputs.get(eOpNodeData.matrixIdx).getNumColumns(), DataType.MATRIX); + + if(bin.operand == EBinaryOperand.AB_AB){ + if (bin.right instanceof EOpNodeBinary rBinary && rBinary.operand == EBinaryOperand.AB_AB){ + MatrixBlock left = rComputeEOpNodeCodegen(bin.left, inputs); + + MatrixBlock right1 = rComputeEOpNodeCodegen(((EOpNodeBinary) bin.right).left, inputs); + MatrixBlock right2 = rComputeEOpNodeCodegen(((EOpNodeBinary) bin.right).right, inputs); + + CNodeData d0 = MatrixBlockToCNodeData(left, 0); + CNodeData d1 = MatrixBlockToCNodeData(right1, 1); + CNodeData d2 = MatrixBlockToCNodeData(right2, 2); +// CNodeNary nary = new CNodeNary(cnodeIn, CNodeNary.NaryType.) + CNodeBinary rightBinary = new CNodeBinary(d1, d2, CNodeBinary.BinType.VECT_MULT); + CNodeBinary cNodeBinary = new CNodeBinary(d0, rightBinary, CNodeBinary.BinType.VECT_MULT); + ArrayList cnodeIn = new ArrayList<>(); + cnodeIn.add(d0); + cnodeIn.add(d1); + cnodeIn.add(d2); + + CNodeRow cnode = new CNodeRow(cnodeIn, cNodeBinary); + + cnode.setRowType(SpoofRowwise.RowType.NO_AGG); + cnode.renameInputs(); + + + String src = cnode.codegen(false, SpoofCompiler.GeneratorAPI.JAVA); + if( LOG.isTraceEnabled()) LOG.trace(CodegenUtils.printWithLineNumber(src)); + Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); + + SpoofOperator op = CodegenUtils.createInstance(cla); + MatrixBlock mb = new MatrixBlock(); + + ArrayList scalars = new ArrayList<>(); + ArrayList mbs = new ArrayList<>(3); + mbs.add(left); + mbs.add(right1); + mbs.add(right2); + MatrixBlock out = op.execute(mbs, scalars, mb, 6); + + return out; + } + } + + throw new NotImplementedException(); + } + + + private void releaseMatrixInputs(ExecutionContext ec){ + for (CPOperand input : _in) + if(input.getDataType()==DataType.MATRIX) + ec.releaseMatrixInput(input.getName()); //todo release other + } + + private static void EnsureMatrixBlockColumnVector(MatrixBlock mb){ + if(mb.getNumColumns() > 1){ + mb.setNumRows(mb.getNumColumns()); + mb.setNumColumns(1); + mb.getDenseBlock().resetNoFill(mb.getNumRows(),1); + } + } + private static void EnsureMatrixBlockRowVector(MatrixBlock mb){ + if(mb.getNumRows() > 1){ + mb.setNumColumns(mb.getNumRows()); + mb.setNumRows(1); + mb.getDenseBlock().resetNoFill(1,mb.getNumColumns()); + } + } + + private static void indent(StringBuilder sb, int level) { + for (int i = 0; i < level; i++) { + sb.append(" "); + } + } + + private MatrixBlock computeCellSummation(ArrayList inputs, List inputsChars, String resultString, + HashMap charToDimensionSizeInt, List summingChars, int outRows, int outCols){ + ArrayList cnodeIn = new ArrayList<>(); + cnodeIn.add(new CNodeData(new LiteralOp(3), 0, 0, DataType.SCALAR)); + CNodeCell cnode = new CNodeCell(cnodeIn, null); + StringBuilder sb = new StringBuilder(); + + int indent = 2; + indent(sb, indent); + + boolean needsSumming = summingChars.stream().anyMatch(x -> x != null); + + String itVar0 = cnode.createVarname(); + String outVar = itVar0; + if (needsSumming) { + sb.append("double "); + sb.append(outVar); + sb.append("=0;\n"); + } + + HashSet summedCharacters = new HashSet<>(); + Iterator hsIt = summingChars.iterator(); + while (hsIt.hasNext()) { + indent(sb, indent); + indent++; + Character c = hsIt.next(); + String itVar = itVar0 + c; + sb.append("for(int "); + sb.append(itVar); + sb.append("=0;"); + sb.append(itVar); + sb.append("<"); + sb.append(charToDimensionSizeInt.get(c)); + sb.append(";"); + sb.append(itVar); + sb.append("++){\n"); + } + indent(sb, indent); + if (needsSumming) { + sb.append(outVar); + sb.append("+="); + } + + for (int i = 0; i < inputsChars.size(); i++) { + if (inputsChars.get(i).length() == 0){ + sb.append("getValue(b["); + sb.append(i); + sb.append("],b["); + sb.append(i); + sb.append("].clen, 0,"); + } + + else if (summingChars.contains(inputsChars.get(i).charAt(0))) { + sb.append("getValue(b["); + sb.append(i); + sb.append("],b["); + sb.append(i); + sb.append("].clen,"); + sb.append(itVar0); + sb.append(inputsChars.get(i).charAt(0)); + sb.append(","); + } else if (resultString.length() >= 1 && inputsChars.get(i).charAt(0) == resultString.charAt(0)) { + sb.append("getValue(b["); + sb.append(i); + sb.append("],b["); + sb.append(i); + sb.append("].clen, rix,"); + } else if (resultString.length() == 2 && inputsChars.get(i).charAt(0) == resultString.charAt(1)) { + sb.append("getValue(b["); + sb.append(i); + sb.append("],b["); + sb.append(i); + sb.append("].clen, cix,"); + } else { + sb.append("getValue(b["); + sb.append(i); + sb.append("],b["); + sb.append(i); + sb.append("].clen, 0,"); + } + + if (inputsChars.get(i).length() != 2){ + sb.append("0)"); + } + else if (summingChars.contains(inputsChars.get(i).charAt(1))) { + sb.append(itVar0); + sb.append(inputsChars.get(i).charAt(1)); + sb.append(")"); + } else if (resultString.length() >= 1 &&inputsChars.get(i).charAt(1) == resultString.charAt(0)) { + sb.append("rix)"); + } else if (resultString.length() == 2 && inputsChars.get(i).charAt(1) == resultString.charAt(1)) { + sb.append("cix)"); + } else { + sb.append("0)"); + } + + if (i < inputsChars.size() - 1) { + sb.append(" * "); + } + + } + if (needsSumming) { + sb.append(";\n"); + } + indent--; + for (int si = 0; si < summingChars.size(); si++) { + indent(sb, indent); + indent--; + sb.append("}\n"); + } + String src = CNodeCell.JAVA_TEMPLATE;// + src = src.replace("%TMP%", cnode.createVarname()); + src = src.replace("%TYPE%", "NO_AGG"); + src = src.replace("%SPARSE_SAFE%", "false"); + src = src.replace("%SEQ%", "true"); + src = src.replace("%AGG_OP_NAME%", "null"); + if (needsSumming) { + src = src.replace("%BODY_dense%", sb.toString()); + src = src.replace("%OUT%", outVar); + } else { + src = src.replace("%BODY_dense%", ""); + src = src.replace("%OUT%", sb.toString()); + } + + if( LOG.isTraceEnabled()) LOG.trace(src); + Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); + SpoofOperator op = CodegenUtils.createInstance(cla); + MatrixBlock resBlock = new MatrixBlock(); + resBlock.reset(outRows, outCols); + inputs.add(0, resBlock); + MatrixBlock out = op.execute(inputs, new ArrayList<>(), new MatrixBlock(), _numThreads); + + return out; + } + + public CPOperand[] getInputs() { + return _in; + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java new file mode 100644 index 00000000000..04ea3b35a0e --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.einsum; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.After; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; + +@RunWith(Parameterized.class) +public class EinsumTest extends AutomatedTestBase +{ + final private static List TEST_CONFIGS = List.of( + new Config("ij,jk->ik", List.of(shape(50, 600), shape(600, 10))), // mm + new Config("ji,jk->ik", List.of(shape(600, 5), shape(600, 10))), + new Config("ji,kj->ik", List.of(shape(600, 5), shape(10, 600))), + new Config("ij,kj->ik", List.of(shape(5, 600), shape(10, 600))), + + new Config("ji,jk->i", List.of(shape(600, 5), shape(600, 10))), + new Config("ij,jk->i", List.of(shape(5, 600), shape(600, 10))), + + new Config("ji,jk->k", List.of(shape(600, 5), shape(600, 10))), + new Config("ij,jk->k", List.of(shape(5, 600), shape(600, 10))), + + new Config("ji,jk->j", List.of(shape(600, 5), shape(600, 10))), + + new Config("ji,ji->ji", List.of(shape(600, 5), shape(600, 5))), // elemwise mult + new Config("ji,ji,ji->ji", List.of(shape(600, 5),shape(600, 5), shape(600, 5)), + List.of(0.0001, 0.0005, 0.001)), + new Config("ji,ij->ji", List.of(shape(600, 5), shape(5, 600))), // elemwise mult + + + new Config("ij,i->ij", List.of(shape(100, 50), shape(100))), // col mult + new Config("ji,i->ij", List.of(shape(50, 100), shape(100))), // row mult + new Config("ij,i->i", List.of(shape(100, 50), shape(100))), + new Config("ij,i->j", List.of(shape(100, 50), shape(100))), + + new Config("i,i->", List.of(shape(50), shape(50))), + new Config("i,j->", List.of(shape(50), shape(80))), + new Config("i,j->ij", List.of(shape(50), shape(80))), // outer vect mult + new Config("i,j->ji", List.of(shape(50), shape(80))), // outer vect mult + + new Config("ij->", List.of(shape(100, 50))), // sum + new Config("ij->i", List.of(shape(100, 50))), // sum(1) + new Config("ij->j", List.of(shape(100, 50))), // sum(0) + new Config("ij->ji", List.of(shape(100, 50))), // T + + new Config("ab,cd->ba", List.of(shape( 600, 10), shape(6, 5))), + new Config("ab,cd,g->ba", List.of(shape( 600, 10), shape(6, 5), shape(3))), + + new Config("ab,bc,cd,de->ae", List.of(shape(5, 600), shape(600, 10), shape(10, 5), shape(5, 4))), // chain of mm + + new Config("ji,jz,zx->ix", List.of(shape(600, 5), shape( 600, 10), shape(10, 2))), + new Config("fx,fg,fz,xg->z", List.of(shape(600, 5), shape( 600, 10), shape(600, 6), shape(5, 10))), + new Config("fx,fg,fz,xg,zx,zg->g", // each idx 3 times (cell tpl) + List.of(shape(5, 60), shape(5, 30), shape(5, 10), shape(60, 30), shape(10, 60), shape(10, 30))), + + new Config("i->", List.of(shape(100))), + new Config("i->i", List.of(shape(100))) + ); + + private final int id; + private final String einsumStr; + private final List shapes; + private final File dmlFile; + private final File rFile; + private final boolean outputScalar; + + public EinsumTest(String einsumStr, List shapes, File dmlFile, File rFile, boolean outputScalar, int id){ + this.id = id; + this.einsumStr = einsumStr; + this.shapes = shapes; + this.dmlFile = dmlFile; + this.rFile = rFile; + this.outputScalar = outputScalar; + } + + @Parameterized.Parameters(name = "{index}: einsum={0}") + public static Collection data() throws IOException { + List parameters = new ArrayList<>(); + + int counter = 1; + + for (Config config : TEST_CONFIGS) { + List files = new ArrayList<>(); + String fullDMLScriptName = "SystemDS_einsum_test" + counter; + + File dmlFile = File.createTempFile(fullDMLScriptName, ".dml"); + dmlFile.deleteOnExit(); + + boolean outputScalar = config.einsumStr.trim().endsWith("->"); + + StringBuilder sb = createDmlFile(config, outputScalar); + + Files.writeString(dmlFile.toPath(), sb.toString()); + + File rFile = File.createTempFile(fullDMLScriptName, ".R"); + rFile.deleteOnExit(); + + sb = createRFile(config, outputScalar); + + Files.writeString(rFile.toPath(), sb.toString()); + + parameters.add(new Object[]{config.einsumStr, config.shapes, dmlFile, rFile, outputScalar, counter}); + + counter++; + } + + return parameters; + } + + private static StringBuilder createDmlFile(Config config, boolean outputScalar) { + StringBuilder sb = new StringBuilder(); + + for (int i = 0; i < config.shapes.size(); i++) { + int[] dims = config.shapes.get(i); + + double factor = config.factors != null ? config.factors.get(i) : 0.0001; + sb.append("A"); + sb.append(i); + + if (dims.length == 1) { // A1 = seq(1,1000) * 0.0001 + sb.append(" = seq(1,"); + sb.append(dims[0]); + sb.append(") * "); + sb.append(factor); + } else { // A0 = matrix(seq(1,50000), 1000, 50) * 0.0001 + sb.append(" = matrix(seq(1, "); + sb.append(dims[0]*dims[1]); + sb.append("), "); + sb.append(dims[0]); + sb.append(", "); + sb.append(dims[1]); + + sb.append(") * "); + sb.append(factor); + } + sb.append("\n"); + } + sb.append("\n"); + + sb.append("R = einsum(\""); + sb.append(config.einsumStr); + sb.append("\", "); + + for (int i = 0; i < config.shapes.size() - 1; i++) { + sb.append("A"); + sb.append(i); + sb.append(", "); + } + sb.append("A"); + sb.append(config.shapes.size() - 1); + sb.append(")"); + + sb.append("\n\n"); + sb.append("write(R, $1)\n"); + return sb; + } + + private static StringBuilder createRFile(Config config, boolean outputScalar) { + StringBuilder sb = new StringBuilder(); + sb.append("args<-commandArgs(TRUE)\n"); + sb.append("options(digits=22)\n"); + sb.append("library(\"Matrix\")\n"); + sb.append("library(\"matrixStats\")\n"); + sb.append("library(\"einsum\")\n\n"); + + + for (int i = 0; i < config.shapes.size(); i++) { + int[] dims = config.shapes.get(i); + + double factor = config.factors != null ? config.factors.get(i) : 0.0001; + sb.append("A"); + sb.append(i); + + if (dims.length == 1) { // A1 = seq(1,1000) * 0.0001 + sb.append(" = seq(1,"); + sb.append(dims[0]); + sb.append(") * "); + sb.append(factor); + } else { // A0 = matrix(seq(1,50000), 1000, 50, byrow=TRUE) * 0.0001 + sb.append(" = matrix(seq(1, "); + sb.append(dims[0]*dims[1]); + sb.append("), "); + sb.append(dims[0]); + sb.append(", "); + sb.append(dims[1]); + + sb.append(", byrow=TRUE) * "); + sb.append(factor); + } + sb.append("\n"); + } + sb.append("\n"); + + sb.append("R = einsum(\""); + sb.append(config.einsumStr); + sb.append("\", "); + + for (int i = 0; i < config.shapes.size()-1; i++) { + sb.append("A"); + sb.append(i); + sb.append(", "); + } + sb.append("A"); + sb.append(config.shapes.size()-1); + sb.append(")"); + + sb.append("\n\n"); + if(outputScalar){ + sb.append("write(R, paste(args[2], \"S\", sep=\"\"))\n"); + }else{ + sb.append("writeMM(as(R, \"CsparseMatrix\"), paste(args[2], \"S\", sep=\"\"))\n"); + } + return sb; + } + + @Test + public void testEinsumWithFiles() { + System.out.println("Testing einsum: " + this.einsumStr); + testCodegenIntegration(TEST_NAME_EINSUM+this.id); + } + @After + public void cleanUp() { + if (this.dmlFile.exists()) { + boolean deleted = this.dmlFile.delete(); + if (!deleted) { + System.err.println("Failed to delete temp file: " + this.dmlFile.getAbsolutePath()); + } + } + if (this.rFile.exists()) { + boolean deleted = this.rFile.delete(); + if (!deleted) { + System.err.println("Failed to delete temp file: " + this.rFile.getAbsolutePath()); + } + } + } + + private static class Config { + public List factors; + String einsumStr; + List shapes; + + Config(String einsum, List shapes) { + this.einsumStr = einsum; + this.shapes = shapes; + this.factors = null; + } + Config(String einsum, List shapes, List factors) { + this.einsumStr = einsum; + this.shapes = shapes; + this.factors = factors; + } + } + + private static int[] shape(int... dims) { + return dims; + } + private static final Log LOG = LogFactory.getLog(EinsumTest.class.getName()); + + private static final String TEST_NAME_EINSUM = "einsum"; + private static final String TEST_DIR = "functions/einsum/"; + private static final String TEST_CLASS_DIR = TEST_DIR + EinsumTest.class.getSimpleName() + "/"; + private final static String TEST_CONF = "SystemDS-config-codegen.xml"; + private final static File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, TEST_CONF); + + private static double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + for(int i = 1; i<= TEST_CONFIGS.size(); i++) + addTestConfiguration( TEST_NAME_EINSUM+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_EINSUM+i, new String[] { String.valueOf(i) }) ); + } + + private void testCodegenIntegration( String testname) + { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + ExecMode platformOld = setExecMode(ExecType.CP); + + String testnameDml = this.dmlFile.getAbsolutePath(); + String testnameR = this.rFile.getAbsolutePath(); + try + { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = testnameDml; + programArgs = new String[]{"-stats", "-explain", "-args", output("S") }; + + fullRScriptName = testnameR; + rCmd = getRCmd(inputDir(), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = false; + + runTest(true, false, null, -1); + runRScript(true); + + if(outputScalar){ + HashMap dmlfile = readDMLScalarFromOutputDir("S"); + HashMap rfile = readRScalarFromExpectedDir("S"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + }else { + //compare matrices + HashMap dmlfile = readDMLMatrixFromOutputDir("S"); + HashMap rfile = readRMatrixFromExpectedDir("S"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + } + } + finally { + resetExecMode(platformOld); + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + OptimizerUtils.ALLOW_AUTO_VECTORIZATION = true; + OptimizerUtils.ALLOW_OPERATOR_FUSION = true; + } + } + + + /** + * Override default configuration with custom test configuration to ensure + * scratch space and local temporary directory locations are also updated. + */ + @Override + protected File getConfigTemplateFile() { + // Instrumentation in this test's output log to show custom configuration file used for template. + LOG.debug("This test case overrides default configuration with " + TEST_CONF_FILE.getPath()); + return TEST_CONF_FILE; + } +} diff --git a/src/test/scripts/functions/einsum/SystemDS-config-codegen.xml b/src/test/scripts/functions/einsum/SystemDS-config-codegen.xml new file mode 100644 index 00000000000..626b31ebd76 --- /dev/null +++ b/src/test/scripts/functions/einsum/SystemDS-config-codegen.xml @@ -0,0 +1,31 @@ + + + + /tmp/systemds + scratch_space + 2 + true + 1 + + + 16 + + auto + \ No newline at end of file diff --git a/src/test/scripts/installDependencies.R b/src/test/scripts/installDependencies.R index af89f2b936e..60642fa8ed4 100644 --- a/src/test/scripts/installDependencies.R +++ b/src/test/scripts/installDependencies.R @@ -64,6 +64,7 @@ custom_install("unbalanced"); custom_install("naivebayes"); custom_install("BiocManager"); custom_install("mltools"); +custom_install("einsum"); BiocManager::install("rhdf5"); print("Installation Done")