From ad3ec9d1b387bcfd80a9f126e516562397b933a4 Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Sun, 18 May 2025 21:55:32 +0200 Subject: [PATCH 01/28] initial working version --- .../org/apache/sysds/common/Builtins.java | 1 + .../java/org/apache/sysds/common/Opcodes.java | 1 + .../java/org/apache/sysds/common/Types.java | 4 +- .../java/org/apache/sysds/hops/NaryOp.java | 5 + src/main/java/org/apache/sysds/lops/Nary.java | 1 + .../parser/BuiltinFunctionExpression.java | 141 +++++ .../apache/sysds/parser/DMLTranslator.java | 5 +- .../instructions/CPInstructionParser.java | 4 +- .../instructions/cp/CPInstruction.java | 2 +- .../instructions/cp/EinsumCPInstruction.java | 482 ++++++++++++++++++ .../instructions/cp/EinsumContext.java | 201 ++++++++ .../test/functions/einsum/EinsumTest.java | 139 +++++ .../einsum/SystemDS-config-codegen.xml | 31 ++ src/test/scripts/functions/einsum/einsum1.R | 34 ++ src/test/scripts/functions/einsum/einsum1.dml | 30 ++ src/test/scripts/functions/einsum/einsum2.R | 34 ++ src/test/scripts/functions/einsum/einsum2.dml | 30 ++ src/test/scripts/functions/einsum/einsum3.R | 34 ++ src/test/scripts/functions/einsum/einsum3.dml | 30 ++ src/test/scripts/functions/einsum/einsum4.R | 34 ++ src/test/scripts/functions/einsum/einsum4.dml | 29 ++ src/test/scripts/functions/einsum/einsum5.R | 34 ++ src/test/scripts/functions/einsum/einsum5.dml | 30 ++ src/test/scripts/functions/einsum/einsum6.R | 34 ++ src/test/scripts/functions/einsum/einsum6.dml | 30 ++ src/test/scripts/functions/einsum/einsum7.R | 35 ++ src/test/scripts/functions/einsum/einsum7.dml | 28 + 27 files changed, 1458 insertions(+), 5 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java create mode 100644 src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumContext.java create mode 100644 src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java create mode 100644 src/test/scripts/functions/einsum/SystemDS-config-codegen.xml create mode 100644 src/test/scripts/functions/einsum/einsum1.R create mode 100644 src/test/scripts/functions/einsum/einsum1.dml create mode 100644 src/test/scripts/functions/einsum/einsum2.R create mode 100644 src/test/scripts/functions/einsum/einsum2.dml create mode 100644 src/test/scripts/functions/einsum/einsum3.R create mode 100644 src/test/scripts/functions/einsum/einsum3.dml create mode 100644 src/test/scripts/functions/einsum/einsum4.R create mode 100644 src/test/scripts/functions/einsum/einsum4.dml create mode 100644 src/test/scripts/functions/einsum/einsum5.R create mode 100644 src/test/scripts/functions/einsum/einsum5.dml create mode 100644 src/test/scripts/functions/einsum/einsum6.R create mode 100644 src/test/scripts/functions/einsum/einsum6.dml create mode 100644 src/test/scripts/functions/einsum/einsum7.R create mode 100644 src/test/scripts/functions/einsum/einsum7.dml diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index 4ff5654de02..ba2fad7c17b 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -398,6 +398,7 @@ public enum Builtins { UNDER_SAMPLING("underSampling", true), UNIQUE("unique", false, true), UPPER_TRI("upper.tri", false, true), + EINSUM("einsum", false, false), XDUMMY1("xdummy1", true), //error handling test XDUMMY2("xdummy2", true); //error handling test diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java b/src/main/java/org/apache/sysds/common/Opcodes.java index a878d3f0ace..d3dca547b87 100644 --- a/src/main/java/org/apache/sysds/common/Opcodes.java +++ b/src/main/java/org/apache/sysds/common/Opcodes.java @@ -200,6 +200,7 @@ public enum Opcodes { TRANSFORMCOLMAP("transformcolmap", CPType.ParameterizedBuiltin), TRANSFORMMETA("transformmeta", CPType.ParameterizedBuiltin), TRANSFORMENCODE("transformencode", CPType.MultiReturnParameterizedBuiltin), + EINSUM("einsum", CPType.EINSUM), //Ternary instruction opcodes PM("+*", CPType.Ternary), diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index c9820a2c092..d8f47a9c41a 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -761,8 +761,8 @@ public String toString() { /** Operations that require a variable number of operands*/ public enum OpOpN { - PRINTF, CBIND, RBIND, MIN, MAX, PLUS, MULT, EVAL, LIST; - + PRINTF, CBIND, RBIND, MIN, MAX, PLUS, MULT, EVAL, LIST, EINSUM; + public boolean isCellOp() { return this == MIN || this == MAX || this == PLUS || this == MULT; } diff --git a/src/main/java/org/apache/sysds/hops/NaryOp.java b/src/main/java/org/apache/sysds/hops/NaryOp.java index 1659b0dbc5e..438cf638515 100644 --- a/src/main/java/org/apache/sysds/hops/NaryOp.java +++ b/src/main/java/org/apache/sysds/hops/NaryOp.java @@ -235,6 +235,11 @@ public void refreshSizeInformation() { setDim1(getInput().size()); setDim2(1); break; + case EINSUM: + setDataType(DataType.MATRIX); + setDim1(getInput().size()); + setDim2(1); + break; case PRINTF: case EVAL: //do nothing: diff --git a/src/main/java/org/apache/sysds/lops/Nary.java b/src/main/java/org/apache/sysds/lops/Nary.java index e073bc68817..e5382ba0330 100644 --- a/src/main/java/org/apache/sysds/lops/Nary.java +++ b/src/main/java/org/apache/sysds/lops/Nary.java @@ -111,6 +111,7 @@ private String getOpcode() { case RBIND: case EVAL: case LIST: + case EINSUM: return operationType.name().toLowerCase(); case MIN: case MAX: diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java index 6a68f867f90..81294aae420 100644 --- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java @@ -23,6 +23,8 @@ import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedList; +import java.util.Iterator; import org.antlr.v4.runtime.ParserRuleContext; import org.apache.commons.lang3.ArrayUtils; @@ -751,7 +753,10 @@ else if(((ConstIdentifier) getThirdExpr().getOutput()) else raiseValidateError("Compress/DeCompress instruction not allowed in dml script"); break; + case EINSUM: + validateEinsum((DataIdentifier) getOutputs()[0]); + break; default: //always unconditional raiseValidateError("Unknown Builtin Function opcode: " + _opcode, false); } @@ -2032,7 +2037,10 @@ else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AV output.setValueType(ValueType.INT64); output.setNnz(id.getDim2()); break; + case EINSUM: + validateEinsum(output); + break; default: if( isMathFunction() ) { checkMathFunctionParam(); @@ -2065,6 +2073,139 @@ else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AV } } + private void validateEinsum(DataIdentifier output){ + if(getSecondExpr() == null) + raiseValidateError("Einsum: at least one input matrix required", false, + LanguageErrorCodes.INVALID_PARAMETERS); + + if(!(getFirstExpr() instanceof StringIdentifier)) + raiseValidateError("Einsum: first argument has to be equation str", false, + LanguageErrorCodes.INVALID_PARAMETERS); + + String eq_string = ((StringIdentifier)getFirstExpr()).getValue(); + + String[] parts = eq_string.split("->"); + + if(parts.length != 2) + raiseValidateError("Einsum: equation str should contain one '->' substring", false, + LanguageErrorCodes.INVALID_PARAMETERS); + + Expression[] expressions = getAllExpr(); + boolean allDimsKnown = true; + + LinkedList matrixBlocks = new LinkedList(); + for (int i=1;i charToDimensionSize = new HashMap<>(); + + Iterator it = matrixBlocks.iterator(); + Identifier curArr = it.next(); + int arrSizeIterator = 0; + int numberOfMatrices = 1; + for (int i = 0; i numberOfMatrices){ + raiseValidateError("Einsum: Provided more operands than specified in equation str", + false, LanguageErrorCodes.INVALID_PARAMETERS); + } + int numberOfDimensions = 0; + long dim1 = 0; + long dim2 = 0; + for (int i = 0; i2){ + raiseValidateError("Einsum: output matrices with with no. dims > 2 not supported", + false, LanguageErrorCodes.INVALID_PARAMETERS); + }else { + output.setDataType(DataType.MATRIX); + output.setDimensions(dim1, dim2); + } + }else{ + int numberOfMatrices = 1; + for (int i = 0; i < parts[0].length(); i++) { + if(parts[0].charAt(i) == ',') + numberOfMatrices++; + } + checkNumParameters(numberOfMatrices+1); + + int numberOfDimensions = 0; + + for (int i = 0; i2){ + raiseValidateError("Einsum: output matrices with with no. dims > 2 not supported", + false, LanguageErrorCodes.INVALID_PARAMETERS); + }else{ + output.setDataType(DataType.MATRIX); + output.setDimensions(-1, -1); + } + } + output.setValueType(ValueType.FP64); + output.setBlocksize(getSecondExpr().getOutput().getBlocksize()); + } + private void setBinaryOutputProperties(DataIdentifier output) { DataType dt1 = getFirstExpr().getOutput().getDataType(); DataType dt2 = getSecondExpr().getOutput().getDataType(); diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java index cd58af62337..ce7adcbaa82 100644 --- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java @@ -2446,7 +2446,10 @@ private Hop processBuiltinFunctionExpression(BuiltinFunctionExpression source, D new NaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOpN.valueOf(source.getOpCode().name()), processAllExpressions(source.getAllExpr(), hops)); break; - + case EINSUM: + currBuiltinOp = new NaryOp(target.getName(), target.getDataType(), target.getValueType(), + OpOpN.valueOf(source.getOpCode().name()), processAllExpressions(source.getAllExpr(), hops)); + break; case PPRED: String sop = ((StringIdentifier)source.getThirdExpr()).getValue(); sop = sop.replace("\"", ""); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java index e30183a5067..9a171cbdc71 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java @@ -64,6 +64,7 @@ import org.apache.sysds.runtime.instructions.cp.UaggOuterChainCPInstruction; import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction; import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction; +import org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction; import org.apache.sysds.runtime.instructions.cpfile.MatrixIndexingCPFileInstruction; public class CPInstructionParser extends InstructionParser { @@ -214,7 +215,8 @@ public static CPInstruction parseSingleInstruction ( CPType cptype, String str ) case EvictLineageCache: return EvictCPInstruction.parseInstruction(str); - + case EINSUM: + return EinsumCPInstruction.parseInstruction(str); default: throw new DMLRuntimeException("Invalid CP Instruction Type: " + cptype ); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java index f8527276a7a..ed55529cd6d 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java @@ -44,7 +44,7 @@ public enum CPType { Builtin, Reorg, Variable, FCall, Append, Rand, QSort, QPick, Local, MatrixIndexing, MMTSJ, PMMJ, MMChain, Reshape, Partition, Compression, DeCompression, SpoofFused, StringInit, CentralMoment, Covariance, UaggOuterChain, Dnn, Sql, Prefetch, Broadcast, TrigRemote, - EvictLineageCache, + EvictLineageCache, EINSUM, NoOp, } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java new file mode 100644 index 00000000000..b41b7da09e6 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -0,0 +1,482 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.cp; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types.DataType; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.hops.codegen.cplan.CNodeCell; +import org.apache.sysds.hops.codegen.cplan.CNodeRow; +import org.apache.sysds.runtime.codegen.*; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; +import org.apache.sysds.runtime.functionobjects.*; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.lineage.LineageCodegenItem; +import org.apache.sysds.runtime.lineage.LineageItem; +import org.apache.sysds.runtime.lineage.LineageItemUtils; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.AggregateOperator; +import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; +import org.apache.sysds.runtime.matrix.operators.ReorgOperator; + +import java.util.*; + +import static org.apache.sysds.runtime.instructions.cp.EinsumContext.getEinsumContext; + +public class EinsumCPInstruction extends ComputationCPInstruction { + + protected static final Log LOG = LogFactory.getLog(EinsumCPInstruction.class.getName()); + public String eqStr; + private final Class _class; + private final SpoofOperator _op; + private final int _numThreads; + private final CPOperand[] _in; + + private EinsumCPInstruction(int k, + CPOperand[] in, CPOperand out, String opcode, String str, String eqStr) + { + super(CPType.EINSUM, null, null, null, out, opcode, str); + _class =null; + _op = null; + _numThreads = k; + _in = in; + this.eqStr = eqStr; + } + + public SpoofOperator getSpoofOperator() { + return _op; + } + + public Class getOperatorClass() { + return _class; + } + + private static final int CONTRACT_LEFT = 1; + private static final int CONTRACT_RIGHT = 2; + private static final int CONTRACT_BOTH = 3; + + public static EinsumCPInstruction parseInstruction(String str) { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + + ArrayList inlist = new ArrayList<>(); + + for (int i = 2; i < parts.length - 1; i++) + { + inlist.add(new CPOperand(parts[i])); + } + + CPOperand out = new CPOperand(parts[parts.length-1]); +// int k = 1;//Integer.parseInt(parts[parts.length-1]); + int k = OptimizerUtils.getConstrainedNumThreads(-1); + + String eqString = new CPOperand(parts[1]).getName(); //todo change + + return new EinsumCPInstruction(k, inlist.toArray(new CPOperand[0]), out, parts[0], str, eqString); + } + + @Override + public void processInstruction(ExecutionContext ec) { + + //get input matrices and scalars, incl pinning of matrices + ArrayList inputs = new ArrayList<>(); + ArrayList scalars = new ArrayList<>(); + if( LOG.isDebugEnabled() ) + LOG.debug("executing spoof instruction " + _op); + for (CPOperand input : _in) { + if(input.getDataType()==DataType.MATRIX){ + MatrixBlock mb = ec.getMatrixInput(input.getName()); + //FIXME fused codegen operators already support compressed main inputs + if(mb instanceof CompressedMatrixBlock){ + mb = ((CompressedMatrixBlock) mb).getUncompressed("Spoof instruction"); + } + inputs.add(mb); + } + else if(input.getDataType()==DataType.SCALAR) { + //note: even if literal, it might be compiled as scalar placeholder + scalars.add(ec.getScalarInput(input)); + } + } + + EinsumContext einc = getEinsumContext(eqStr,inputs); + + + String[] parts = einc.equationString.split("->"); + String[] inputsChars = parts[0].split(","); + +// System.out.println("outrows:"+einc.outRows); +// System.out.println("outcols:"+einc.outCols); + + //todo move to separate op earlier: + for(int i=0;i... + // outer tmpl + CNodeRow cnode = new CNodeRow(new ArrayList<>(), null); +// cnode.setConstDim2(einc.outCols); +// cnode.setNumVectorIntermediates(1); + String src = tmpRow; + + if(einc.outCols == 1){ +// cnode.setRowType(SpoofRowwise.RowType.ROW_AGG); + src = src.replace("%TYPE%","ROW_AGG"); + + }else { +// cnode.setRowType(SpoofRowwise.RowType.COL_AGG_B1_T); + src = src.replace("%TYPE%","COL_AGG_B1_T"); + + } + src = src.replace("%TMP%", cnode.createVarname()); + +// String src= cnode.codegenEinsum(false, SpoofCompiler.GeneratorAPI.JAVA); + src = src.replace("%CONST_DIM2%",einc.outCols.toString());// super(RowType.%TYPE%, %CONST_DIM2%, %TB1%, %VECT_MEM%);\n" + + +// System.out.println(src); + Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); +// Class cla = CodegenUtils.compileClass("codegen.TMP0", src); + SpoofOperator op = CodegenUtils.createInstance(cla); + MatrixBlock mb = new MatrixBlock(); + mb.reset(einc.outRows, einc.outCols, false); + mb.allocateDenseBlock(); + if(LibSpoofPrimitives.isFlipOuter(einc.outRows,einc.outCols)){ +// System.out.println("swapping"); + ArrayList m2 = new ArrayList(2); + + m2.add(inputs.get(1)); + m2.add(inputs.get(0)); + MatrixBlock out =op.execute(m2,scalars,mb,_numThreads); + ec.setMatrixOutput(output.getName(), out); + }else{ +// System.out.println("NOTswapping"); + MatrixBlock out =op.execute(inputs,scalars,mb,_numThreads); + ec.setMatrixOutput(output.getName(), out); + } + + } + else if(inputsChars.length == 2 && inputsChars[0].charAt(1)==inputsChars[1].charAt(0)){ + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);//todo move to separate op earlier + MatrixBlock first = (inputs.get(0)).reorgOperations(transpose, new MatrixBlock(), 0 ,0, 0); + + CNodeRow cnode = new CNodeRow(new ArrayList<>(), null); + String src = tmpRow; + + if(einc.outCols == 1){ + src = src.replace("%TYPE%","ROW_AGG"); + + }else { + src = src.replace("%TYPE%","COL_AGG_B1_T"); + + } + src = src.replace("%TMP%", cnode.createVarname()); + + src = src.replace("%CONST_DIM2%",einc.outCols.toString());// super(RowType.%TYPE%, %CONST_DIM2%, %TB1%, %VECT_MEM%);\n" + + +// System.out.println(src); + Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); + SpoofOperator op = CodegenUtils.createInstance(cla); + MatrixBlock mb = new MatrixBlock(); + mb.reset(einc.outRows, einc.outCols, false); + + mb.allocateDenseBlock(); + if(LibSpoofPrimitives.isFlipOuter(einc.outRows,einc.outCols)){ + ArrayList m2 = new ArrayList(2); + + m2.add(inputs.get(1)); + m2.add(first); + MatrixBlock out =op.execute(m2,scalars,mb,_numThreads); + ec.setMatrixOutput(output.getName(), out); + }else{ + ArrayList m2 = new ArrayList(2); + m2.add(first); + m2.add(inputs.get(1)); + MatrixBlock out =op.execute(m2,scalars,mb,_numThreads); + ec.setMatrixOutput(output.getName(), out); + } + } + else{ //fallback to cell + CNodeCell cnode = new CNodeCell(new ArrayList<>(), null); +// cnode.setCellType(SpoofCellwise.CellType.NO_AGG); + StringBuilder sb = new StringBuilder(); + + String outputChars = parts[1]; + if( outputChars.length()==2){ + einc.summingChars.remove( outputChars.charAt(0)); + einc.summingChars.remove( outputChars.charAt(1)); + + }else if( outputChars.length()==1){ + einc.summingChars.remove( outputChars.charAt(0)); + + } + boolean needsSumming = einc.summingChars.stream().anyMatch(x->x != null); + String itVar0 = "TMP123";//+new IDSequence().getNextID(); todo: generate this var + String outVar = null; + if(needsSumming){ + outVar = "TMP123";//+new IDSequence().getNextID(); + sb.append("double "); + sb.append(outVar); + sb.append("=0;\n"); + } + + HashSet summedCharacters = new HashSet<>(); + Iterator hsIt = einc.summingChars.iterator(); + while (hsIt.hasNext()) { + + Character c = hsIt.next(); + String itVar = itVar0+c; + sb.append("for(int "); + sb.append(itVar); + sb.append("=0;"); + sb.append(itVar); + sb.append("<"); + sb.append(einc.charToDimensionSizeInt.get(c)); + sb.append(";"); + sb.append(itVar); + sb.append("++){\n"); + } + if (needsSumming){ + sb.append(outVar); + sb.append("+="); + } + + if(parts[1].length()==2){ + for (int i=0;i< inputsChars.length;i++){ + if(einc.summingChars.contains(inputsChars[i].charAt(0))){ + sb.append("getValue(b["); + sb.append(i); + sb.append("],b["); + sb.append(i); + sb.append("].clen,"); + sb.append(itVar0); + sb.append(inputsChars[i].charAt(0)); + sb.append(","); + }else if(inputsChars[i].charAt(0)==outputChars.charAt(0)) { + sb.append("getValue(b["); + sb.append(i); + sb.append("],b["); + sb.append(i); + sb.append("].clen, rix,"); + }else if(inputsChars[i].charAt(0)==outputChars.charAt(1)) { + sb.append("getValue(b["); + sb.append(i); + sb.append("],b["); + sb.append(i); + sb.append("].clen, cix,"); + }else { + sb.append("getValue(b["); + sb.append(i); + sb.append("],b["); + sb.append(i); + sb.append("].clen, 0,"); + } + + if(einc.summingChars.contains(inputsChars[i].charAt(1))) { + sb.append(itVar0); + sb.append(inputsChars[i].charAt(1)); + sb.append(")"); + } + else if(inputsChars[i].charAt(1)==outputChars.charAt(0)){ + sb.append("rix)"); + }else if(inputsChars[i].charAt(1)==outputChars.charAt(1)){ + sb.append("cix)"); + + }else { + sb.append("0)"); + } + + + if(i getLineageItem(ExecutionContext ec) { + //return the lineage item if already traced once + LineageItem li = ec.getLineage().get(output.getName()); + if (li != null) + return Pair.of(output.getName(), li); + + //read and deepcopy the corresponding lineage DAG (pre-codegen) + LineageItem LIroot = LineageCodegenItem.getCodegenLTrace(getOperatorClass().getName()).deepCopy(); + + //replace the placeholders with original instruction inputs. + LineageItemUtils.replaceDagLeaves(ec, LIroot, _in); + + return Pair.of(output.getName(), LIroot); + } + + public CPOperand[] getInputs() { + return _in; + } + + private static final IDSequence _idSeqfn = new IDSequence(); + + private final static String tmpRow = "package codegen;\n" + + "import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;\n" + + "import org.apache.sysds.runtime.codegen.SpoofOperator.SideInput;\n" + + "import org.apache.sysds.runtime.codegen.SpoofRowwise;\n" + + "import org.apache.sysds.runtime.codegen.SpoofRowwise.RowType;\n" + + "import org.apache.commons.math3.util.FastMath;\n" + + "\n" + + "public final class %TMP% extends SpoofRowwise { \n" + + " public %TMP%() {\n" + + " super(RowType.%TYPE%, %CONST_DIM2%, false, 1);\n" + + " }\n" + + " protected void genexec(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, int rix) { \n" + + "LibSpoofPrimitives.vectOuterMultAdd(a, b[0].values(rix), c, ai, b[0].pos(rix), 0, len, b[0].clen); }\n" + + " protected void genexec(double[] avals, int[] aix, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int alen, int len, long grix, int rix) { \n" + + " }\n" + + "}\n"; + + private final static String tmpCell = + "package codegen;\n" + + "import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;\n" + + "import org.apache.sysds.runtime.codegen.SpoofCellwise;\n" + + "import org.apache.sysds.runtime.codegen.SpoofCellwise.AggOp;\n" + + "import org.apache.sysds.runtime.codegen.SpoofCellwise.CellType;\n" + + "import org.apache.sysds.runtime.codegen.SpoofOperator.SideInput;\n" + + "import org.apache.commons.math3.util.FastMath;\n" + + "public final class %TMP% extends SpoofCellwise {\n" + + " public %TMP%() {\n" + + " super(CellType.NO_AGG, false, true, null);\n" + + " }\n" + + " protected double genexec(double a, SideInput[] b, double[] scalars, int m, int n, long grix, int rix, int cix) { \n" + + " %BODY_dense%" + + " return %OUT%;\n" + + " }\n" + + "}"; +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumContext.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumContext.java new file mode 100644 index 00000000000..d8954b91884 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumContext.java @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.cp; + +import org.apache.sysds.runtime.matrix.data.MatrixBlock; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; + +public class EinsumContext { + public Integer outRows; + public Integer outCols; + public HashMap charToDimensionSizeInt; + public String equationString; + public Integer[] contractDims; + public Integer[] summingDims; + public HashSet summingChars; + + public static EinsumContext getEinsumContext(String eqStr, ArrayList inputs){ + EinsumContext res = new EinsumContext(); + + res.equationString = eqStr; + int i = 0; + res.charToDimensionSizeInt = new HashMap(); + Iterator it = inputs.iterator(); + MatrixBlock curArr = it.next(); + int arrSizeIterator=0; + HashSet summingChars = new HashSet<>(); + Integer[] contractDims = new Integer[inputs.size()];//0==nothing, 1 = right, 2=left, 3 = both + Integer[] summingDims = new Integer[inputs.size()];//0/null==nothing, 1 = right, 2=left, 3 = both + int arrIt = 0; + for (i = 0; true; i++){ + char c = eqStr.charAt(i); + if(c=='-'){ + i+=2; + break; + } + if(c==','){ + arrIt++; + curArr = it.next(); + arrSizeIterator = 0; + } + + else{ + if (res.charToDimensionSizeInt.containsKey(c)){ + // just check if dims match! + if(arrSizeIterator==0) + assert (res.charToDimensionSizeInt.get(c) == curArr.getNumRows()); + else if(arrSizeIterator==1) + assert (res.charToDimensionSizeInt.get(c) == curArr.getNumColumns()); + + summingChars.add(c); + + }else{ + if(arrSizeIterator==0) + res.charToDimensionSizeInt.put(c, curArr.getNumRows()); + else if(arrSizeIterator==1) + res.charToDimensionSizeInt.put(c, curArr.getNumColumns()); + } + arrSizeIterator++; + } + + //Process char + } + int rem = eqStr.length() - i; + arrSizeIterator = 0; + if (rem ==0){ + res.outRows=1; + res.outCols=1; + + arrIt=0; + for (i = 0; true; i++) { + char c = eqStr.charAt(i); + if (c == '-') { + break; + } + if (c == ',') { + arrIt++; + arrSizeIterator = 0; + continue; + } + + if(summingChars.contains(c)){ + + }else{ + if(contractDims[arrIt]==null){ + contractDims[arrIt]=arrSizeIterator +1; + + }else { + contractDims[arrIt] += arrSizeIterator + 1; + } + } + arrSizeIterator++; + + } + }else if (rem == 1){ + char c1= eqStr.charAt(i); + res.outRows=(res.charToDimensionSizeInt.get(c1)); + + res.outCols=1; + arrIt=0; + for (i = 0; true; i++) { + char c = eqStr.charAt(i); + if (c == '-') { + break; + } + if (c == ',') { + arrIt++; + arrSizeIterator = 0; + continue; + } + + if(summingChars.contains(c)){ + if(summingDims[arrIt] == null){ + summingDims[arrIt]=arrSizeIterator +1; // it=0->add 1, it==1->add 2 + }else{ + summingDims[arrIt]+=arrSizeIterator +1; // it=0->add 1, it==1->add 2 + + } + }else if(c==c1){ + // this dim is remaining + }else{ + if(contractDims[arrIt]==null){ + contractDims[arrIt]=arrSizeIterator +1; + + }else { + contractDims[arrIt] += arrSizeIterator + 1; + } + + } + arrSizeIterator++; + + } + }else if (rem==2){ + char c1= eqStr.charAt(i); + char c2= eqStr.charAt(i+1); + res.outRows=(res.charToDimensionSizeInt.get(c1)); + res.outCols=(res.charToDimensionSizeInt.get(c2)); + + arrIt=0; + for (i = 0; true; i++) { + char c = eqStr.charAt(i); + if (c == '-') { + break; + } + if (c == ',') { + arrIt++; + arrSizeIterator = 0; + continue; + + } + + if(summingChars.contains(c)){ + if(summingDims[arrIt] == null){ + summingDims[arrIt]=arrSizeIterator +1; // it=0->add 1, it==1->add 2 + }else{ + summingDims[arrIt]+=arrSizeIterator +1; // it=0->add 1, it==1->add 2 + + } + }else if(c==c1 || c==c2){ + // this dim is remaining + }else{ + if(contractDims[arrIt]==null){ + contractDims[arrIt]=arrSizeIterator +1; + + }else { + contractDims[arrIt] += arrSizeIterator + 1; + } + } + arrSizeIterator++; + + } + }else{ + throw new RuntimeException("output dim > 2 not supported for now"); + } + res.contractDims=contractDims; + res.summingDims=summingDims; + + res.summingChars = summingChars; + return res; + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java new file mode 100644 index 00000000000..471aa47e4b8 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.einsum; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Ignore; +import org.junit.Test; + +import java.io.File; +import java.util.HashMap; + +public class EinsumTest extends AutomatedTestBase +{ + private static final Log LOG = LogFactory.getLog(EinsumTest.class.getName()); + + private static final String TEST_NAME_EINSUM = "einsum"; + private static final String TEST_EINSUM1 = TEST_NAME_EINSUM+"1"; + private static final String TEST_EINSUM2 = TEST_NAME_EINSUM+"2"; + private static final String TEST_EINSUM3 = TEST_NAME_EINSUM+"3"; + private static final String TEST_EINSUM4 = TEST_NAME_EINSUM+"4"; + private static final String TEST_EINSUM5 = TEST_NAME_EINSUM+"5"; + private static final String TEST_EINSUM6 = TEST_NAME_EINSUM+"6"; + private static final String TEST_EINSUM7 = TEST_NAME_EINSUM+"7"; + + private static final String TEST_DIR = "functions/einsum/"; + private static final String TEST_CLASS_DIR = TEST_DIR + EinsumTest.class.getSimpleName() + "/"; + private final static String TEST_CONF = "SystemDS-config-codegen.xml"; + private final static File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, TEST_CONF); + + private static double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + for(int i=1; i<=7; i++) + addTestConfiguration( TEST_NAME_EINSUM+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_EINSUM+i, new String[] { String.valueOf(i) }) ); + } + @Test + public void testCodegenEinsum1CP() { + testCodegenIntegration( TEST_EINSUM1, false, ExecType.CP ); + } + @Test + public void testCodegenEinsum2CP() { + testCodegenIntegration( TEST_EINSUM2, false, ExecType.CP ); + } + @Test + public void testCodegenEinsum3CP() { + testCodegenIntegration( TEST_EINSUM3, false, ExecType.CP ); +} + @Test + public void testCodegenEinsum4CP() { + testCodegenIntegration( TEST_EINSUM4, false, ExecType.CP ); + } + @Test + public void testCodegenEinsum5CP() { + testCodegenIntegration( TEST_EINSUM5, false, ExecType.CP ); + } + @Test + public void testCodegenEinsum6CP() { + testCodegenIntegration( TEST_EINSUM6, false, ExecType.CP ); + } + @Test + public void testCodegenEinsum7CP() { + testCodegenIntegration( TEST_EINSUM7, false, ExecType.CP ); + } + + private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType ) + { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + ExecMode platformOld = setExecMode(instType); + + try + { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[]{"-stats", "-explain", "-args", output("S") }; + + fullRScriptName = HOME + testname + ".R"; + rCmd = getRCmd(inputDir(), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap dmlfile = readDMLMatrixFromOutputDir("S"); + HashMap rfile = readRMatrixFromExpectedDir("S"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + } + finally { + resetExecMode(platformOld); + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + OptimizerUtils.ALLOW_AUTO_VECTORIZATION = true; + OptimizerUtils.ALLOW_OPERATOR_FUSION = true; + } + } + + /** + * Override default configuration with custom test configuration to ensure + * scratch space and local temporary directory locations are also updated. + */ + @Override + protected File getConfigTemplateFile() { + // Instrumentation in this test's output log to show custom configuration file used for template. + LOG.debug("This test case overrides default configuration with " + TEST_CONF_FILE.getPath()); + return TEST_CONF_FILE; + } +} diff --git a/src/test/scripts/functions/einsum/SystemDS-config-codegen.xml b/src/test/scripts/functions/einsum/SystemDS-config-codegen.xml new file mode 100644 index 00000000000..626b31ebd76 --- /dev/null +++ b/src/test/scripts/functions/einsum/SystemDS-config-codegen.xml @@ -0,0 +1,31 @@ + + + + /tmp/systemds + scratch_space + 2 + true + 1 + + + 16 + + auto + \ No newline at end of file diff --git a/src/test/scripts/functions/einsum/einsum1.R b/src/test/scripts/functions/einsum/einsum1.R new file mode 100644 index 00000000000..98e52486f95 --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum1.R @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") +library("einsum") + +P = matrix(seq(1,3000), 600, 5, byrow=TRUE); +X = matrix(seq(1,6000), 600, 10, byrow=TRUE); + +# R = t(P) %*% X; +R = einsum("ji,jz->iz",P,X) + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum1.dml b/src/test/scripts/functions/einsum/einsum1.dml new file mode 100644 index 00000000000..1523339e21c --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum1.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +P = matrix(seq(1,3000), 600,5); +X = matrix(seq(1,6000), 600, 10) + +while(FALSE){} + +#R = t(P) %*% X ; + +R = einsum("ji,jz->iz",P,X) +write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum2.R b/src/test/scripts/functions/einsum/einsum2.R new file mode 100644 index 00000000000..15be5c772c2 --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum2.R @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") +library("einsum") + +P = matrix(seq(1,3000), 5, 600, byrow=TRUE); +X = matrix(seq(1,6000), 600, 10, byrow=TRUE); + +# R = P %*% X; +R = einsum("ij,jz->iz",P,X) + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum2.dml b/src/test/scripts/functions/einsum/einsum2.dml new file mode 100644 index 00000000000..eb47e19c807 --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum2.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +P = matrix(seq(1,3000), 5, 600) +X = matrix(seq(1,6000), 600, 10); + +while(FALSE){} + +#R = P %*% X ; +R = einsum("ij,jz->iz",P,X) + +write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum3.R b/src/test/scripts/functions/einsum/einsum3.R new file mode 100644 index 00000000000..0a802dc6fde --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum3.R @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") +library("einsum") + +X = matrix(seq(1,6000), 600, 10, byrow=TRUE); +P = matrix(seq(1,3000), 600, 5, byrow=TRUE); + +# R = sum(t(P) %*% X); +R = einsum("ji,jz->i",P,X) + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum3.dml b/src/test/scripts/functions/einsum/einsum3.dml new file mode 100644 index 00000000000..9e8c96939f2 --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum3.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +P = matrix(seq(1,3000), 600, 5) +X = matrix(seq(1,6000), 600, 10); + +while(FALSE){} + + +#R = sum(t(P) %*% X) ; + +R = einsum("ji,jz->i",P,X) +write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum4.R b/src/test/scripts/functions/einsum/einsum4.R new file mode 100644 index 00000000000..bb9caf31e57 --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum4.R @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") +library("einsum") + +X = matrix(seq(1,6000), 600, 10, byrow=TRUE); +P = matrix(seq(1,3000), 600, 5, byrow=TRUE); + +# R = colSums(t(P) %*% X); +R = einsum("ji,jz->z",P,X) + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum4.dml b/src/test/scripts/functions/einsum/einsum4.dml new file mode 100644 index 00000000000..84c3efbdfae --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum4.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +P = matrix(seq(1,3000), 600, 5) +X = matrix(seq(1,6000), 600, 10); + +while(FALSE){} + +#R = colSums(t(P) %*% X) ; + +R = einsum("ji,jz->z",P,X) +write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum5.R b/src/test/scripts/functions/einsum/einsum5.R new file mode 100644 index 00000000000..902941b9222 --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum5.R @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") +library("einsum") + +X = matrix(seq(1,6000), 600, 10, byrow=TRUE); +P = matrix(seq(1,3000), 600, 5, byrow=TRUE); + +# R = rowSums(P) * rowSums(X) +R = einsum("ji,jz->j",P,X) + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum5.dml b/src/test/scripts/functions/einsum/einsum5.dml new file mode 100644 index 00000000000..de5fb654e20 --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum5.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(seq(1,6000), 600, 10); +P = matrix(seq(1,3000), 600, 5) + +while(FALSE){} + +#R = colSums(t(P) %*% X) ; + +R = einsum("ji,jz->j",P,X) +write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum6.R b/src/test/scripts/functions/einsum/einsum6.R new file mode 100644 index 00000000000..0aa46d92a64 --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum6.R @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") +library("einsum") + +X = matrix(seq(1,6000), 600, 10, byrow=TRUE); +P = matrix(seq(1,30), 6, 5, byrow=TRUE); + +# R = P * sum(X) +R = einsum("ab,cd->ab",P,X) + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum6.dml b/src/test/scripts/functions/einsum/einsum6.dml new file mode 100644 index 00000000000..1385d0329eb --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum6.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(seq(1,6000), 600, 10); +P = matrix(seq(1,30), 6, 5) + +while(FALSE){} + +#R = colSums(t(P) %*% X) ; + +R = einsum("ab,cd->ab",P,X) +write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum7.R b/src/test/scripts/functions/einsum/einsum7.R new file mode 100644 index 00000000000..f0075f86dc9 --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum7.R @@ -0,0 +1,35 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") +library("einsum") + +P = matrix(seq(1,3000), 600, 5, byrow=TRUE); +X = matrix(seq(1,6000), 600, 10, byrow=TRUE); +Z = matrix(seq(1,20), 10, 2, byrow=TRUE); + +# R = t(P) %*% X; +# RR= R %*% Z +RR = einsum("ji,jz,zx->ix",P,X,Z) +writeMM(as(RR, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum7.dml b/src/test/scripts/functions/einsum/einsum7.dml new file mode 100644 index 00000000000..5a1b889ce29 --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum7.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +P = matrix(seq(1,3000), 600, 5) +X = matrix(seq(1,6000), 600, 10); +Z = matrix(seq(1,20), 10, 2) + +while(FALSE){} + +R = einsum("ji,jz,zx->ix",P,X,Z) +write(R, $1) From bc69ee34fc619c412cbe6a93658e7d79de56eef2 Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Sun, 18 May 2025 22:02:57 +0200 Subject: [PATCH 02/28] add einsum R dependency --- src/test/scripts/installDependencies.R | 1 + 1 file changed, 1 insertion(+) diff --git a/src/test/scripts/installDependencies.R b/src/test/scripts/installDependencies.R index af89f2b936e..60642fa8ed4 100644 --- a/src/test/scripts/installDependencies.R +++ b/src/test/scripts/installDependencies.R @@ -64,6 +64,7 @@ custom_install("unbalanced"); custom_install("naivebayes"); custom_install("BiocManager"); custom_install("mltools"); +custom_install("einsum"); BiocManager::install("rhdf5"); print("Installation Done") From 8db99c9b4e3d603a2074039cc1531f9ec6f7fb6b Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Mon, 19 May 2025 13:12:09 +0200 Subject: [PATCH 03/28] fix merge --- .../apache/sysds/common/InstructionType.java | 1 + .../java/org/apache/sysds/common/Opcodes.java | 2 +- .../parser/BuiltinFunctionExpression.java | 8 +- .../instructions/CPInstructionParser.java | 2 - .../cp/BuiltinNaryCPInstruction.java | 5 +- .../instructions/cp/EinsumCPInstruction.java | 91 ++++++++++--------- 6 files changed, 56 insertions(+), 53 deletions(-) diff --git a/src/main/java/org/apache/sysds/common/InstructionType.java b/src/main/java/org/apache/sysds/common/InstructionType.java index 4dba1c5be09..3f61b74fb64 100644 --- a/src/main/java/org/apache/sysds/common/InstructionType.java +++ b/src/main/java/org/apache/sysds/common/InstructionType.java @@ -61,6 +61,7 @@ public enum InstructionType { MMTSJ, PMMJ, MMChain, + EINSUM, //SP Types MAPMM, diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java b/src/main/java/org/apache/sysds/common/Opcodes.java index 4cd2a658d6c..7b89689d117 100644 --- a/src/main/java/org/apache/sysds/common/Opcodes.java +++ b/src/main/java/org/apache/sysds/common/Opcodes.java @@ -172,7 +172,7 @@ public enum Opcodes { RBIND("rbind", InstructionType.BuiltinNary), EVAL("eval", InstructionType.BuiltinNary), LIST("list", InstructionType.BuiltinNary), - EINSUM("einsum", CPType.EINSUM), + EINSUM("einsum", InstructionType.BuiltinNary), //Parametrized builtin functions AUTODIFF("autoDiff", InstructionType.ParameterizedBuiltin), CONTAINS("contains", InstructionType.ParameterizedBuiltin), diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java index 5694caff308..aba7e4312fe 100644 --- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java @@ -750,13 +750,13 @@ else if(((ConstIdentifier) getThirdExpr().getOutput()) metaOutput.setDataType(DataType.FRAME); metaOutput.setDimensions(compressInput1.getDim1(), -1); } - case EINSUM: - validateEinsum((DataIdentifier) getOutputs()[0]); - - break; else raiseValidateError("Compress/DeCompress instruction not allowed in dml script"); break; + case EINSUM: + validateEinsum((DataIdentifier) getOutputs()[0]); + + break; default: //always unconditional raiseValidateError("Unknown Builtin Function opcode: " + _opcode, false); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java index 8ae35782ef0..0f6c5e9f57e 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java @@ -219,8 +219,6 @@ public static CPInstruction parseSingleInstruction ( InstructionType cptype, Str case EvictLineageCache: return EvictCPInstruction.parseInstruction(str); - case EINSUM: - return EinsumCPInstruction.parseInstruction(str); default: throw new DMLRuntimeException("Invalid CP Instruction Type: " + cptype ); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java index 6d230a30f08..e7aa1b5fd76 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java @@ -91,10 +91,13 @@ else if( opcode.equals(Opcodes.NM.toString()) ) { return new MatrixBuiltinNaryCPInstruction( new SimpleOperator(Multiply.getMultiplyFnObject()), opcode, str, outputOperand, inputOperands); } + else if( opcode.equals(Opcodes.EINSUM.toString()) ) { + return new EinsumCPInstruction(null, opcode, str, outputOperand, inputOperands); + } else if (OpOpN.EVAL.name().equalsIgnoreCase(opcode)) { return new EvalNaryCPInstruction(null, opcode, str, outputOperand, inputOperands); } - + throw new DMLRuntimeException("Opcode (" + opcode + ") not recognized in BuiltinMultipleCPInstruction"); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java index b41b7da09e6..014b1dcd02c 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -38,13 +38,14 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.AggregateOperator; import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; +import org.apache.sysds.runtime.matrix.operators.Operator; import org.apache.sysds.runtime.matrix.operators.ReorgOperator; import java.util.*; import static org.apache.sysds.runtime.instructions.cp.EinsumContext.getEinsumContext; -public class EinsumCPInstruction extends ComputationCPInstruction { +public class EinsumCPInstruction extends BuiltinNaryCPInstruction { protected static final Log LOG = LogFactory.getLog(EinsumCPInstruction.class.getName()); public String eqStr; @@ -53,17 +54,17 @@ public class EinsumCPInstruction extends ComputationCPInstruction { private final int _numThreads; private final CPOperand[] _in; - private EinsumCPInstruction(int k, - CPOperand[] in, CPOperand out, String opcode, String str, String eqStr) + public EinsumCPInstruction(Operator op, String opcode, String istr, CPOperand out, CPOperand... inputs) { - super(CPType.EINSUM, null, null, null, out, opcode, str); - _class =null; + super(op, opcode, istr, out, inputs); + _class = null; _op = null; - _numThreads = k; - _in = in; - this.eqStr = eqStr; + _numThreads = OptimizerUtils.getConstrainedNumThreads(-1); + _in = inputs; + this.eqStr = inputs[0].getName(); } + public SpoofOperator getSpoofOperator() { return _op; } @@ -76,24 +77,24 @@ public Class getOperatorClass() { private static final int CONTRACT_RIGHT = 2; private static final int CONTRACT_BOTH = 3; - public static EinsumCPInstruction parseInstruction(String str) { - String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); - - ArrayList inlist = new ArrayList<>(); - - for (int i = 2; i < parts.length - 1; i++) - { - inlist.add(new CPOperand(parts[i])); - } - - CPOperand out = new CPOperand(parts[parts.length-1]); -// int k = 1;//Integer.parseInt(parts[parts.length-1]); - int k = OptimizerUtils.getConstrainedNumThreads(-1); - - String eqString = new CPOperand(parts[1]).getName(); //todo change - - return new EinsumCPInstruction(k, inlist.toArray(new CPOperand[0]), out, parts[0], str, eqString); - } +// public static EinsumCPInstruction parseInstruction(String str) { +// String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); +// +// ArrayList inlist = new ArrayList<>(); +// +// for (int i = 2; i < parts.length - 1; i++) +// { +// inlist.add(new CPOperand(parts[i])); +// } +// +// CPOperand out = new CPOperand(parts[parts.length-1]); +//// int k = 1;//Integer.parseInt(parts[parts.length-1]); +// int k = OptimizerUtils.getConstrainedNumThreads(-1); +// +// String eqString = new CPOperand(parts[1]).getName(); //todo change +// +// return new EinsumCPInstruction(k, inlist.toArray(new CPOperand[0]), out, parts[0], str, eqString); +// } @Override public void processInstruction(ExecutionContext ec) { @@ -112,10 +113,10 @@ public void processInstruction(ExecutionContext ec) { } inputs.add(mb); } - else if(input.getDataType()==DataType.SCALAR) { - //note: even if literal, it might be compiled as scalar placeholder - scalars.add(ec.getScalarInput(input)); - } +// else if(input.getDataType()==DataType.SCALAR) { +// //note: even if literal, it might be compiled as scalar placeholder +// scalars.add(ec.getScalarInput(input)); +// } } EinsumContext einc = getEinsumContext(eqStr,inputs); @@ -423,21 +424,21 @@ else if(inputsChars[i].charAt(1)==outputChars.charAt(0)){ ec.releaseMatrixInput(input.getName()); } - @Override - public Pair getLineageItem(ExecutionContext ec) { - //return the lineage item if already traced once - LineageItem li = ec.getLineage().get(output.getName()); - if (li != null) - return Pair.of(output.getName(), li); - - //read and deepcopy the corresponding lineage DAG (pre-codegen) - LineageItem LIroot = LineageCodegenItem.getCodegenLTrace(getOperatorClass().getName()).deepCopy(); - - //replace the placeholders with original instruction inputs. - LineageItemUtils.replaceDagLeaves(ec, LIroot, _in); - - return Pair.of(output.getName(), LIroot); - } +// @Override +// public Pair getLineageItem(ExecutionContext ec) { +// //return the lineage item if already traced once +// LineageItem li = ec.getLineage().get(output.getName()); +// if (li != null) +// return Pair.of(output.getName(), li); +// +// //read and deepcopy the corresponding lineage DAG (pre-codegen) +// LineageItem LIroot = LineageCodegenItem.getCodegenLTrace(getOperatorClass().getName()).deepCopy(); +// +// //replace the placeholders with original instruction inputs. +// LineageItemUtils.replaceDagLeaves(ec, LIroot, _in); +// +// return Pair.of(output.getName(), LIroot); +// } public CPOperand[] getInputs() { return _in; From c9c9947338f6b51e82e72a04de5d7adbb567dc67 Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Mon, 19 May 2025 13:20:40 +0200 Subject: [PATCH 04/28] quick fix --- .../instructions/cp/EinsumCPInstruction.java | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java index 014b1dcd02c..6333b5f21f5 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -23,8 +23,11 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types.DataType; +import org.apache.sysds.hops.LiteralOp; import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.hops.codegen.cplan.CNode; import org.apache.sysds.hops.codegen.cplan.CNodeCell; +import org.apache.sysds.hops.codegen.cplan.CNodeData; import org.apache.sysds.hops.codegen.cplan.CNodeRow; import org.apache.sysds.runtime.codegen.*; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; @@ -164,7 +167,9 @@ public void processInstruction(ExecutionContext ec) { if(inputsChars.length == 2 && inputsChars[0].charAt(0)==inputsChars[1].charAt(0) && !einc.summingChars.contains(parts[1].charAt(0))){// ja,jb->... // outer tmpl - CNodeRow cnode = new CNodeRow(new ArrayList<>(), null); + ArrayList cnodeIn = new ArrayList<>(); + cnodeIn.add(new CNodeData(new LiteralOp(3), 0, 0, DataType.SCALAR)); + CNodeRow cnode = new CNodeRow(cnodeIn, null); // cnode.setConstDim2(einc.outCols); // cnode.setNumVectorIntermediates(1); String src = tmpRow; @@ -208,8 +213,9 @@ public void processInstruction(ExecutionContext ec) { else if(inputsChars.length == 2 && inputsChars[0].charAt(1)==inputsChars[1].charAt(0)){ ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);//todo move to separate op earlier MatrixBlock first = (inputs.get(0)).reorgOperations(transpose, new MatrixBlock(), 0 ,0, 0); - - CNodeRow cnode = new CNodeRow(new ArrayList<>(), null); + ArrayList cnodeIn = new ArrayList<>(); + cnodeIn.add(new CNodeData(new LiteralOp(3), 0, 0, DataType.SCALAR)); + CNodeRow cnode = new CNodeRow(cnodeIn, null); String src = tmpRow; if(einc.outCols == 1){ @@ -246,7 +252,9 @@ else if(inputsChars.length == 2 && inputsChars[0].charAt(1)==inputsChars[1].char } } else{ //fallback to cell - CNodeCell cnode = new CNodeCell(new ArrayList<>(), null); + ArrayList cnodeIn = new ArrayList<>(); + cnodeIn.add(new CNodeData(new LiteralOp(3), 0, 0, DataType.SCALAR)); + CNodeCell cnode = new CNodeCell(cnodeIn, null); // cnode.setCellType(SpoofCellwise.CellType.NO_AGG); StringBuilder sb = new StringBuilder(); From 3a09dff981db8fadbf64cdc4806384d8f5d83972 Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Sun, 8 Jun 2025 23:54:51 +0200 Subject: [PATCH 05/28] more computations done using row tpl --- .../sysds/hops/codegen/cplan/CNodeBinary.java | 8 +- .../sysds/hops/codegen/cplan/CNodeData.java | 8 + .../instructions/cp/EinsumCPInstruction.java | 913 +++++++++++++++--- .../instructions/cp/EinsumContext.java | 14 +- .../test/functions/einsum/EinsumTest.java | 12 +- src/test/scripts/functions/einsum/einsum8.R | 34 + src/test/scripts/functions/einsum/einsum8.dml | 30 + src/test/scripts/functions/einsum/einsum9.R | 34 + src/test/scripts/functions/einsum/einsum9.dml | 30 + 9 files changed, 963 insertions(+), 120 deletions(-) create mode 100644 src/test/scripts/functions/einsum/einsum8.R create mode 100644 src/test/scripts/functions/einsum/einsum8.dml create mode 100644 src/test/scripts/functions/einsum/einsum9.R create mode 100644 src/test/scripts/functions/einsum/einsum9.dml diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java index b29d586c38a..c80242f2547 100644 --- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java +++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java @@ -187,9 +187,13 @@ public String codegen(boolean sparse, GeneratorAPI api) { varj.startsWith("b") ? (api == GeneratorAPI.JAVA ? varj + ".values(rix)" : (_type == BinType.VECT_MATRIXMULT ? varj : varj + ".vals(0)")) : _inputs.get(j).getDataType() == DataType.MATRIX ? (api == GeneratorAPI.JAVA ? varj : varj + ".vals(0)") : varj); - + + if(_type == BinType.VECT_OUTERMULT_ADD && (_inputs.get(j) instanceof CNodeData && _inputs.get(j).getDataType().isMatrix()) && + (varj.startsWith("b"))) + tmp = tmp.replace("%POS"+(j+1)+"%",varj + ".pos(rix)"); + else //replace start position of main input - tmp = tmp.replace("%POS"+(j+1)+"%", (_inputs.get(j) instanceof CNodeData + tmp = tmp.replace("%POS"+(j+1)+"%", (_inputs.get(j) instanceof CNodeData && _inputs.get(j).getDataType().isMatrix()) ? (!varj.startsWith("b")) ? varj+"i" : ((TemplateUtils.isMatrix(_inputs.get(j)) || (_type.isElementwise() && TemplateUtils.isColVector(_inputs.get(j)))) && _type!=BinType.VECT_MATRIXMULT) ? diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeData.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeData.java index 9292972874f..b90789df5eb 100644 --- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeData.java +++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeData.java @@ -56,6 +56,14 @@ public CNodeData(CNodeData node, String newName) { _cols = node.getNumCols(); _dataType = node.getDataType(); } + + public CNodeData(String name, long hopID, long rows, long cols, DataType dataType) { + _name = name; + _hopID = hopID; + _rows = rows; + _cols = cols; + _dataType = dataType; + } @Override public String getVarname() { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java index 6333b5f21f5..0e7195f7622 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -19,13 +19,16 @@ package org.apache.sysds.runtime.instructions.cp; +import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.hops.LiteralOp; import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.hops.codegen.SpoofCompiler; import org.apache.sysds.hops.codegen.cplan.CNode; +import org.apache.sysds.hops.codegen.cplan.CNodeBinary; import org.apache.sysds.hops.codegen.cplan.CNodeCell; import org.apache.sysds.hops.codegen.cplan.CNodeData; import org.apache.sysds.hops.codegen.cplan.CNodeRow; @@ -34,22 +37,20 @@ import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysds.runtime.functionobjects.*; -import org.apache.sysds.runtime.instructions.InstructionUtils; -import org.apache.sysds.runtime.lineage.LineageCodegenItem; -import org.apache.sysds.runtime.lineage.LineageItem; -import org.apache.sysds.runtime.lineage.LineageItemUtils; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.AggregateOperator; import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; import org.apache.sysds.runtime.matrix.operators.Operator; import org.apache.sysds.runtime.matrix.operators.ReorgOperator; +import org.checkerframework.checker.units.qual.A; import java.util.*; import static org.apache.sysds.runtime.instructions.cp.EinsumContext.getEinsumContext; public class EinsumCPInstruction extends BuiltinNaryCPInstruction { - + public static boolean oldCode= false; + public static boolean forceCell = false; protected static final Log LOG = LogFactory.getLog(EinsumCPInstruction.class.getName()); public String eqStr; private final Class _class; @@ -104,6 +105,7 @@ public void processInstruction(ExecutionContext ec) { //get input matrices and scalars, incl pinning of matrices ArrayList inputs = new ArrayList<>(); + ArrayList inputsNames = new ArrayList<>(); ArrayList scalars = new ArrayList<>(); if( LOG.isDebugEnabled() ) LOG.debug("executing spoof instruction " + _op); @@ -115,6 +117,7 @@ public void processInstruction(ExecutionContext ec) { mb = ((CompressedMatrixBlock) mb).getUncompressed("Spoof instruction"); } inputs.add(mb); + inputsNames.add(input.getName()); } // else if(input.getDataType()==DataType.SCALAR) { // //note: even if literal, it might be compiled as scalar placeholder @@ -126,11 +129,73 @@ public void processInstruction(ExecutionContext ec) { String[] parts = einc.equationString.split("->"); - String[] inputsChars = parts[0].split(","); +// ArrayList inputsChars = new ArrayList<>(Arrays.asList(parts[0].split(","))); + + System.out.println("outrows:"+einc.outRows); + System.out.println("outcols:"+einc.outCols); + + Character outChar1 = null; + Character outChar2 = null; + + if(parts[1].length()>=2){ + outChar1 = parts[1].charAt(0); + outChar2 = parts[1].charAt(1); + }else if (parts[1].length()==1){ + outChar1 = parts[1].charAt(0); + } + HashMap partsCharactersCounter = new HashMap<>(); + HashMap> partsCharactersToIndices = new HashMap<>(); + ArrayList newEquationStringSplit = new ArrayList(); + + ArrayList diagMatrices = new ArrayList<>(); + int arrCounter=0; + for(int i=0;i()); + + partsCharactersToIndices.get(c).add(arrCounter); + s+=c; + } + if(i+1()); + + partsCharactersToIndices.get(c2).add(arrCounter); + s+=c2; + } + i++; + + } + newEquationStringSplit.add(s); + } + ArrayList inputsChars = newEquationStringSplit; + System.out.println(String.join(",",newEquationStringSplit)); //todo move to separate op earlier: for(int i=0;i go througth dims to count and try to do row mults and order them + for(Character c :partsCharactersToIndices.keySet()){ + ArrayList a = partsCharactersToIndices.get(c); - if(inputsChars.length == 2 && inputsChars[0].charAt(0)==inputsChars[1].charAt(0) && !einc.summingChars.contains(parts[1].charAt(0))){// ja,jb->... + System.out.println(c+" count= "+a.size()); + } + if(oldCode) { + if (!forceCell && inputsChars.size() == 2 && inputsChars.get(0).charAt(0) == inputsChars.get(1).charAt(0) && !einc.summingChars.contains(parts[1].charAt(0))) { + + + // ja,jb->... // outer tmpl - ArrayList cnodeIn = new ArrayList<>(); - cnodeIn.add(new CNodeData(new LiteralOp(3), 0, 0, DataType.SCALAR)); - CNodeRow cnode = new CNodeRow(cnodeIn, null); + boolean oldWay = false; + if (oldWay) { + // outer tmpl + ArrayList cnodeIn = new ArrayList<>(); + cnodeIn.add(new CNodeData(new LiteralOp(3), 0, 0, DataType.SCALAR)); + CNodeRow cnode = new CNodeRow(cnodeIn, null); // cnode.setConstDim2(einc.outCols); // cnode.setNumVectorIntermediates(1); - String src = tmpRow; + String src = tmpRow; - if(einc.outCols == 1){ + if (einc.outCols == 1) { // cnode.setRowType(SpoofRowwise.RowType.ROW_AGG); - src = src.replace("%TYPE%","ROW_AGG"); +// src = src.replace("%TYPE%","ROW_AGG"); + src = src.replace("%TYPE%", "COL_AGG_B1_T"); - }else { + } else { // cnode.setRowType(SpoofRowwise.RowType.COL_AGG_B1_T); - src = src.replace("%TYPE%","COL_AGG_B1_T"); + src = src.replace("%TYPE%", "COL_AGG_B1_T"); - } - src = src.replace("%TMP%", cnode.createVarname()); + } + src = src.replace("%TMP%", cnode.createVarname()); // String src= cnode.codegenEinsum(false, SpoofCompiler.GeneratorAPI.JAVA); - src = src.replace("%CONST_DIM2%",einc.outCols.toString());// super(RowType.%TYPE%, %CONST_DIM2%, %TB1%, %VECT_MEM%);\n" + + src = src.replace("%CONST_DIM2%", einc.outCols.toString());// super(RowType.%TYPE%, %CONST_DIM2%, %TB1%, %VECT_MEM%);\n" + -// System.out.println(src); - Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); + System.out.println(src); + Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); // Class cla = CodegenUtils.compileClass("codegen.TMP0", src); - SpoofOperator op = CodegenUtils.createInstance(cla); - MatrixBlock mb = new MatrixBlock(); - mb.reset(einc.outRows, einc.outCols, false); - mb.allocateDenseBlock(); - if(LibSpoofPrimitives.isFlipOuter(einc.outRows,einc.outCols)){ + SpoofOperator op = CodegenUtils.createInstance(cla); + MatrixBlock mb = new MatrixBlock(); + mb.reset(einc.outRows, einc.outCols, false); + mb.allocateDenseBlock(); + if (LibSpoofPrimitives.isFlipOuter(einc.outRows, einc.outCols)) { // System.out.println("swapping"); - ArrayList m2 = new ArrayList(2); + ArrayList m2 = new ArrayList(2); - m2.add(inputs.get(1)); - m2.add(inputs.get(0)); - MatrixBlock out =op.execute(m2,scalars,mb,_numThreads); - ec.setMatrixOutput(output.getName(), out); - }else{ + m2.add(inputs.get(1)); + m2.add(inputs.get(0)); + MatrixBlock out = op.execute(m2, scalars, mb, _numThreads); + ec.setMatrixOutput(output.getName(), out); + } else { // System.out.println("NOTswapping"); - MatrixBlock out =op.execute(inputs,scalars,mb,_numThreads); + MatrixBlock out = op.execute(inputs, scalars, mb, _numThreads); + ec.setMatrixOutput(output.getName(), out); + } + } else { + MatrixBlock first; + MatrixBlock second; + String firstName; + String secondName; + if (inputs.get(0).getNumColumns() == einc.outRows) { + first = inputs.get(0); + second = inputs.get(1); + firstName = inputsNames.get(0); + secondName = inputsNames.get(1); + } else { + first = inputs.get(1); + second = inputs.get(0); + firstName = inputsNames.get(1); + secondName = inputsNames.get(0); + } + + ArrayList thisInputs = new ArrayList<>(Arrays.asList(first, second)); + + ArrayList cnodeIn = new ArrayList<>(); + + CNode c1 = new CNodeData(firstName, 1, first.getNumRows(), first.getNumColumns(), DataType.MATRIX); + CNode c2 = new CNodeData(secondName, 2, second.getNumRows(), second.getNumColumns(), DataType.MATRIX); + cnodeIn.add(c1); + cnodeIn.add(c2); + CNode cnodeOut = new CNodeBinary(c1, c2, CNodeBinary.BinType.VECT_OUTERMULT_ADD); + CNodeRow cnode = new CNodeRow(cnodeIn, cnodeOut); + if (inputs.get(0).getNumColumns() == einc.outRows) { + cnode.setConstDim2(einc.outCols); // output will be in1.cols,in2.cols + + } else { + cnode.setConstDim2(einc.outCols); + + } + cnode.setRowType(SpoofRowwise.RowType.COL_AGG_B1_T); +// if (einc.outCols == 1) { +// cnode.setRowType(SpoofRowwise.RowType.COL_AGG_B1R); +// +//// cnode.setConstDim2(1); +//// cnode.setNumVectorIntermediates(1); +// +// } else { +// cnode.setRowType(SpoofRowwise.RowType.COL_AGG_B1_T); +// } + cnode.renameInputs(); + + String src = cnode.codegen(false, SpoofCompiler.GeneratorAPI.JAVA); + + System.out.println(src); + Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); + + SpoofOperator op = CodegenUtils.createInstance(cla); + MatrixBlock mb = new MatrixBlock(); + mb.reset(einc.outRows, einc.outCols, false); + mb.allocateDenseBlock(); + +// MatrixBlock out = op.execute(inputs, scalars, mb, _numThreads); +// ec.setMatrixOutput(output.getName(), out); +// if(LibSpoofPrimitives.isFlipOuter(einc.outRows,einc.outCols)){ +// ArrayList m2 = new ArrayList(2); +// +// m2.add(inputs.get(1)); +// m2.add(inputs.get(0)); +// MatrixBlock out =op.execute(m2,scalars,mb,_numThreads); +// ec.setMatrixOutput(output.getName(), out); +// }else{ +//// ArrayList m2 = new ArrayList(2); +//// . + MatrixBlock out = op.execute(thisInputs, scalars, mb, _numThreads); ec.setMatrixOutput(output.getName(), out); +// } } + } else if (!forceCell && inputsChars.size() == 2 && inputsChars.get(0).charAt(1) == inputsChars.get(1).charAt(0)) { + if (false) { + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);//todo move to separate op earlier + MatrixBlock first = (inputs.get(0)).reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + ArrayList cnodeIn = new ArrayList<>(); + cnodeIn.add(new CNodeData(new LiteralOp(3), 0, 0, DataType.SCALAR)); + CNodeRow cnode = new CNodeRow(cnodeIn, null); + String src = tmpRow; - } - else if(inputsChars.length == 2 && inputsChars[0].charAt(1)==inputsChars[1].charAt(0)){ - ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);//todo move to separate op earlier - MatrixBlock first = (inputs.get(0)).reorgOperations(transpose, new MatrixBlock(), 0 ,0, 0); - ArrayList cnodeIn = new ArrayList<>(); - cnodeIn.add(new CNodeData(new LiteralOp(3), 0, 0, DataType.SCALAR)); - CNodeRow cnode = new CNodeRow(cnodeIn, null); - String src = tmpRow; + if (einc.outCols == 1) { + src = src.replace("%TYPE%", "ROW_AGG"); + + } else { + src = src.replace("%TYPE%", "COL_AGG_B1_T"); + + } + src = src.replace("%TMP%", cnode.createVarname()); + + src = src.replace("%CONST_DIM2%", einc.outCols.toString());// super(RowType.%TYPE%, %CONST_DIM2%, %TB1%, %VECT_MEM%);\n" + - if(einc.outCols == 1){ - src = src.replace("%TYPE%","ROW_AGG"); + System.out.println(src); + Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); + SpoofOperator op = CodegenUtils.createInstance(cla); + MatrixBlock mb = new MatrixBlock(); + mb.reset(einc.outRows, einc.outCols, false); - }else { - src = src.replace("%TYPE%","COL_AGG_B1_T"); + mb.allocateDenseBlock(); + if (LibSpoofPrimitives.isFlipOuter(einc.outRows, einc.outCols)) { + ArrayList m2 = new ArrayList(2); + m2.add(inputs.get(1)); + m2.add(first); + MatrixBlock out = op.execute(m2, scalars, mb, _numThreads); + ec.setMatrixOutput(output.getName(), out); + } else { + ArrayList m2 = new ArrayList(2); + m2.add(first); + m2.add(inputs.get(1)); + MatrixBlock out = op.execute(m2, scalars, mb, _numThreads); + ec.setMatrixOutput(output.getName(), out); + } } - src = src.replace("%TMP%", cnode.createVarname()); + else { + MatrixBlock first = inputs.get(0); + MatrixBlock second = inputs.get(1); + String firstName = inputsNames.get(0); + String secondName = inputsNames.get(1); + + ArrayList thisInputs = new ArrayList<>(Arrays.asList(first, second)); + + ArrayList cnodeIn = new ArrayList<>(); + + CNode c1 = new CNodeData(firstName, 1, first.getNumRows(), first.getNumColumns(), DataType.MATRIX); + CNode c2 = new CNodeData(secondName, 2, second.getNumRows(), second.getNumColumns(), DataType.MATRIX); + cnodeIn.add(c1); + cnodeIn.add(c2); + CNode cnodeOut = new CNodeBinary(c1, c2, CNodeBinary.BinType.VECT_MATRIXMULT); + CNodeRow cnode = new CNodeRow(cnodeIn, cnodeOut); +// if(inputs.get(0).getNumColumns() == einc.outRows) { +// cnode.setConstDim2(einc.outCols); // output will be in1.cols,in2.cols +// +// }else{ +// cnode.setConstDim2(einc.outCols); +// +// } + cnode.setRowType(SpoofRowwise.RowType.NO_AGG_B1); - src = src.replace("%CONST_DIM2%",einc.outCols.toString());// super(RowType.%TYPE%, %CONST_DIM2%, %TB1%, %VECT_MEM%);\n" + + cnode.renameInputs(); -// System.out.println(src); - Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); - SpoofOperator op = CodegenUtils.createInstance(cla); - MatrixBlock mb = new MatrixBlock(); - mb.reset(einc.outRows, einc.outCols, false); + String src = cnode.codegen(false, SpoofCompiler.GeneratorAPI.JAVA); - mb.allocateDenseBlock(); - if(LibSpoofPrimitives.isFlipOuter(einc.outRows,einc.outCols)){ - ArrayList m2 = new ArrayList(2); + System.out.println(src); + Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); - m2.add(inputs.get(1)); - m2.add(first); - MatrixBlock out =op.execute(m2,scalars,mb,_numThreads); - ec.setMatrixOutput(output.getName(), out); - }else{ - ArrayList m2 = new ArrayList(2); - m2.add(first); - m2.add(inputs.get(1)); - MatrixBlock out =op.execute(m2,scalars,mb,_numThreads); + SpoofOperator op = CodegenUtils.createInstance(cla); + MatrixBlock mb = new MatrixBlock(); + mb.reset(einc.outRows, einc.outCols, false); + mb.allocateDenseBlock(); + + + MatrixBlock out = op.execute(thisInputs, scalars, mb, _numThreads); ec.setMatrixOutput(output.getName(), out); } - } - else{ //fallback to cell + } else { //fallback to cell ArrayList cnodeIn = new ArrayList<>(); cnodeIn.add(new CNodeData(new LiteralOp(3), 0, 0, DataType.SCALAR)); CNodeCell cnode = new CNodeCell(cnodeIn, null); @@ -259,18 +455,18 @@ else if(inputsChars.length == 2 && inputsChars[0].charAt(1)==inputsChars[1].char StringBuilder sb = new StringBuilder(); String outputChars = parts[1]; - if( outputChars.length()==2){ - einc.summingChars.remove( outputChars.charAt(0)); - einc.summingChars.remove( outputChars.charAt(1)); + if (outputChars.length() == 2) { + einc.summingChars.remove(outputChars.charAt(0)); + einc.summingChars.remove(outputChars.charAt(1)); - }else if( outputChars.length()==1){ - einc.summingChars.remove( outputChars.charAt(0)); + } else if (outputChars.length() == 1) { + einc.summingChars.remove(outputChars.charAt(0)); } - boolean needsSumming = einc.summingChars.stream().anyMatch(x->x != null); + boolean needsSumming = einc.summingChars.stream().anyMatch(x -> x != null); String itVar0 = "TMP123";//+new IDSequence().getNextID(); todo: generate this var String outVar = null; - if(needsSumming){ + if (needsSumming) { outVar = "TMP123";//+new IDSequence().getNextID(); sb.append("double "); sb.append(outVar); @@ -282,7 +478,7 @@ else if(inputsChars.length == 2 && inputsChars[0].charAt(1)==inputsChars[1].char while (hsIt.hasNext()) { Character c = hsIt.next(); - String itVar = itVar0+c; + String itVar = itVar0 + c; sb.append("for(int "); sb.append(itVar); sb.append("=0;"); @@ -293,35 +489,35 @@ else if(inputsChars.length == 2 && inputsChars[0].charAt(1)==inputsChars[1].char sb.append(itVar); sb.append("++){\n"); } - if (needsSumming){ + if (needsSumming) { sb.append(outVar); sb.append("+="); } - if(parts[1].length()==2){ - for (int i=0;i< inputsChars.length;i++){ - if(einc.summingChars.contains(inputsChars[i].charAt(0))){ + if (parts[1].length() == 2) { + for (int i = 0; i < inputsChars.size(); i++) { + if (einc.summingChars.contains(inputsChars.get(i).charAt(0))) { sb.append("getValue(b["); sb.append(i); sb.append("],b["); sb.append(i); sb.append("].clen,"); sb.append(itVar0); - sb.append(inputsChars[i].charAt(0)); + sb.append(inputsChars.get(i).charAt(0)); sb.append(","); - }else if(inputsChars[i].charAt(0)==outputChars.charAt(0)) { + } else if (inputsChars.get(i).charAt(0) == outputChars.charAt(0)) { sb.append("getValue(b["); sb.append(i); sb.append("],b["); sb.append(i); sb.append("].clen, rix,"); - }else if(inputsChars[i].charAt(0)==outputChars.charAt(1)) { + } else if (inputsChars.get(i).charAt(0) == outputChars.charAt(1)) { sb.append("getValue(b["); sb.append(i); sb.append("],b["); sb.append(i); sb.append("].clen, cix,"); - }else { + } else { sb.append("getValue(b["); sb.append(i); sb.append("],b["); @@ -329,51 +525,50 @@ else if(inputsChars.length == 2 && inputsChars[0].charAt(1)==inputsChars[1].char sb.append("].clen, 0,"); } - if(einc.summingChars.contains(inputsChars[i].charAt(1))) { + if (einc.summingChars.contains(inputsChars.get(i).charAt(1))) { sb.append(itVar0); - sb.append(inputsChars[i].charAt(1)); + sb.append(inputsChars.get(i).charAt(1)); sb.append(")"); - } - else if(inputsChars[i].charAt(1)==outputChars.charAt(0)){ - sb.append("rix)"); - }else if(inputsChars[i].charAt(1)==outputChars.charAt(1)){ + } else if (inputsChars.get(i).charAt(1) == outputChars.charAt(0)) { + sb.append("rix)"); + } else if (inputsChars.get(i).charAt(1) == outputChars.charAt(1)) { sb.append("cix)"); - }else { + } else { sb.append("0)"); } - if(i toSum = null; + Character sumC = null; + anyCouldNotDo = false; + Character cInOut = null; + for (Character c : partsCharactersToIndices.keySet()) { + if (c == outChar1 || c == outChar2) + continue; + toSum = partsCharactersToIndices.get(c).stream() + .filter(Objects::nonNull).toList(); + if (toSum.size() > 2) { + anyCouldNotDo = true; + continue; } - else{ - ec.setMatrixOutput(output.getName(), out); + if (toSum.size() != 2) + continue; + sumC = c; + break; + } + if (anyCouldNotDo) { + break; + } + if (sumC == null) { + //check if maybe there are out-put characters only terms like a,a,ab->ba + List remStrings = inputsChars.stream() + .filter(Objects::nonNull).toList(); + List remMbs = inputs.stream() + .filter(Objects::nonNull).toList(); + if(remStrings.size() > 1){ + Pair res = computRowSummationsOutputCharsOnly(remMbs, remStrings, parts[1],scalar); + scalar = null; + inputs = new ArrayList<>(Arrays.asList(res.getLeft())); + inputsChars = new ArrayList<>(Arrays.asList(res.getRight())); + } + break; //nothing left to sum + } + + Pair res = computeRowSummation(toSum, inputs, inputsChars, scalar); + scalar = null; + String newS = res.getRight(); + + var iter = toSum.listIterator(); + Integer ii = iter.next(); + for (Integer idx : toSum) { + inputs.set(idx, null); + inputsChars.set(idx, null); + } + inputs.add(res.getLeft()); + inputsChars.add(newS); + + for (int i = 0; i < newS.length(); i++) { + char c = newS.charAt(i); +// partsCharactersToIndices.get(c).remove(c); + partsCharactersToIndices.get(c).add(inputs.size() - 1); + } + + +// for(int i=0;i remStrings = inputsChars.stream() + .filter(Objects::nonNull).toList(); + List remMbs = inputs.stream() + .filter(Objects::nonNull).toList(); + MatrixBlock res; + if(remStrings.size() == 1){ + String s = remStrings.get(0); + if(s.equals(parts[1])){ + res=remMbs.get(0); + }else if(s.charAt(0)==s.charAt(1)) { + // diagonal needed + ReorgOperator op = new ReorgOperator(DiagIndex.getDiagIndexFnObject()); + res= remMbs.get(0).reorgOperations(op, new MatrixBlock(),0,0,0); + }else{ + //it has to be transpose: ab->ba + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);//todo move to separate op earlier + res = remMbs.get(0).reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + } + }else{ + throw new RuntimeException("did not expect this!"); + } + ec.setMatrixOutput(output.getName(), res); + } + else { + ArrayList mbs = new ArrayList<>(); + ArrayList chars = new ArrayList<>(); + for (int i = 0; i < inputs.size(); i++) { + MatrixBlock mb = inputs.get(i); + if (mb != null) { + mbs.add(mb); + chars.add(inputsChars.get(i)); } } +// HashSet summingChars = new HashSet<>(); +// for(String s : inputsChars){ +// if(s == null) continue; +// if(s.length() == 1) summingChars.add(s.charAt(0)); +// if(s.length() == 2) { +// summingChars.add(s.charAt(0)); +// summingChars.add(s.charAt(1)); +// } +// } + ArrayList summingChars = new ArrayList(); + for (Character c : partsCharactersToIndices.keySet()) { + if (c != outChar1 && c != outChar2) summingChars.add(c); + } + //computeCellSummation(ArrayList inputs, List inputsChars, String resultString, + // HashMap charToDimensionSizeInt, List summingChars) + MatrixBlock res = computeCellSummation(mbs, chars, parts[1], einc.charToDimensionSizeInt, summingChars, einc.outRows, einc.outCols); + + if (einc.outRows == 1 && einc.outCols == 1) + ec.setScalarOutput(output.getName(), new DoubleObject(res.get(0, 0))); + else ec.setMatrixOutput(output.getName(), res); + + } + } + //final operation // release input matrices @@ -432,6 +778,341 @@ else if(inputsChars[i].charAt(1)==outputChars.charAt(0)){ ec.releaseMatrixInput(input.getName()); } + private enum SumOperation { + aB_a, + Ba_a, + Ba_aC, // mmult +// aB_Ca, + Ba_Ca, + aB_aC, // outer mult + a_a, + } + + private enum AggregateAtEnd{ + Left, + Right, + Both, + None, + } + private Pair computRowSummationsOutputCharsOnly(List inputs, List inputsChars, String resString, Double scalar ){ + if(resString.length() == 1){ + // dont expect more than two of these, throw error if happens + if(inputs.size() != 2) throw new RuntimeException("did not expects this, please fix me"); + MatrixBlock mb = getCodegenMatrixBlock(inputs.get(0), inputs.get(1), CNodeBinary.BinType.VECT_MULT, SpoofRowwise.RowType.NO_AGG); + return Pair.of(mb, inputsChars.get(0)); + }else{ // resString.length() == 2 + // something like a,a,b,b,ab,ba + // group them + + ArrayList a = new ArrayList<>(); + ArrayList b = new ArrayList<>(); + ArrayList ab = new ArrayList<>(); + ArrayList ba = new ArrayList<>(); + for(int i =0;i< inputs.size(); i++){ + String s = inputsChars.get(i); + if(s.length() == 2){ + if(s.equals(resString)) ab.add(inputs.get(i)); + else ba.add(inputs.get(i)); + }else{ + if(s.charAt(0)==resString.charAt(0)) a.add(inputs.get(i)); + else b.add(inputs.get(i)); + } + } + // mult all a-s: + // mult all b-s: + // check if there is ab or ba + // if no: + // then do outer product axb or bxa + // if there is then: + // mult ba and a + // mult ab and b + // transp ba into ab + // mult 2 ab and ab + return Pair.of( ab.get(0) ,resString); +// throw new NotImplementedException("todo"); +// return null; + } + } + private Pair computeRowSummation(List toSum, ArrayList inputs, List inputsChars) { + return computeRowSummation(toSum,inputs,inputsChars, null); + } + private Pair computeRowSummation(List toSum, ArrayList inputs, List inputsChars, Double scalar) { + + if(toSum.size() != 2){ + return null; + } + + String s1 = inputsChars.get(toSum.get(0)); + String s2 = inputsChars.get(toSum.get(1)); + + + MatrixBlock first = null; + MatrixBlock second = null; + + String resS; + SumOperation sumOp; + + if(s1.length()==1 && s2.length() == 1){ //remove never happening here + sumOp = SumOperation.a_a; + resS = ""; + } + else if(s2.length() == 1 || s1.length() == 1){ + if(s1.length() == 1){ + String sTemp = s1; + s1=s2; + s2=sTemp; + + first = inputs.get(toSum.get(1)); + second = inputs.get(toSum.get(0)); + }else{ + first = inputs.get(toSum.get(0)); + second = inputs.get(toSum.get(1)); + } + + if(s1.charAt(0) == s2.charAt(0)){ + sumOp = SumOperation.aB_a; + resS = String.valueOf(s1.charAt(1)); + }else{ + sumOp = SumOperation.Ba_a; + resS = String.valueOf(s1.charAt(0)); + } + } + else if(s1.charAt(0) == s2.charAt(0)){ + sumOp = SumOperation.aB_aC; + first = inputs.get(toSum.get(0)); + second = inputs.get(toSum.get(1)); + resS = String.valueOf(s1.charAt(1))+String.valueOf(s2.charAt(1)); + + } + else if(s1.charAt(1) == s2.charAt(1)){ + sumOp = SumOperation.Ba_Ca; + first = inputs.get(toSum.get(0)); + second = inputs.get(toSum.get(1)); + resS = String.valueOf(s1.charAt(0))+String.valueOf(s2.charAt(0)); + } + else if(s1.charAt(0) == s2.charAt(1)){ + sumOp = SumOperation.Ba_aC; + String sTemp = s1; + s1=s2; + s2=sTemp; + + first = inputs.get(toSum.get(0)); + second = inputs.get(toSum.get(1)); + resS = String.valueOf(s1.charAt(0))+String.valueOf(s2.charAt(1)); + + } + else if(s1.charAt(1) == s2.charAt(0)){ + sumOp = SumOperation.Ba_aC; + first = inputs.get(toSum.get(0)); + second = inputs.get(toSum.get(1)); + resS = String.valueOf(s1.charAt(0))+String.valueOf(s2.charAt(1)); + + }else{ + throw new RuntimeException("Error when choosing row multiplication operation"); + } + MatrixBlock out; + + switch (sumOp) { + case Ba_a: + throw new NotImplementedException(); + case Ba_aC: { + out = getCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_MATRIXMULT, SpoofRowwise.RowType.NO_AGG, null, scalar); + break; + } + case Ba_Ca: + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); + second = second.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + out = getCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_MATRIXMULT, SpoofRowwise.RowType.NO_AGG, null, scalar); + break; + case aB_a: + case aB_aC: { + out = getCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_OUTERMULT_ADD, SpoofRowwise.RowType.COL_AGG_B1_T, Long.valueOf( second.getNumColumns()),scalar); + break; + } + case a_a: + out = getCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_MULT, SpoofRowwise.RowType.NO_AGG,null, scalar); + break; + default: + throw new IllegalStateException("Unexpected value: " + sumOp); + } + return Pair.of(out , resS); + } + private MatrixBlock getCodegenMatrixBlock(MatrixBlock first, MatrixBlock second, CNodeBinary.BinType binaryType, SpoofRowwise.RowType rowType){ + return getCodegenMatrixBlock(first, second, binaryType,rowType,null, null); + } + private MatrixBlock getCodegenMatrixBlock(MatrixBlock first, MatrixBlock second, CNodeBinary.BinType binaryType, SpoofRowwise.RowType rowType, Long secondDim, Double scalar) { + ArrayList thisInputs = new ArrayList<>(Arrays.asList(first, second)); + + ArrayList cnodeIn = new ArrayList<>(); + + CNode c1 = new CNodeData("c1", 1, first.getNumRows(), first.getNumColumns(), DataType.MATRIX); + CNode c2 = new CNodeData("c2", 2, second.getNumRows(), second.getNumColumns(), DataType.MATRIX); + cnodeIn.add(c1); + cnodeIn.add(c2); + CNode cnodeOut = new CNodeBinary(c1, c2, binaryType); + CNodeRow cnode = new CNodeRow(cnodeIn, cnodeOut); + + cnode.setRowType(rowType); + + if(secondDim != null) cnode.setConstDim2(secondDim); + cnode.renameInputs(); + + String src = cnode.codegen(false, SpoofCompiler.GeneratorAPI.JAVA); + + System.out.println(src); + Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); + + SpoofOperator op = CodegenUtils.createInstance(cla); + MatrixBlock mb = new MatrixBlock(); +// mb.reset(einc.outRows, einc.outCols, false); +// mb.allocateDenseBlock(); + + ArrayList scalars = new ArrayList<>(); + if(scalar != null) scalars.add(new DoubleObject(scalar)); + MatrixBlock out = op.execute(thisInputs, scalars, mb, _numThreads); + return out; + } + + private MatrixBlock computeCellSummation(ArrayList inputs, List inputsChars, String resultString, + HashMap charToDimensionSizeInt, List summingChars, int outRows, int outCols){ + ArrayList cnodeIn = new ArrayList<>(); + cnodeIn.add(new CNodeData(new LiteralOp(3), 0, 0, DataType.SCALAR)); + CNodeCell cnode = new CNodeCell(cnodeIn, null); +// cnode.setCellType(SpoofCellwise.CellType.NO_AGG); + StringBuilder sb = new StringBuilder(); + +// if (resultString.length() == 2) { +// summingChars.remove(resultString.charAt(0)); +// summingChars.remove(resultString.charAt(1)); +// +// } else if (resultString.length() == 1) { +// summingChars.remove(resultString.charAt(0)); +// +// } + boolean needsSumming = summingChars.stream().anyMatch(x -> x != null); + String itVar0 = "TMP123";//+new IDSequence().getNextID(); todo: generate this var + String outVar = null; + if (needsSumming) { + outVar = "TMP123";//+new IDSequence().getNextID(); + sb.append("double "); + sb.append(outVar); + sb.append("=0;\n"); + } + + HashSet summedCharacters = new HashSet<>(); + Iterator hsIt = summingChars.iterator(); + while (hsIt.hasNext()) { + + Character c = hsIt.next(); + String itVar = itVar0 + c; + sb.append("for(int "); + sb.append(itVar); + sb.append("=0;"); + sb.append(itVar); + sb.append("<"); + sb.append(charToDimensionSizeInt.get(c)); + sb.append(";"); + sb.append(itVar); + sb.append("++){\n"); + } + if (needsSumming) { + sb.append(outVar); + sb.append("+="); + } + + for (int i = 0; i < inputsChars.size(); i++) { + if (inputsChars.get(i).length() == 0){ + sb.append("getValue(b["); + sb.append(i); + sb.append("],b["); + sb.append(i); + sb.append("].clen, 0,"); + } + + else if (summingChars.contains(inputsChars.get(i).charAt(0))) { + sb.append("getValue(b["); + sb.append(i); + sb.append("],b["); + sb.append(i); + sb.append("].clen,"); + sb.append(itVar0); + sb.append(inputsChars.get(i).charAt(0)); + sb.append(","); + } else if (resultString.length() >= 1 && inputsChars.get(i).charAt(0) == resultString.charAt(0)) { + sb.append("getValue(b["); + sb.append(i); + sb.append("],b["); + sb.append(i); + sb.append("].clen, rix,"); + } else if (resultString.length() == 2 && inputsChars.get(i).charAt(0) == resultString.charAt(1)) { + sb.append("getValue(b["); + sb.append(i); + sb.append("],b["); + sb.append(i); + sb.append("].clen, cix,"); + } else { + sb.append("getValue(b["); + sb.append(i); + sb.append("],b["); + sb.append(i); + sb.append("].clen, 0,"); + } + + if (inputsChars.get(i).length() != 2){ + sb.append("0)"); + } + else if (summingChars.contains(inputsChars.get(i).charAt(1))) { + sb.append(itVar0); + sb.append(inputsChars.get(i).charAt(1)); + sb.append(")"); + } else if (resultString.length() >= 1 &&inputsChars.get(i).charAt(1) == resultString.charAt(0)) { + sb.append("rix)"); + } else if (resultString.length() == 2 && inputsChars.get(i).charAt(1) == resultString.charAt(1)) { + sb.append("cix)"); + } else { + sb.append("0)"); + } + + + if (i < inputsChars.size() - 1) { + sb.append("*"); + } + + } + if (needsSumming) { + sb.append(";"); + } + for (int si = 0; si < summingChars.size(); si++) { + sb.append("}\n"); + } + String src = tmpCell; + src = src.replace("%TMP%", cnode.createVarname()); + if (needsSumming) { + src = src.replace("%BODY_dense%", sb.toString()); + src = src.replace("%OUT%", outVar); + } else { + src = src.replace("%BODY_dense%", ""); + src = src.replace("%OUT%", sb.toString()); + } + +// String src = needsSumming ? cnode.codegenEinsum(false, SpoofCompiler.GeneratorAPI.JAVA, sb.toString(), outVar) : cnode.codegenEinsum(false, SpoofCompiler.GeneratorAPI.JAVA, "", sb.toString()); + System.out.println(src); + Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); + SpoofOperator op = CodegenUtils.createInstance(cla); + MatrixBlock resBlock = new MatrixBlock(); + resBlock.reset(outRows, outCols); + inputs.add(0, resBlock); + MatrixBlock out = op.execute(inputs, new ArrayList<>(), new MatrixBlock(), _numThreads); + if (outRows == 1 && outCols == 1) { +// ec.setScalarOutput(output.getName(), new DoubleObject(out.get(0, 0))); + return out; + } else { +// ec.setMatrixOutput(output.getName(), out); + return out; + + } + } + // @Override // public Pair getLineageItem(ExecutionContext ec) { // //return the lineage item if already traced once diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumContext.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumContext.java index d8954b91884..81f690eccb9 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumContext.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumContext.java @@ -34,6 +34,7 @@ public class EinsumContext { public Integer[] contractDims; public Integer[] summingDims; public HashSet summingChars; + public HashSet contractDimsSet; public static EinsumContext getEinsumContext(String eqStr, ArrayList inputs){ EinsumContext res = new EinsumContext(); @@ -47,6 +48,9 @@ public static EinsumContext getEinsumContext(String eqStr, ArrayList summingChars = new HashSet<>(); Integer[] contractDims = new Integer[inputs.size()];//0==nothing, 1 = right, 2=left, 3 = both Integer[] summingDims = new Integer[inputs.size()];//0/null==nothing, 1 = right, 2=left, 3 = both + HashSet contractDimsSet = new HashSet(); + + int arrIt = 0; for (i = 0; true; i++){ char c = eqStr.charAt(i); @@ -55,6 +59,7 @@ public static EinsumContext getEinsumContext(String eqStr, ArrayListadd 1, it==1->add 2 }else{ @@ -139,6 +146,8 @@ else if(arrSizeIterator==1) }else if(c==c1){ // this dim is remaining }else{ + contractDimsSet.add(c); + if(contractDims[arrIt]==null){ contractDims[arrIt]=arrSizeIterator +1; @@ -179,6 +188,8 @@ else if(arrSizeIterator==1) }else if(c==c1 || c==c2){ // this dim is remaining }else{ + contractDimsSet.add(c); + if(contractDims[arrIt]==null){ contractDims[arrIt]=arrSizeIterator +1; @@ -193,6 +204,7 @@ else if(arrSizeIterator==1) throw new RuntimeException("output dim > 2 not supported for now"); } res.contractDims=contractDims; + res.contractDimsSet = contractDimsSet; res.summingDims=summingDims; res.summingChars = summingChars; diff --git a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java index 471aa47e4b8..6fab39dd2cf 100644 --- a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java +++ b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java @@ -48,6 +48,8 @@ public class EinsumTest extends AutomatedTestBase private static final String TEST_EINSUM5 = TEST_NAME_EINSUM+"5"; private static final String TEST_EINSUM6 = TEST_NAME_EINSUM+"6"; private static final String TEST_EINSUM7 = TEST_NAME_EINSUM+"7"; + private static final String TEST_EINSUM8 = TEST_NAME_EINSUM+"8"; + private static final String TEST_EINSUM9 = TEST_NAME_EINSUM+"9"; private static final String TEST_DIR = "functions/einsum/"; private static final String TEST_CLASS_DIR = TEST_DIR + EinsumTest.class.getSimpleName() + "/"; @@ -59,7 +61,7 @@ public class EinsumTest extends AutomatedTestBase @Override public void setUp() { TestUtils.clearAssertionInformation(); - for(int i=1; i<=7; i++) + for(int i=1; i<=9; i++) addTestConfiguration( TEST_NAME_EINSUM+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_EINSUM+i, new String[] { String.valueOf(i) }) ); } @Test @@ -90,6 +92,14 @@ public void testCodegenEinsum6CP() { public void testCodegenEinsum7CP() { testCodegenIntegration( TEST_EINSUM7, false, ExecType.CP ); } + @Test + public void testCodegenEinsum8CP() { + testCodegenIntegration( TEST_EINSUM8, false, ExecType.CP ); + } + @Test + public void testCodegenEinsum9CP() { + testCodegenIntegration( TEST_EINSUM9, false, ExecType.CP ); + } private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType ) { diff --git a/src/test/scripts/functions/einsum/einsum8.R b/src/test/scripts/functions/einsum/einsum8.R new file mode 100644 index 00000000000..5587c9d878f --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum8.R @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") +library("einsum") + +P = matrix(seq(1,3000), 600, 5, byrow=TRUE); +X = matrix(seq(1,6000), 600, 10, byrow=TRUE); + +# R = t(P) %*% X; +R = einsum("ji,jz->zi",P,X) + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum8.dml b/src/test/scripts/functions/einsum/einsum8.dml new file mode 100644 index 00000000000..2e47614800b --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum8.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +P = matrix(seq(1,3000), 600,5); +X = matrix(seq(1,6000), 600, 10) + +while(FALSE){} + +#R = t(P) %*% X ; + +R = einsum("ji,jz->zi",P,X) +write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum9.R b/src/test/scripts/functions/einsum/einsum9.R new file mode 100644 index 00000000000..61cb86a96a7 --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum9.R @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") +library("einsum") + +P = matrix(seq(1,3000), 5, 600, byrow=TRUE); +X = matrix(seq(1,6000), 10, 600, byrow=TRUE); + +# R = t(P) %*% X; +R = einsum("ij,zj->iz",P,X) + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum9.dml b/src/test/scripts/functions/einsum/einsum9.dml new file mode 100644 index 00000000000..a22c03fdc9b --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum9.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +P = matrix(seq(1,3000), 5,600); +X = matrix(seq(1,6000), 10, 600) + +while(FALSE){} + +#R = t(P) %*% X ; + +R = einsum("ij,zj->iz",P,X) +write(R, $1) From 6b3c67f1244af9e84f82819e62f484edb5250b02 Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Mon, 9 Jun 2025 00:00:24 +0200 Subject: [PATCH 06/28] removed old code --- .../instructions/cp/EinsumCPInstruction.java | 638 +++--------------- 1 file changed, 104 insertions(+), 534 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java index 0e7195f7622..d280653be58 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -49,7 +49,6 @@ import static org.apache.sysds.runtime.instructions.cp.EinsumContext.getEinsumContext; public class EinsumCPInstruction extends BuiltinNaryCPInstruction { - public static boolean oldCode= false; public static boolean forceCell = false; protected static final Log LOG = LogFactory.getLog(EinsumCPInstruction.class.getName()); public String eqStr; @@ -81,24 +80,6 @@ public Class getOperatorClass() { private static final int CONTRACT_RIGHT = 2; private static final int CONTRACT_BOTH = 3; -// public static EinsumCPInstruction parseInstruction(String str) { -// String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); -// -// ArrayList inlist = new ArrayList<>(); -// -// for (int i = 2; i < parts.length - 1; i++) -// { -// inlist.add(new CPOperand(parts[i])); -// } -// -// CPOperand out = new CPOperand(parts[parts.length-1]); -//// int k = 1;//Integer.parseInt(parts[parts.length-1]); -// int k = OptimizerUtils.getConstrainedNumThreads(-1); -// -// String eqString = new CPOperand(parts[1]).getName(); //todo change -// -// return new EinsumCPInstruction(k, inlist.toArray(new CPOperand[0]), out, parts[0], str, eqString); -// } @Override public void processInstruction(ExecutionContext ec) { @@ -235,464 +216,85 @@ else if(!einc.contractDimsSet.contains(c2)){ } -//todo -> go througth dims to count and try to do row mults and order them for(Character c :partsCharactersToIndices.keySet()){ ArrayList a = partsCharactersToIndices.get(c); System.out.println(c+" count= "+a.size()); } - if(oldCode) { - if (!forceCell && inputsChars.size() == 2 && inputsChars.get(0).charAt(0) == inputsChars.get(1).charAt(0) && !einc.summingChars.contains(parts[1].charAt(0))) { - - - // ja,jb->... - // outer tmpl - boolean oldWay = false; - if (oldWay) { - // outer tmpl - ArrayList cnodeIn = new ArrayList<>(); - cnodeIn.add(new CNodeData(new LiteralOp(3), 0, 0, DataType.SCALAR)); - CNodeRow cnode = new CNodeRow(cnodeIn, null); -// cnode.setConstDim2(einc.outCols); -// cnode.setNumVectorIntermediates(1); - String src = tmpRow; - - if (einc.outCols == 1) { -// cnode.setRowType(SpoofRowwise.RowType.ROW_AGG); -// src = src.replace("%TYPE%","ROW_AGG"); - src = src.replace("%TYPE%", "COL_AGG_B1_T"); - - } else { -// cnode.setRowType(SpoofRowwise.RowType.COL_AGG_B1_T); - src = src.replace("%TYPE%", "COL_AGG_B1_T"); - - } - src = src.replace("%TMP%", cnode.createVarname()); - -// String src= cnode.codegenEinsum(false, SpoofCompiler.GeneratorAPI.JAVA); - src = src.replace("%CONST_DIM2%", einc.outCols.toString());// super(RowType.%TYPE%, %CONST_DIM2%, %TB1%, %VECT_MEM%);\n" + - - System.out.println(src); - Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); -// Class cla = CodegenUtils.compileClass("codegen.TMP0", src); - SpoofOperator op = CodegenUtils.createInstance(cla); - MatrixBlock mb = new MatrixBlock(); - mb.reset(einc.outRows, einc.outCols, false); - mb.allocateDenseBlock(); - if (LibSpoofPrimitives.isFlipOuter(einc.outRows, einc.outCols)) { -// System.out.println("swapping"); - ArrayList m2 = new ArrayList(2); - - m2.add(inputs.get(1)); - m2.add(inputs.get(0)); - MatrixBlock out = op.execute(m2, scalars, mb, _numThreads); - ec.setMatrixOutput(output.getName(), out); - } else { -// System.out.println("NOTswapping"); - MatrixBlock out = op.execute(inputs, scalars, mb, _numThreads); - ec.setMatrixOutput(output.getName(), out); - } - } else { - MatrixBlock first; - MatrixBlock second; - String firstName; - String secondName; - if (inputs.get(0).getNumColumns() == einc.outRows) { - first = inputs.get(0); - second = inputs.get(1); - firstName = inputsNames.get(0); - secondName = inputsNames.get(1); - } else { - first = inputs.get(1); - second = inputs.get(0); - firstName = inputsNames.get(1); - secondName = inputsNames.get(0); - } - - ArrayList thisInputs = new ArrayList<>(Arrays.asList(first, second)); - - ArrayList cnodeIn = new ArrayList<>(); - - CNode c1 = new CNodeData(firstName, 1, first.getNumRows(), first.getNumColumns(), DataType.MATRIX); - CNode c2 = new CNodeData(secondName, 2, second.getNumRows(), second.getNumColumns(), DataType.MATRIX); - cnodeIn.add(c1); - cnodeIn.add(c2); - CNode cnodeOut = new CNodeBinary(c1, c2, CNodeBinary.BinType.VECT_OUTERMULT_ADD); - CNodeRow cnode = new CNodeRow(cnodeIn, cnodeOut); - if (inputs.get(0).getNumColumns() == einc.outRows) { - cnode.setConstDim2(einc.outCols); // output will be in1.cols,in2.cols - - } else { - cnode.setConstDim2(einc.outCols); - - } - cnode.setRowType(SpoofRowwise.RowType.COL_AGG_B1_T); -// if (einc.outCols == 1) { -// cnode.setRowType(SpoofRowwise.RowType.COL_AGG_B1R); -// -//// cnode.setConstDim2(1); -//// cnode.setNumVectorIntermediates(1); -// -// } else { -// cnode.setRowType(SpoofRowwise.RowType.COL_AGG_B1_T); -// } - cnode.renameInputs(); - - String src = cnode.codegen(false, SpoofCompiler.GeneratorAPI.JAVA); - - System.out.println(src); - Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); - - SpoofOperator op = CodegenUtils.createInstance(cla); - MatrixBlock mb = new MatrixBlock(); - mb.reset(einc.outRows, einc.outCols, false); - mb.allocateDenseBlock(); - -// MatrixBlock out = op.execute(inputs, scalars, mb, _numThreads); -// ec.setMatrixOutput(output.getName(), out); -// if(LibSpoofPrimitives.isFlipOuter(einc.outRows,einc.outCols)){ -// ArrayList m2 = new ArrayList(2); -// -// m2.add(inputs.get(1)); -// m2.add(inputs.get(0)); -// MatrixBlock out =op.execute(m2,scalars,mb,_numThreads); -// ec.setMatrixOutput(output.getName(), out); -// }else{ -//// ArrayList m2 = new ArrayList(2); -//// . - MatrixBlock out = op.execute(thisInputs, scalars, mb, _numThreads); - ec.setMatrixOutput(output.getName(), out); -// } - } - } else if (!forceCell && inputsChars.size() == 2 && inputsChars.get(0).charAt(1) == inputsChars.get(1).charAt(0)) { - if (false) { - ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);//todo move to separate op earlier - MatrixBlock first = (inputs.get(0)).reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - ArrayList cnodeIn = new ArrayList<>(); - cnodeIn.add(new CNodeData(new LiteralOp(3), 0, 0, DataType.SCALAR)); - CNodeRow cnode = new CNodeRow(cnodeIn, null); - String src = tmpRow; - - if (einc.outCols == 1) { - src = src.replace("%TYPE%", "ROW_AGG"); - - } else { - src = src.replace("%TYPE%", "COL_AGG_B1_T"); - } - src = src.replace("%TMP%", cnode.createVarname()); - - src = src.replace("%CONST_DIM2%", einc.outCols.toString());// super(RowType.%TYPE%, %CONST_DIM2%, %TB1%, %VECT_MEM%);\n" + - - System.out.println(src); - Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); - SpoofOperator op = CodegenUtils.createInstance(cla); - MatrixBlock mb = new MatrixBlock(); - mb.reset(einc.outRows, einc.outCols, false); - - mb.allocateDenseBlock(); - if (LibSpoofPrimitives.isFlipOuter(einc.outRows, einc.outCols)) { - ArrayList m2 = new ArrayList(2); - - m2.add(inputs.get(1)); - m2.add(first); - MatrixBlock out = op.execute(m2, scalars, mb, _numThreads); - ec.setMatrixOutput(output.getName(), out); - } else { - ArrayList m2 = new ArrayList(2); - m2.add(first); - m2.add(inputs.get(1)); - MatrixBlock out = op.execute(m2, scalars, mb, _numThreads); - ec.setMatrixOutput(output.getName(), out); - } - } - else { - MatrixBlock first = inputs.get(0); - MatrixBlock second = inputs.get(1); - String firstName = inputsNames.get(0); - String secondName = inputsNames.get(1); - - ArrayList thisInputs = new ArrayList<>(Arrays.asList(first, second)); - - ArrayList cnodeIn = new ArrayList<>(); - - CNode c1 = new CNodeData(firstName, 1, first.getNumRows(), first.getNumColumns(), DataType.MATRIX); - CNode c2 = new CNodeData(secondName, 2, second.getNumRows(), second.getNumColumns(), DataType.MATRIX); - cnodeIn.add(c1); - cnodeIn.add(c2); - CNode cnodeOut = new CNodeBinary(c1, c2, CNodeBinary.BinType.VECT_MATRIXMULT); - CNodeRow cnode = new CNodeRow(cnodeIn, cnodeOut); -// if(inputs.get(0).getNumColumns() == einc.outRows) { -// cnode.setConstDim2(einc.outCols); // output will be in1.cols,in2.cols -// -// }else{ -// cnode.setConstDim2(einc.outCols); -// -// } - cnode.setRowType(SpoofRowwise.RowType.NO_AGG_B1); - - cnode.renameInputs(); - - String src = cnode.codegen(false, SpoofCompiler.GeneratorAPI.JAVA); - - System.out.println(src); - Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); - - SpoofOperator op = CodegenUtils.createInstance(cla); - MatrixBlock mb = new MatrixBlock(); - mb.reset(einc.outRows, einc.outCols, false); - mb.allocateDenseBlock(); - - - MatrixBlock out = op.execute(thisInputs, scalars, mb, _numThreads); - ec.setMatrixOutput(output.getName(), out); - } - } else { //fallback to cell - ArrayList cnodeIn = new ArrayList<>(); - cnodeIn.add(new CNodeData(new LiteralOp(3), 0, 0, DataType.SCALAR)); - CNodeCell cnode = new CNodeCell(cnodeIn, null); -// cnode.setCellType(SpoofCellwise.CellType.NO_AGG); - StringBuilder sb = new StringBuilder(); - - String outputChars = parts[1]; - if (outputChars.length() == 2) { - einc.summingChars.remove(outputChars.charAt(0)); - einc.summingChars.remove(outputChars.charAt(1)); - - } else if (outputChars.length() == 1) { - einc.summingChars.remove(outputChars.charAt(0)); - - } - boolean needsSumming = einc.summingChars.stream().anyMatch(x -> x != null); - String itVar0 = "TMP123";//+new IDSequence().getNextID(); todo: generate this var - String outVar = null; - if (needsSumming) { - outVar = "TMP123";//+new IDSequence().getNextID(); - sb.append("double "); - sb.append(outVar); - sb.append("=0;\n"); - } - - HashSet summedCharacters = new HashSet<>(); - Iterator hsIt = einc.summingChars.iterator(); - while (hsIt.hasNext()) { - - Character c = hsIt.next(); - String itVar = itVar0 + c; - sb.append("for(int "); - sb.append(itVar); - sb.append("=0;"); - sb.append(itVar); - sb.append("<"); - sb.append(einc.charToDimensionSizeInt.get(c)); - sb.append(";"); - sb.append(itVar); - sb.append("++){\n"); - } - if (needsSumming) { - sb.append(outVar); - sb.append("+="); - } - - if (parts[1].length() == 2) { - for (int i = 0; i < inputsChars.size(); i++) { - if (einc.summingChars.contains(inputsChars.get(i).charAt(0))) { - sb.append("getValue(b["); - sb.append(i); - sb.append("],b["); - sb.append(i); - sb.append("].clen,"); - sb.append(itVar0); - sb.append(inputsChars.get(i).charAt(0)); - sb.append(","); - } else if (inputsChars.get(i).charAt(0) == outputChars.charAt(0)) { - sb.append("getValue(b["); - sb.append(i); - sb.append("],b["); - sb.append(i); - sb.append("].clen, rix,"); - } else if (inputsChars.get(i).charAt(0) == outputChars.charAt(1)) { - sb.append("getValue(b["); - sb.append(i); - sb.append("],b["); - sb.append(i); - sb.append("].clen, cix,"); - } else { - sb.append("getValue(b["); - sb.append(i); - sb.append("],b["); - sb.append(i); - sb.append("].clen, 0,"); - } - - if (einc.summingChars.contains(inputsChars.get(i).charAt(1))) { - sb.append(itVar0); - sb.append(inputsChars.get(i).charAt(1)); - sb.append(")"); - } else if (inputsChars.get(i).charAt(1) == outputChars.charAt(0)) { - sb.append("rix)"); - } else if (inputsChars.get(i).charAt(1) == outputChars.charAt(1)) { - sb.append("cix)"); - - } else { - sb.append("0)"); - } - - - if (i < inputsChars.size() - 1) { - sb.append("*"); - } - - } - } else { - for (int i = 0; i < inputsChars.size(); i++) { - if (inputsChars.size() == 2) { - if (einc.summingChars.contains(inputsChars.get(i).charAt(0))) { - sb.append("getValue(b["); - sb.append(i); - sb.append("],b["); - sb.append(i); - sb.append("].clen,"); - sb.append(itVar0); - sb.append(inputsChars.get(i).charAt(0)); - sb.append(","); - } else if (inputsChars.get(i).charAt(0) == outputChars.charAt(0)) { - sb.append("getValue(b["); - sb.append(i); - sb.append("],b["); - sb.append(i); - sb.append("].clen, rix,"); - } else if (inputsChars.get(i).charAt(0) == outputChars.charAt(1)) { - sb.append("getValue(b["); - sb.append(i); - sb.append("],b["); - sb.append(i); - sb.append("].clen, cix,"); - } else { - sb.append("getValue(b["); - sb.append(i); - sb.append("],b["); - sb.append(i); - sb.append("].clen, 0,"); - } - sb.append("0)"); - - if (i < inputsChars.size() - 1) { - sb.append("*"); - } - - } - } - } - if (needsSumming) { - sb.append(";"); - } - for (int si = 0; si < einc.summingChars.size(); si++) { - sb.append("}\n"); - } - String src = tmpCell; - src = src.replace("%TMP%", cnode.createVarname()); - if (needsSumming) { - src = src.replace("%BODY_dense%", sb.toString()); - src = src.replace("%OUT%", outVar); - } else { - src = src.replace("%BODY_dense%", ""); - src = src.replace("%OUT%", sb.toString()); - } -// String src = needsSumming ? cnode.codegenEinsum(false, SpoofCompiler.GeneratorAPI.JAVA, sb.toString(), outVar) : cnode.codegenEinsum(false, SpoofCompiler.GeneratorAPI.JAVA, "", sb.toString()); - System.out.println(src); - Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); - SpoofOperator op = CodegenUtils.createInstance(cla); - MatrixBlock resBlock = new MatrixBlock(); - resBlock.reset(einc.outRows, einc.outCols); - inputs.add(0, resBlock); - MatrixBlock out = op.execute(inputs, scalars, new MatrixBlock(), _numThreads); - if (einc.outRows == 1 && einc.outCols == 1) { - ec.setScalarOutput(output.getName(), new DoubleObject(out.get(0, 0))); - - } else { - ec.setMatrixOutput(output.getName(), out); - } + // compute scalars: + Double scalar = null; + for(int i=0;i< inputs.size(); i++){ + String s = inputsChars.get(i); + if(s.equals("")){ + MatrixBlock mb = inputs.get(i); + if (scalar == null) scalar = mb.get(0,0); + else scalar*= mb.get(0,0); + inputs.set(i,null); + inputsChars.set(i,null); } } - // sum the characters: - - boolean anyCouldNotDo = true; - - if (!oldCode) { - // compute scalars: - Double scalar = null; - for(int i=0;i< inputs.size(); i++){ - String s = inputsChars.get(i); - if(s.equals("")){ - MatrixBlock mb = inputs.get(i); - if (scalar == null) scalar = mb.get(0,0); - else scalar*= mb.get(0,0); - inputs.set(i,null); - inputsChars.set(i,null); + boolean anyCouldNotDo = true; // information to do cell tpl for remaining ones + + while (!forceCell) { + List toSum = null; + Character sumC = null; + anyCouldNotDo = false; + Character cInOut = null; + for (Character c : partsCharactersToIndices.keySet()) { // sum on dim at the time + if (c == outChar1 || c == outChar2) + continue; + toSum = partsCharactersToIndices.get(c).stream() + .filter(Objects::nonNull).toList(); + if (toSum.size() > 2) { + anyCouldNotDo = true; + continue; } - + if (toSum.size() != 2) + continue; + sumC = c; + break; } - - while (!forceCell) { - List toSum = null; - Character sumC = null; - anyCouldNotDo = false; - Character cInOut = null; - for (Character c : partsCharactersToIndices.keySet()) { - if (c == outChar1 || c == outChar2) - continue; - toSum = partsCharactersToIndices.get(c).stream() - .filter(Objects::nonNull).toList(); - if (toSum.size() > 2) { - anyCouldNotDo = true; - continue; - } - if (toSum.size() != 2) - continue; - sumC = c; - break; - } - if (anyCouldNotDo) { - break; - } - if (sumC == null) { - //check if maybe there are out-put characters only terms like a,a,ab->ba - List remStrings = inputsChars.stream() - .filter(Objects::nonNull).toList(); - List remMbs = inputs.stream() - .filter(Objects::nonNull).toList(); - if(remStrings.size() > 1){ - Pair res = computRowSummationsOutputCharsOnly(remMbs, remStrings, parts[1],scalar); - scalar = null; - inputs = new ArrayList<>(Arrays.asList(res.getLeft())); - inputsChars = new ArrayList<>(Arrays.asList(res.getRight())); - } - break; //nothing left to sum + if (anyCouldNotDo) { + break; + } + if (sumC == null) { + //check if maybe there are out-put characters only terms like a,a,ab->ba + List remStrings = inputsChars.stream() + .filter(Objects::nonNull).toList(); + List remMbs = inputs.stream() + .filter(Objects::nonNull).toList(); + if(remStrings.size() > 1){ + Pair res = computRowSummationsOutputCharsOnly(remMbs, remStrings, parts[1],scalar); + scalar = null; + inputs = new ArrayList<>(Arrays.asList(res.getLeft())); + inputsChars = new ArrayList<>(Arrays.asList(res.getRight())); } + break; //nothing left to sum + } - Pair res = computeRowSummation(toSum, inputs, inputsChars, scalar); - scalar = null; - String newS = res.getRight(); + Pair res = computeRowSummation(toSum, inputs, inputsChars, scalar); + scalar = null; + String newS = res.getRight(); - var iter = toSum.listIterator(); - Integer ii = iter.next(); - for (Integer idx : toSum) { - inputs.set(idx, null); - inputsChars.set(idx, null); - } - inputs.add(res.getLeft()); - inputsChars.add(newS); +// var iter = toSum.listIterator(); +// Integer ii = iter.next(); + for (Integer idx : toSum) { + inputs.set(idx, null); + inputsChars.set(idx, null); + } + inputs.add(res.getLeft()); + inputsChars.add(newS); - for (int i = 0; i < newS.length(); i++) { - char c = newS.charAt(i); + for (int i = 0; i < newS.length(); i++) { + char c = newS.charAt(i); // partsCharactersToIndices.get(c).remove(c); - partsCharactersToIndices.get(c).add(inputs.size() - 1); - } + partsCharactersToIndices.get(c).add(inputs.size() - 1); + } // for(int i=0;i remStrings = inputsChars.stream() - .filter(Objects::nonNull).toList(); - List remMbs = inputs.stream() - .filter(Objects::nonNull).toList(); - MatrixBlock res; - if(remStrings.size() == 1){ - String s = remStrings.get(0); - if(s.equals(parts[1])){ - res=remMbs.get(0); - }else if(s.charAt(0)==s.charAt(1)) { - // diagonal needed - ReorgOperator op = new ReorgOperator(DiagIndex.getDiagIndexFnObject()); - res= remMbs.get(0).reorgOperations(op, new MatrixBlock(),0,0,0); - }else{ - //it has to be transpose: ab->ba - ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);//todo move to separate op earlier - res = remMbs.get(0).reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - } + partsCharactersToIndices.remove(sumC); + } + if (!anyCouldNotDo){ + //check if any operations to do that were not-output dimension summations: + List remStrings = inputsChars.stream() + .filter(Objects::nonNull).toList(); + List remMbs = inputs.stream() + .filter(Objects::nonNull).toList(); + MatrixBlock res; + if(remStrings.size() == 1){ + String s = remStrings.get(0); + if(s.equals(parts[1])){ + res=remMbs.get(0); + }else if(s.charAt(0)==s.charAt(1)) { + // diagonal needed + ReorgOperator op = new ReorgOperator(DiagIndex.getDiagIndexFnObject()); + res= remMbs.get(0).reorgOperations(op, new MatrixBlock(),0,0,0); }else{ - throw new RuntimeException("did not expect this!"); + //it has to be transpose: ab->ba + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);//todo move to separate op earlier + res = remMbs.get(0).reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); } - ec.setMatrixOutput(output.getName(), res); + }else{ + throw new RuntimeException("did not expect this!"); } + ec.setMatrixOutput(output.getName(), res); + } - else { - ArrayList mbs = new ArrayList<>(); - ArrayList chars = new ArrayList<>(); - for (int i = 0; i < inputs.size(); i++) { - MatrixBlock mb = inputs.get(i); - if (mb != null) { - mbs.add(mb); - chars.add(inputsChars.get(i)); - } + else { + ArrayList mbs = new ArrayList<>(); + ArrayList chars = new ArrayList<>(); + for (int i = 0; i < inputs.size(); i++) { + MatrixBlock mb = inputs.get(i); + if (mb != null) { + mbs.add(mb); + chars.add(inputsChars.get(i)); } + } // HashSet summingChars = new HashSet<>(); // for(String s : inputsChars){ // if(s == null) continue; @@ -754,21 +356,21 @@ else if(!einc.contractDimsSet.contains(c2)){ // summingChars.add(s.charAt(1)); // } // } - ArrayList summingChars = new ArrayList(); - for (Character c : partsCharactersToIndices.keySet()) { - if (c != outChar1 && c != outChar2) summingChars.add(c); + ArrayList summingChars = new ArrayList(); + for (Character c : partsCharactersToIndices.keySet()) { + if (c != outChar1 && c != outChar2) summingChars.add(c); - } - //computeCellSummation(ArrayList inputs, List inputsChars, String resultString, - // HashMap charToDimensionSizeInt, List summingChars) - MatrixBlock res = computeCellSummation(mbs, chars, parts[1], einc.charToDimensionSizeInt, summingChars, einc.outRows, einc.outCols); + } + //computeCellSummation(ArrayList inputs, List inputsChars, String resultString, + // HashMap charToDimensionSizeInt, List summingChars) + MatrixBlock res = computeCellSummation(mbs, chars, parts[1], einc.charToDimensionSizeInt, summingChars, einc.outRows, einc.outCols); - if (einc.outRows == 1 && einc.outCols == 1) - ec.setScalarOutput(output.getName(), new DoubleObject(res.get(0, 0))); - else ec.setMatrixOutput(output.getName(), res); + if (einc.outRows == 1 && einc.outCols == 1) + ec.setScalarOutput(output.getName(), new DoubleObject(res.get(0, 0))); + else ec.setMatrixOutput(output.getName(), res); - } } + //final operation @@ -1113,21 +715,6 @@ else if (summingChars.contains(inputsChars.get(i).charAt(1))) { } } -// @Override -// public Pair getLineageItem(ExecutionContext ec) { -// //return the lineage item if already traced once -// LineageItem li = ec.getLineage().get(output.getName()); -// if (li != null) -// return Pair.of(output.getName(), li); -// -// //read and deepcopy the corresponding lineage DAG (pre-codegen) -// LineageItem LIroot = LineageCodegenItem.getCodegenLTrace(getOperatorClass().getName()).deepCopy(); -// -// //replace the placeholders with original instruction inputs. -// LineageItemUtils.replaceDagLeaves(ec, LIroot, _in); -// -// return Pair.of(output.getName(), LIroot); -// } public CPOperand[] getInputs() { return _in; @@ -1135,23 +722,6 @@ public CPOperand[] getInputs() { private static final IDSequence _idSeqfn = new IDSequence(); - private final static String tmpRow = "package codegen;\n" + - "import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;\n" + - "import org.apache.sysds.runtime.codegen.SpoofOperator.SideInput;\n" + - "import org.apache.sysds.runtime.codegen.SpoofRowwise;\n" + - "import org.apache.sysds.runtime.codegen.SpoofRowwise.RowType;\n" + - "import org.apache.commons.math3.util.FastMath;\n" + - "\n" + - "public final class %TMP% extends SpoofRowwise { \n" + - " public %TMP%() {\n" + - " super(RowType.%TYPE%, %CONST_DIM2%, false, 1);\n" + - " }\n" + - " protected void genexec(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, int rix) { \n" + - "LibSpoofPrimitives.vectOuterMultAdd(a, b[0].values(rix), c, ai, b[0].pos(rix), 0, len, b[0].clen); }\n" + - " protected void genexec(double[] avals, int[] aix, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int alen, int len, long grix, int rix) { \n" + - " }\n" + - "}\n"; - private final static String tmpCell = "package codegen;\n" + "import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;\n" + From 786d217b0ee4aceb4fee09f1d8194d5bddc01424 Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Mon, 9 Jun 2025 00:33:46 +0200 Subject: [PATCH 07/28] fix bugs --- .../instructions/cp/EinsumCPInstruction.java | 22 ++++++++---- .../test/functions/einsum/EinsumTest.java | 11 +++--- src/test/scripts/functions/einsum/einsum10.R | 35 +++++++++++++++++++ .../scripts/functions/einsum/einsum10.dml | 31 ++++++++++++++++ 4 files changed, 88 insertions(+), 11 deletions(-) create mode 100644 src/test/scripts/functions/einsum/einsum10.R create mode 100644 src/test/scripts/functions/einsum/einsum10.dml diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java index d280653be58..d4c94b39b02 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -240,7 +240,8 @@ else if(!einc.contractDimsSet.contains(c2)){ boolean anyCouldNotDo = true; // information to do cell tpl for remaining ones - while (!forceCell) { + if(!forceCell) + while (true) { List toSum = null; Character sumC = null; anyCouldNotDo = false; @@ -248,8 +249,15 @@ else if(!einc.contractDimsSet.contains(c2)){ for (Character c : partsCharactersToIndices.keySet()) { // sum on dim at the time if (c == outChar1 || c == outChar2) continue; - toSum = partsCharactersToIndices.get(c).stream() - .filter(Objects::nonNull).toList(); + toSum = new ArrayList<>(); + int count =0; //very temp solution, this part will be replaced anyway later probably + for(Integer idx : partsCharactersToIndices.get(c).stream() + .filter(Objects::nonNull).toList()){ + if(inputs.get(idx) != null){ + count++; + toSum.add(idx); + } + } if (toSum.size() > 2) { anyCouldNotDo = true; continue; @@ -498,8 +506,8 @@ else if(s1.charAt(0) == s2.charAt(1)){ s1=s2; s2=sTemp; - first = inputs.get(toSum.get(0)); - second = inputs.get(toSum.get(1)); + first = inputs.get(toSum.get(1)); + second = inputs.get(toSum.get(0)); resS = String.valueOf(s1.charAt(0))+String.valueOf(s2.charAt(1)); } @@ -518,13 +526,13 @@ else if(s1.charAt(1) == s2.charAt(0)){ case Ba_a: throw new NotImplementedException(); case Ba_aC: { - out = getCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_MATRIXMULT, SpoofRowwise.RowType.NO_AGG, null, scalar); + out = getCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_MATRIXMULT, SpoofRowwise.RowType.NO_AGG_B1, Long.valueOf( second.getNumColumns()), scalar); break; } case Ba_Ca: ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); second = second.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - out = getCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_MATRIXMULT, SpoofRowwise.RowType.NO_AGG, null, scalar); + out = getCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_MATRIXMULT, SpoofRowwise.RowType.NO_AGG_B1, Long.valueOf( second.getNumColumns()), scalar); break; case aB_a: case aB_aC: { diff --git a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java index 6fab39dd2cf..d45562385f8 100644 --- a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java +++ b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java @@ -50,6 +50,7 @@ public class EinsumTest extends AutomatedTestBase private static final String TEST_EINSUM7 = TEST_NAME_EINSUM+"7"; private static final String TEST_EINSUM8 = TEST_NAME_EINSUM+"8"; private static final String TEST_EINSUM9 = TEST_NAME_EINSUM+"9"; + private static final String TEST_EINSUM10 = TEST_NAME_EINSUM+"10"; private static final String TEST_DIR = "functions/einsum/"; private static final String TEST_CLASS_DIR = TEST_DIR + EinsumTest.class.getSimpleName() + "/"; @@ -61,7 +62,7 @@ public class EinsumTest extends AutomatedTestBase @Override public void setUp() { TestUtils.clearAssertionInformation(); - for(int i=1; i<=9; i++) + for(int i=1; i<=10; i++) addTestConfiguration( TEST_NAME_EINSUM+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_EINSUM+i, new String[] { String.valueOf(i) }) ); } @Test @@ -93,13 +94,15 @@ public void testCodegenEinsum7CP() { testCodegenIntegration( TEST_EINSUM7, false, ExecType.CP ); } @Test - public void testCodegenEinsum8CP() { - testCodegenIntegration( TEST_EINSUM8, false, ExecType.CP ); - } + public void testCodegenEinsum8CP() { testCodegenIntegration( TEST_EINSUM8, false, ExecType.CP ); } @Test public void testCodegenEinsum9CP() { testCodegenIntegration( TEST_EINSUM9, false, ExecType.CP ); } + @Test + public void testCodegenEinsum91CP() { + testCodegenIntegration( TEST_EINSUM10, false, ExecType.CP ); + } private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType ) { diff --git a/src/test/scripts/functions/einsum/einsum10.R b/src/test/scripts/functions/einsum/einsum10.R new file mode 100644 index 00000000000..9fe7fa58b33 --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum10.R @@ -0,0 +1,35 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") +library("einsum") + +A = matrix(seq(1,3000), 5, 600, byrow=TRUE); +B = matrix(seq(1,6000), 600, 10, byrow=TRUE); +C = matrix(seq(1,50), 10, 5, byrow=TRUE); +D = matrix(seq(1,20), 5, 4, byrow=TRUE); + +R = einsum("ab,bc,cd,de->ae",A,B,C,D) + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum10.dml b/src/test/scripts/functions/einsum/einsum10.dml new file mode 100644 index 00000000000..b7d2e931da5 --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum10.dml @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = matrix(seq(1,3000), 5, 600) +B = matrix(seq(1,6000), 600, 10); +C = matrix(seq(1,50), 10, 5); +D = matrix(seq(1,20), 5, 4); + +while(FALSE){} + +R = einsum("ab,bc,cd,de->ae",A,B,C,D) + +write(R, $1) From 980d7a8420dc2f730bdcf4e712dffcf1bd6c1dde Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Tue, 17 Jun 2025 20:04:05 +0200 Subject: [PATCH 08/28] cleanup validate code --- .../java/org/apache/sysds/hops/NaryOp.java | 17 ++++- .../parser/BuiltinFunctionExpression.java | 69 +++++++------------ 2 files changed, 40 insertions(+), 46 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/NaryOp.java b/src/main/java/org/apache/sysds/hops/NaryOp.java index 438cf638515..466ae6236b7 100644 --- a/src/main/java/org/apache/sysds/hops/NaryOp.java +++ b/src/main/java/org/apache/sysds/hops/NaryOp.java @@ -236,9 +236,20 @@ public void refreshSizeInformation() { setDim2(1); break; case EINSUM: - setDataType(DataType.MATRIX); - setDim1(getInput().size()); - setDim2(1); + String outStr = ((LiteralOp) _input.get(0)).getStringValue().split("->")[1]; + int count = 0; + for (int i = 0; i < outStr.length(); i++){ + if(outStr.charAt(i) != ' ') count++; + } + if(count==0) { + setDataType(DataType.SCALAR); + setDim1(0); + setDim2(0); + } + else{ + setDim1( HopRewriteUtils.getMaxInputDim(this, true)); + setDim2(count==1 ? 1 : HopRewriteUtils.getMaxInputDim(this, false)); + } break; case PRINTF: case EVAL: diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java index aba7e4312fe..c73ce9cdcdf 100644 --- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java @@ -755,7 +755,6 @@ else if(((ConstIdentifier) getThirdExpr().getOutput()) break; case EINSUM: validateEinsum((DataIdentifier) getOutputs()[0]); - break; default: //always unconditional raiseValidateError("Unknown Builtin Function opcode: " + _opcode, false); @@ -2069,7 +2068,6 @@ else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AV break; case EINSUM: validateEinsum(output); - break; default: if( isMathFunction() ) { @@ -2113,10 +2111,9 @@ private void validateEinsum(DataIdentifier output){ LanguageErrorCodes.INVALID_PARAMETERS); String eq_string = ((StringIdentifier)getFirstExpr()).getValue(); + String[] eqStringParts = eq_string.split("->"); - String[] parts = eq_string.split("->"); - - if(parts.length != 2) + if(eqStringParts.length != 2) raiseValidateError("Einsum: equation str should contain one '->' substring", false, LanguageErrorCodes.INVALID_PARAMETERS); @@ -2126,96 +2123,82 @@ private void validateEinsum(DataIdentifier output){ LinkedList matrixBlocks = new LinkedList(); for (int i=1;i charToDimensionSize = new HashMap<>(); - Iterator it = matrixBlocks.iterator(); - Identifier curArr = it.next(); + Identifier currArr = it.next(); int arrSizeIterator = 0; int numberOfMatrices = 1; - for (int i = 0; i numberOfMatrices){ + if (getAllExpr().length - 1 > numberOfMatrices) raiseValidateError("Einsum: Provided more operands than specified in equation str", false, LanguageErrorCodes.INVALID_PARAMETERS); - } + int numberOfDimensions = 0; - long dim1 = 0; - long dim2 = 0; - for (int i = 0; i2){ + }else if(numberOfDimensions > 2){ raiseValidateError("Einsum: output matrices with with no. dims > 2 not supported", false, LanguageErrorCodes.INVALID_PARAMETERS); }else { output.setDataType(DataType.MATRIX); output.setDimensions(dim1, dim2); } - }else{ + } else { // dimensions unknown int numberOfMatrices = 1; - for (int i = 0; i < parts[0].length(); i++) { - if(parts[0].charAt(i) == ',') + for (int i = 0; i < eqStringParts[0].length(); i++) { + if(eqStringParts[0].charAt(i) == ',') numberOfMatrices++; } checkNumParameters(numberOfMatrices+1); int numberOfDimensions = 0; - - for (int i = 0; i Date: Tue, 17 Jun 2025 20:05:45 +0200 Subject: [PATCH 09/28] add tests --- .../test/functions/einsum/EinsumTest.java | 12 +++--- src/test/scripts/functions/einsum/einsum11.R | 38 +++++++++++++++++++ .../scripts/functions/einsum/einsum11.dml | 33 ++++++++++++++++ 3 files changed, 78 insertions(+), 5 deletions(-) create mode 100644 src/test/scripts/functions/einsum/einsum11.R create mode 100644 src/test/scripts/functions/einsum/einsum11.dml diff --git a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java index d45562385f8..51f77059836 100644 --- a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java +++ b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java @@ -21,7 +21,6 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.apache.sysds.common.Opcodes; import org.apache.sysds.common.Types.ExecMode; import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.hops.OptimizerUtils; @@ -29,8 +28,6 @@ import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.test.TestUtils; -import org.junit.Assert; -import org.junit.Ignore; import org.junit.Test; import java.io.File; @@ -51,6 +48,7 @@ public class EinsumTest extends AutomatedTestBase private static final String TEST_EINSUM8 = TEST_NAME_EINSUM+"8"; private static final String TEST_EINSUM9 = TEST_NAME_EINSUM+"9"; private static final String TEST_EINSUM10 = TEST_NAME_EINSUM+"10"; + private static final String TEST_EINSUM11 = TEST_NAME_EINSUM+"11"; private static final String TEST_DIR = "functions/einsum/"; private static final String TEST_CLASS_DIR = TEST_DIR + EinsumTest.class.getSimpleName() + "/"; @@ -62,7 +60,7 @@ public class EinsumTest extends AutomatedTestBase @Override public void setUp() { TestUtils.clearAssertionInformation(); - for(int i=1; i<=10; i++) + for(int i=1; i<=11; i++) addTestConfiguration( TEST_NAME_EINSUM+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_EINSUM+i, new String[] { String.valueOf(i) }) ); } @Test @@ -100,9 +98,13 @@ public void testCodegenEinsum9CP() { testCodegenIntegration( TEST_EINSUM9, false, ExecType.CP ); } @Test - public void testCodegenEinsum91CP() { + public void testCodegenEinsum10CP() { testCodegenIntegration( TEST_EINSUM10, false, ExecType.CP ); } + @Test + public void testCodegenEinsum11CP() { + testCodegenIntegration( TEST_EINSUM11, false, ExecType.CP ); + } private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType ) { diff --git a/src/test/scripts/functions/einsum/einsum11.R b/src/test/scripts/functions/einsum/einsum11.R new file mode 100644 index 00000000000..45c2a647799 --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum11.R @@ -0,0 +1,38 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") +library("einsum") + + +A = matrix(seq(1,300), 5, 60, byrow=TRUE)/1000; +B = matrix(seq(1,150), 5, 30, byrow=TRUE)/1000; +C = matrix(seq(1,500), 5, 100, byrow=TRUE)/1000; +D = matrix(seq(1,1800), 60, 30, byrow=TRUE)/1000; +E = matrix(seq(1,6000), 100, 60, byrow=TRUE)/1000; +F = matrix(seq(1,3000), 100, 30, byrow=TRUE)/1000; + +R = einsum("fx,fg,fz,xg,zx,zg->g",A,B,C,D,E,F) + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum11.dml b/src/test/scripts/functions/einsum/einsum11.dml new file mode 100644 index 00000000000..7c2b3e0e9f2 --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum11.dml @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = matrix(seq(1,300), 5, 60)/1000; +B = matrix(seq(1,150), 5, 30)/1000; +C = matrix(seq(1,500), 5, 100)/1000; +D = matrix(seq(1,1800), 60, 30)/1000; +E = matrix(seq(1,6000), 100, 60)/1000; +F = matrix(seq(1,3000), 100, 30)/1000; + +while(FALSE){} + +R = einsum("fx,fg,fz,xg,zx,zg->g",A,B,C,D,E,F) + +write(R, $1) From 109480503e802d56fe14ff63f035c3b3252c05fc Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Thu, 19 Jun 2025 00:49:11 +0200 Subject: [PATCH 10/28] better plan generation for einsum, handle more cases, added test that does matrix muls and elemwise muls with sums --- .../sysds/hops/codegen/cplan/CNodeCell.java | 2 +- .../instructions/cp/EinsumCPInstruction.java | 488 +++++++++++------- .../test/functions/einsum/EinsumTest.java | 8 +- src/test/scripts/functions/einsum/einsum12.R | 35 ++ .../scripts/functions/einsum/einsum12.dml | 31 ++ 5 files changed, 381 insertions(+), 183 deletions(-) create mode 100644 src/test/scripts/functions/einsum/einsum12.R create mode 100644 src/test/scripts/functions/einsum/einsum12.dml diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java index 3d2c19ef4c8..2482ac77e22 100644 --- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java +++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java @@ -31,7 +31,7 @@ public class CNodeCell extends CNodeTpl { - protected static final String JAVA_TEMPLATE = + public static final String JAVA_TEMPLATE = "package codegen;\n" + "import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;\n" + "import org.apache.sysds.runtime.codegen.SpoofCellwise;\n" diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java index d4c94b39b02..6d22a1ffd5b 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -23,6 +23,8 @@ import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.log4j.Level; +import org.apache.log4j.Logger; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.hops.LiteralOp; import org.apache.sysds.hops.OptimizerUtils; @@ -42,14 +44,13 @@ import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; import org.apache.sysds.runtime.matrix.operators.Operator; import org.apache.sysds.runtime.matrix.operators.ReorgOperator; -import org.checkerframework.checker.units.qual.A; import java.util.*; import static org.apache.sysds.runtime.instructions.cp.EinsumContext.getEinsumContext; public class EinsumCPInstruction extends BuiltinNaryCPInstruction { - public static boolean forceCell = false; + public static boolean FORCE_CELL_TPL = false; protected static final Log LOG = LogFactory.getLog(EinsumCPInstruction.class.getName()); public String eqStr; private final Class _class; @@ -65,6 +66,7 @@ public EinsumCPInstruction(Operator op, String opcode, String istr, CPOperand ou _numThreads = OptimizerUtils.getConstrainedNumThreads(-1); _in = inputs; this.eqStr = inputs[0].getName(); + Logger.getLogger(EinsumCPInstruction.class).setLevel(Level.TRACE); } @@ -80,6 +82,7 @@ public Class getOperatorClass() { private static final int CONTRACT_RIGHT = 2; private static final int CONTRACT_BOTH = 3; + private EinsumContext einc = null; @Override public void processInstruction(ExecutionContext ec) { @@ -87,9 +90,6 @@ public void processInstruction(ExecutionContext ec) { //get input matrices and scalars, incl pinning of matrices ArrayList inputs = new ArrayList<>(); ArrayList inputsNames = new ArrayList<>(); - ArrayList scalars = new ArrayList<>(); - if( LOG.isDebugEnabled() ) - LOG.debug("executing spoof instruction " + _op); for (CPOperand input : _in) { if(input.getDataType()==DataType.MATRIX){ MatrixBlock mb = ec.getMatrixInput(input.getName()); @@ -100,20 +100,16 @@ public void processInstruction(ExecutionContext ec) { inputs.add(mb); inputsNames.add(input.getName()); } -// else if(input.getDataType()==DataType.SCALAR) { -// //note: even if literal, it might be compiled as scalar placeholder -// scalars.add(ec.getScalarInput(input)); -// } } EinsumContext einc = getEinsumContext(eqStr,inputs); - + this.einc = einc; + //todo not throwing err when output char is not in input String[] parts = einc.equationString.split("->"); // ArrayList inputsChars = new ArrayList<>(Arrays.asList(parts[0].split(","))); - System.out.println("outrows:"+einc.outRows); - System.out.println("outcols:"+einc.outCols); + if( LOG.isDebugEnabled() ) LOG.trace("outrows:"+einc.outRows+", outcols:"+einc.outCols); Character outChar1 = null; Character outChar2 = null; @@ -124,7 +120,6 @@ public void processInstruction(ExecutionContext ec) { }else if (parts[1].length()==1){ outChar1 = parts[1].charAt(0); } - HashMap partsCharactersCounter = new HashMap<>(); HashMap> partsCharactersToIndices = new HashMap<>(); ArrayList newEquationStringSplit = new ArrayList(); @@ -138,10 +133,6 @@ public void processInstruction(ExecutionContext ec) { } String s=""; if(!einc.contractDimsSet.contains(c)){ -// if(partsCharactersCounter.containsKey(c)) -// partsCharactersCounter.put(c, partsCharactersCounter.get(c)+1); -// else partsCharactersCounter.put(c, 1); -// s+=c; if(!partsCharactersToIndices.containsKey(c)) partsCharactersToIndices.put(c, new ArrayList<>()); @@ -176,7 +167,7 @@ else if(!einc.contractDimsSet.contains(c2)){ newEquationStringSplit.add(s); } ArrayList inputsChars = newEquationStringSplit; - System.out.println(String.join(",",newEquationStringSplit)); + LOG.trace(String.join(",",newEquationStringSplit)); //todo move to separate op earlier: for(int i=0;i a = partsCharactersToIndices.get(c); + if(LOG.isTraceEnabled()) + for(Character c :partsCharactersToIndices.keySet()){ + ArrayList a = partsCharactersToIndices.get(c); + LOG.trace(c+" count= "+a.size()); + } + + + + Double scalar = null; + boolean anyCouldNotDo = FORCE_CELL_TPL ? true : generatePlanAndExecute(partsCharactersToIndices, inputs, inputsChars, outChar1, outChar2); // information to do cell tpl for remaining ones - System.out.println(c+" count= "+a.size()); + if (!anyCouldNotDo){ + //check if any operations to do that were not-output dimension summations: + List remStrings = inputsChars.stream() + .filter(Objects::nonNull).toList(); + List remMbs = inputs.stream() + .filter(Objects::nonNull).toList(); + MatrixBlock res; + if(remStrings.size() == 1){ + String s = remStrings.get(0); + if(s.equals(parts[1])){ + res=remMbs.get(0); + }else if(s.charAt(0)==s.charAt(1)) { + // diagonal needed + ReorgOperator op = new ReorgOperator(DiagIndex.getDiagIndexFnObject()); + res= remMbs.get(0).reorgOperations(op, new MatrixBlock(),0,0,0); + }else{ + //it has to be transpose: ab->ba + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); + res = remMbs.get(0).reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + } + }else{ + throw new RuntimeException("did not expect this!"); + } + ec.setMatrixOutput(output.getName(), res); } + else { + ArrayList mbs = new ArrayList<>(); + ArrayList chars = new ArrayList<>(); + for (int i = 0; i < inputs.size(); i++) { + MatrixBlock mb = inputs.get(i); + if (mb != null) { + mbs.add(mb); + chars.add(inputsChars.get(i)); + } + } + if(chars.size()==1 && chars.get(0).equals(parts[1])){ // maybe result is correct after all... (need to improve logic if cell is needed but for now just check if maybe all is OK) + ec.setMatrixOutput(output.getName(), mbs.get(0)); + }else { + ArrayList summingChars = new ArrayList(); + for (Character c : partsCharactersToIndices.keySet()) { + if (c != outChar1 && c != outChar2) summingChars.add(c); + } + + MatrixBlock res = computeCellSummation(mbs, chars, parts[1], einc.charToDimensionSizeInt, summingChars, einc.outRows, einc.outCols); + + if (einc.outRows == 1 && einc.outCols == 1) + ec.setScalarOutput(output.getName(), new DoubleObject(res.get(0, 0))); + else ec.setMatrixOutput(output.getName(), res); + } + } + //final operation + + + // release input matrices + for (CPOperand input : _in) + if(input.getDataType()==DataType.MATRIX) + ec.releaseMatrixInput(input.getName()); + } + private boolean generatePlanAndExecute(HashMap> partsCharactersToIndices, ArrayList inputs, ArrayList inputsChars, Character outChar1, Character outChar2) { // compute scalars: Double scalar = null; for(int i=0;i< inputs.size(); i++){ @@ -237,24 +293,105 @@ else if(!einc.contractDimsSet.contains(c2)){ } } + boolean anyCouldNotDo; + boolean didAnything = false; + do { + anyCouldNotDo = sumCharactersWherePossible(partsCharactersToIndices, inputs, inputsChars, outChar1, outChar2); + didAnything = false; + if(inputsChars.stream().filter(Objects::nonNull).count() > 1) + didAnything = multiplyTerms(partsCharactersToIndices, inputs, inputsChars, outChar1, outChar2); + } + while(didAnything); + + return anyCouldNotDo; + } + + /* handle situation: ji,ji or ij,ji */ + private boolean multiplyTerms(HashMap> partsCharactersToIndices, ArrayList inputs, ArrayList inputsChars, Character outChar1, Character outChar2 ) { + ArrayList multiplyIdxs = new ArrayList<>(); + ArrayList transposeMultiplyIdxs = new ArrayList<>(); + + HashMap> stringToIndex = new HashMap<>(); + + for(int i = 0; i < inputsChars.size(); i++){ + String s = inputsChars.get(i); + if(s==null) continue; +// if(s.length() != 2) continue; + + if (stringToIndex.containsKey(s)) stringToIndex.get(s).add(i); + else { ArrayList list = new ArrayList<>(); list.add(i); stringToIndex.put(s, list); } + } + + boolean doneAnything = false; + + for(var s : stringToIndex.keySet()){ + if(!stringToIndex.containsKey(s)) continue; // entries can be removed - boolean anyCouldNotDo = true; // information to do cell tpl for remaining ones + String sT = s.length() == 2 ? String.valueOf(s.charAt(1)) + s.charAt(0) : null; + ArrayList idxs = stringToIndex.get(s); + ArrayList idxsT = sT != null ? stringToIndex.containsKey(sT) ? stringToIndex.get(sT) : null : null; + + if(idxs.size() <= 1 && idxsT == null) continue; + + doneAnything = true; + + // do decision if transpose idxs or idxsT: right now just alway transpose second + ArrayList mbs = new ArrayList<>(); + if(LOG.isTraceEnabled()){ + StringBuilder sb = new StringBuilder(); + for(Integer idx : idxs){ + sb.append(inputsChars.get(idx)); + sb.append(","); + } + if(idxsT != null) + for(Integer idx : idxsT){ + sb.append(inputsChars.get(idx)); + sb.append(","); + } + LOG.trace("Element wise multiplying: "+sb.toString()); + } + for(Integer idx : idxs){ + mbs.add(inputs.get(idx)); + inputs.set(idx, null); + inputsChars.set(idx, null); + } + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); + if(idxsT != null) + for(Integer idx : idxsT){ + mbs.add(inputs.get(idx).reorgOperations(transpose, new MatrixBlock(), 0, 0, 0)); + inputs.set(idx, null); + inputsChars.set(idx, null); + } + + MatrixBlock mb = getCodegenElemwiseMult(mbs); + + inputs.add(mb); + inputsChars.add(s); + + for (int i = 0; i < s.length(); i++) { // for each char in string, add pointer to newly created entry + char c = s.charAt(i); + partsCharactersToIndices.get(c).add(inputs.size() - 1); + } + + if(idxsT != null) + stringToIndex.remove(sT); + } + + return doneAnything; + } + + private boolean sumCharactersWherePossible(HashMap> partsCharactersToIndices, ArrayList inputs, ArrayList inputsChars, Character outChar1, Character outChar2) { + boolean anyCouldNotDo = false; - if(!forceCell) while (true) { List toSum = null; Character sumC = null; - anyCouldNotDo = false; - Character cInOut = null; - for (Character c : partsCharactersToIndices.keySet()) { // sum on dim at the time + for (Character c : partsCharactersToIndices.keySet()) { // sum one dim at the time if (c == outChar1 || c == outChar2) continue; toSum = new ArrayList<>(); - int count =0; //very temp solution, this part will be replaced anyway later probably - for(Integer idx : partsCharactersToIndices.get(c).stream() - .filter(Objects::nonNull).toList()){ - if(inputs.get(idx) != null){ - count++; + for (Integer idx : partsCharactersToIndices.get(c).stream().filter(Objects::nonNull).toList()) { + if (inputs.get(idx) != null) { toSum.add(idx); } } @@ -267,30 +404,12 @@ else if(!einc.contractDimsSet.contains(c2)){ sumC = c; break; } - if (anyCouldNotDo) { - break; - } - if (sumC == null) { - //check if maybe there are out-put characters only terms like a,a,ab->ba - List remStrings = inputsChars.stream() - .filter(Objects::nonNull).toList(); - List remMbs = inputs.stream() - .filter(Objects::nonNull).toList(); - if(remStrings.size() > 1){ - Pair res = computRowSummationsOutputCharsOnly(remMbs, remStrings, parts[1],scalar); - scalar = null; - inputs = new ArrayList<>(Arrays.asList(res.getLeft())); - inputsChars = new ArrayList<>(Arrays.asList(res.getRight())); - } - break; //nothing left to sum - } - Pair res = computeRowSummation(toSum, inputs, inputsChars, scalar); - scalar = null; + if(sumC == null) break; + + Pair res = computeRowSummation(toSum, inputs, inputsChars, sumC); String newS = res.getRight(); -// var iter = toSum.listIterator(); -// Integer ii = iter.next(); for (Integer idx : toSum) { inputs.set(idx, null); inputsChars.set(idx, null); @@ -298,104 +417,24 @@ else if(!einc.contractDimsSet.contains(c2)){ inputs.add(res.getLeft()); inputsChars.add(newS); - for (int i = 0; i < newS.length(); i++) { + for (int i = 0; i < newS.length(); i++) { // for each char in string, add pointer to newly created entry char c = newS.charAt(i); -// partsCharactersToIndices.get(c).remove(c); partsCharactersToIndices.get(c).add(inputs.size() - 1); } - - -// for(int i=0;i remStrings = inputsChars.stream() - .filter(Objects::nonNull).toList(); - List remMbs = inputs.stream() - .filter(Objects::nonNull).toList(); - MatrixBlock res; - if(remStrings.size() == 1){ - String s = remStrings.get(0); - if(s.equals(parts[1])){ - res=remMbs.get(0); - }else if(s.charAt(0)==s.charAt(1)) { - // diagonal needed - ReorgOperator op = new ReorgOperator(DiagIndex.getDiagIndexFnObject()); - res= remMbs.get(0).reorgOperations(op, new MatrixBlock(),0,0,0); - }else{ - //it has to be transpose: ab->ba - ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);//todo move to separate op earlier - res = remMbs.get(0).reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - } - }else{ - throw new RuntimeException("did not expect this!"); - } - ec.setMatrixOutput(output.getName(), res); - } - - else { - ArrayList mbs = new ArrayList<>(); - ArrayList chars = new ArrayList<>(); - for (int i = 0; i < inputs.size(); i++) { - MatrixBlock mb = inputs.get(i); - if (mb != null) { - mbs.add(mb); - chars.add(inputsChars.get(i)); - } - } -// HashSet summingChars = new HashSet<>(); -// for(String s : inputsChars){ -// if(s == null) continue; -// if(s.length() == 1) summingChars.add(s.charAt(0)); -// if(s.length() == 2) { -// summingChars.add(s.charAt(0)); -// summingChars.add(s.charAt(1)); -// } -// } - ArrayList summingChars = new ArrayList(); - for (Character c : partsCharactersToIndices.keySet()) { - if (c != outChar1 && c != outChar2) summingChars.add(c); - - } - //computeCellSummation(ArrayList inputs, List inputsChars, String resultString, - // HashMap charToDimensionSizeInt, List summingChars) - MatrixBlock res = computeCellSummation(mbs, chars, parts[1], einc.charToDimensionSizeInt, summingChars, einc.outRows, einc.outCols); - - if (einc.outRows == 1 && einc.outCols == 1) - ec.setScalarOutput(output.getName(), new DoubleObject(res.get(0, 0))); - else ec.setMatrixOutput(output.getName(), res); - - } - - //final operation - - - // release input matrices - for (CPOperand input : _in) - if(input.getDataType()==DataType.MATRIX) - ec.releaseMatrixInput(input.getName()); + return anyCouldNotDo; } private enum SumOperation { aB_a, Ba_a, - Ba_aC, // mmult + Ba_aC, // mmult -> BC // aB_Ca, Ba_Ca, aB_aC, // outer mult a_a, + aB_aB, Ba_Ba, Ba_aB, aB_Ba,// mult and sums, something like ij,ij->i } private enum AggregateAtEnd{ @@ -408,7 +447,7 @@ private Pair computRowSummationsOutputCharsOnly(List computRowSummationsOutputCharsOnly(List computeRowSummation(List toSum, ArrayList inputs, List inputsChars) { - return computeRowSummation(toSum,inputs,inputsChars, null); + private Pair computeRowSummation(List toSum, ArrayList inputs, List inputsChars, Character sumChar) { + return computeRowSummation(toSum,inputs,inputsChars, null, sumChar); } - private Pair computeRowSummation(List toSum, ArrayList inputs, List inputsChars, Double scalar) { + private Pair computeRowSummation(List toSum, ArrayList inputs, List inputsChars, Double scalar, Character sumChar) { if(toSum.size() != 2){ return null; @@ -455,7 +494,6 @@ private Pair computeRowSummation(List toSum, Array String s1 = inputsChars.get(toSum.get(0)); String s2 = inputsChars.get(toSum.get(1)); - MatrixBlock first = null; MatrixBlock second = null; @@ -486,8 +524,31 @@ else if(s2.length() == 1 || s1.length() == 1){ sumOp = SumOperation.Ba_a; resS = String.valueOf(s1.charAt(0)); } - } - else if(s1.charAt(0) == s2.charAt(0)){ + } else if (s1.equals(s2)) { + if(s1.charAt(0) == sumChar){ + sumOp = SumOperation.aB_aB; + first = inputs.get(toSum.get(0)); + second = inputs.get(toSum.get(1)); + resS = String.valueOf(s1.charAt(1)); + }else{ + sumOp = SumOperation.Ba_Ba; + first = inputs.get(toSum.get(0)); + second = inputs.get(toSum.get(1)); + resS = String.valueOf(s1.charAt(0)); + } + }else if (s1.charAt(0) == s2.charAt(1) && s1.charAt(1) == s2.charAt(0)) { + if(s1.charAt(0) == sumChar){ + sumOp = SumOperation.aB_Ba; + first = inputs.get(toSum.get(0)); + second = inputs.get(toSum.get(1)); + resS = String.valueOf(s1.charAt(1)); + }else{ + sumOp = SumOperation.Ba_aB; + first = inputs.get(toSum.get(0)); + second = inputs.get(toSum.get(1)); + resS = String.valueOf(s1.charAt(0)); + } + } else if(s1.charAt(0) == s2.charAt(0)){ sumOp = SumOperation.aB_aC; first = inputs.get(toSum.get(0)); second = inputs.get(toSum.get(1)); @@ -522,35 +583,86 @@ else if(s1.charAt(1) == s2.charAt(0)){ } MatrixBlock out; + if(LOG.isTraceEnabled()) LOG.trace("Summing: "+s1+","+s2+"->"+resS); switch (sumOp) { case Ba_a: throw new NotImplementedException(); case Ba_aC: { - out = getCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_MATRIXMULT, SpoofRowwise.RowType.NO_AGG_B1, Long.valueOf( second.getNumColumns()), scalar); + out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_MATRIXMULT, SpoofRowwise.RowType.NO_AGG_B1, Long.valueOf( second.getNumColumns()), scalar); break; } - case Ba_Ca: + case Ba_Ca: { ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); second = second.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - out = getCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_MATRIXMULT, SpoofRowwise.RowType.NO_AGG_B1, Long.valueOf( second.getNumColumns()), scalar); + out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_MATRIXMULT, SpoofRowwise.RowType.NO_AGG_B1, Long.valueOf(second.getNumColumns()), scalar); break; + } case aB_a: case aB_aC: { - out = getCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_OUTERMULT_ADD, SpoofRowwise.RowType.COL_AGG_B1_T, Long.valueOf( second.getNumColumns()),scalar); + out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_OUTERMULT_ADD, SpoofRowwise.RowType.COL_AGG_B1_T, Long.valueOf( second.getNumColumns()),scalar); break; } case a_a: - out = getCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_MULT, SpoofRowwise.RowType.NO_AGG,null, scalar); + out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_MULT, SpoofRowwise.RowType.NO_AGG,null, scalar); + break; + case aB_aB: { + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); + first = first.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + second = second.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.DOT_PRODUCT, SpoofRowwise.RowType.COL_AGG, null, scalar); break; + } + case Ba_Ba: { + out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.DOT_PRODUCT, SpoofRowwise.RowType.ROW_AGG, null, scalar); + break; + } + case aB_Ba: { + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); + first = first.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_MATRIXMULT, SpoofRowwise.RowType.ROW_AGG,null, scalar); + break; + } + case Ba_aB: { + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); + second = second.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.DOT_PRODUCT, SpoofRowwise.RowType.ROW_AGG,null, scalar); + break; + } + default: throw new IllegalStateException("Unexpected value: " + sumOp); } return Pair.of(out , resS); } - private MatrixBlock getCodegenMatrixBlock(MatrixBlock first, MatrixBlock second, CNodeBinary.BinType binaryType, SpoofRowwise.RowType rowType){ - return getCodegenMatrixBlock(first, second, binaryType,rowType,null, null); + private MatrixBlock getRowCodegenMatrixBlock(MatrixBlock first, MatrixBlock second, CNodeBinary.BinType binaryType, SpoofRowwise.RowType rowType){ + return getRowCodegenMatrixBlock(first, second, binaryType,rowType,null, null); + } + private MatrixBlock getCodegenElemwiseMult(ArrayList mbs) { + + ArrayList cnodeIn = new ArrayList<>(); + for(int i=0;i scalars = new ArrayList<>(); + MatrixBlock out = op.execute(mbs, scalars, mb, _numThreads); + return out; } - private MatrixBlock getCodegenMatrixBlock(MatrixBlock first, MatrixBlock second, CNodeBinary.BinType binaryType, SpoofRowwise.RowType rowType, Long secondDim, Double scalar) { + private MatrixBlock getRowCodegenMatrixBlock(MatrixBlock first, MatrixBlock second, CNodeBinary.BinType binaryType, SpoofRowwise.RowType rowType, Long secondDim, Double scalar) { ArrayList thisInputs = new ArrayList<>(Arrays.asList(first, second)); ArrayList cnodeIn = new ArrayList<>(); @@ -568,42 +680,38 @@ private MatrixBlock getCodegenMatrixBlock(MatrixBlock first, MatrixBlock second, cnode.renameInputs(); String src = cnode.codegen(false, SpoofCompiler.GeneratorAPI.JAVA); - - System.out.println(src); + if( LOG.isTraceEnabled()) LOG.trace(CodegenUtils.printWithLineNumber(src)); Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); SpoofOperator op = CodegenUtils.createInstance(cla); MatrixBlock mb = new MatrixBlock(); -// mb.reset(einc.outRows, einc.outCols, false); -// mb.allocateDenseBlock(); ArrayList scalars = new ArrayList<>(); if(scalar != null) scalars.add(new DoubleObject(scalar)); MatrixBlock out = op.execute(thisInputs, scalars, mb, _numThreads); return out; } + private static void indent(StringBuilder sb, int level) { + for (int i = 0; i < level; i++) { + sb.append(" "); + } + } private MatrixBlock computeCellSummation(ArrayList inputs, List inputsChars, String resultString, HashMap charToDimensionSizeInt, List summingChars, int outRows, int outCols){ ArrayList cnodeIn = new ArrayList<>(); cnodeIn.add(new CNodeData(new LiteralOp(3), 0, 0, DataType.SCALAR)); CNodeCell cnode = new CNodeCell(cnodeIn, null); -// cnode.setCellType(SpoofCellwise.CellType.NO_AGG); StringBuilder sb = new StringBuilder(); -// if (resultString.length() == 2) { -// summingChars.remove(resultString.charAt(0)); -// summingChars.remove(resultString.charAt(1)); -// -// } else if (resultString.length() == 1) { -// summingChars.remove(resultString.charAt(0)); -// -// } + int indent = 2; + indent(sb, indent); + boolean needsSumming = summingChars.stream().anyMatch(x -> x != null); - String itVar0 = "TMP123";//+new IDSequence().getNextID(); todo: generate this var - String outVar = null; + ; + String itVar0 = cnode.createVarname(); + String outVar = itVar0; if (needsSumming) { - outVar = "TMP123";//+new IDSequence().getNextID(); sb.append("double "); sb.append(outVar); sb.append("=0;\n"); @@ -612,7 +720,8 @@ private MatrixBlock computeCellSummation(ArrayList inputs, List summedCharacters = new HashSet<>(); Iterator hsIt = summingChars.iterator(); while (hsIt.hasNext()) { - + indent(sb, indent); + indent++; Character c = hsIt.next(); String itVar = itVar0 + c; sb.append("for(int "); @@ -625,6 +734,7 @@ private MatrixBlock computeCellSummation(ArrayList inputs, List(), new MatrixBlock(), _numThreads); - if (outRows == 1 && outCols == 1) { -// ec.setScalarOutput(output.getName(), new DoubleObject(out.get(0, 0))); - return out; - } else { -// ec.setMatrixOutput(output.getName(), out); - return out; - } + return out; } @@ -747,4 +857,22 @@ public CPOperand[] getInputs() { " return %OUT%;\n" + " }\n" + "}"; + public static final String JAVA_TEMPLATE = + "package codegen;\n" + + "import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;\n" + + "import org.apache.sysds.runtime.codegen.SpoofCellwise;\n" + + "import org.apache.sysds.runtime.codegen.SpoofCellwise.AggOp;\n" + + "import org.apache.sysds.runtime.codegen.SpoofCellwise.CellType;\n" + + "import org.apache.sysds.runtime.codegen.SpoofOperator.SideInput;\n" + + "import org.apache.commons.math3.util.FastMath;\n" + + "\n" + + "public final class %TMP% extends SpoofCellwise {\n" + + " public %TMP%() {\n" + + " super(CellType.%TYPE%, %SPARSE_SAFE%, %SEQ%, %AGG_OP_NAME%);\n" + + " }\n" + + " protected double genexec(double a, SideInput[] b, double[] scalars, int m, int n, long grix, int rix, int cix) { \n" + + "%BODY_dense%" + + " return %OUT%;\n" + + " }\n" + + "}\n"; } diff --git a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java index 51f77059836..24b0fc2b14b 100644 --- a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java +++ b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java @@ -49,6 +49,7 @@ public class EinsumTest extends AutomatedTestBase private static final String TEST_EINSUM9 = TEST_NAME_EINSUM+"9"; private static final String TEST_EINSUM10 = TEST_NAME_EINSUM+"10"; private static final String TEST_EINSUM11 = TEST_NAME_EINSUM+"11"; + private static final String TEST_EINSUM12 = TEST_NAME_EINSUM+"12"; private static final String TEST_DIR = "functions/einsum/"; private static final String TEST_CLASS_DIR = TEST_DIR + EinsumTest.class.getSimpleName() + "/"; @@ -60,7 +61,7 @@ public class EinsumTest extends AutomatedTestBase @Override public void setUp() { TestUtils.clearAssertionInformation(); - for(int i=1; i<=11; i++) + for(int i=1; i<=12; i++) addTestConfiguration( TEST_NAME_EINSUM+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_EINSUM+i, new String[] { String.valueOf(i) }) ); } @Test @@ -105,7 +106,10 @@ public void testCodegenEinsum10CP() { public void testCodegenEinsum11CP() { testCodegenIntegration( TEST_EINSUM11, false, ExecType.CP ); } - + @Test + public void testCodegenEinsum12CP() { + testCodegenIntegration( TEST_EINSUM12, false, ExecType.CP ); + } private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType ) { boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; diff --git a/src/test/scripts/functions/einsum/einsum12.R b/src/test/scripts/functions/einsum/einsum12.R new file mode 100644 index 00000000000..b2abb5bf642 --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum12.R @@ -0,0 +1,35 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") +library("einsum") + +A = matrix(seq(1,3000), 600, 5, byrow=TRUE)/1000; +B = matrix(seq(1,6000), 600, 10, byrow=TRUE)/1000; +C = matrix(seq(1,3600), 600, 6, byrow=TRUE)/1000; +D = matrix(seq(1,50), 5, 10, byrow=TRUE)/1000; + +R = einsum("fx,fg,fz,xg->z",A,B,C,D) + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum12.dml b/src/test/scripts/functions/einsum/einsum12.dml new file mode 100644 index 00000000000..f8bbd275100 --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum12.dml @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = matrix(seq(1,3000), 600, 5)/1000; +B = matrix(seq(1,6000), 600, 10)/1000; +C = matrix(seq(1,3600), 600, 6)/1000; +D = matrix(seq(1,50), 5, 10)/1000; + +while(FALSE){} + +R = einsum("fx,fg,fz,xg->z",A,B,C,D) + +write(R, $1) From b46f23a23735ec3d19e0ef5f7a65940c0a416ea1 Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Thu, 26 Jun 2025 13:55:36 +0200 Subject: [PATCH 11/28] support for scalar multiplication, better validation --- .../java/org/apache/sysds/hops/NaryOp.java | 19 ++-- .../parser/BuiltinFunctionExpression.java | 102 ++++++++++------- .../instructions/cp/EinsumCPInstruction.java | 105 +++++++++++------- .../test/functions/einsum/EinsumTest.java | 65 ++++++----- src/test/scripts/functions/einsum/einsum13.R | 33 ++++++ .../scripts/functions/einsum/einsum13.dml | 28 +++++ src/test/scripts/functions/einsum/einsum6.R | 2 +- src/test/scripts/functions/einsum/einsum6.dml | 4 +- 8 files changed, 236 insertions(+), 122 deletions(-) create mode 100644 src/test/scripts/functions/einsum/einsum13.R create mode 100644 src/test/scripts/functions/einsum/einsum13.dml diff --git a/src/main/java/org/apache/sysds/hops/NaryOp.java b/src/main/java/org/apache/sysds/hops/NaryOp.java index 466ae6236b7..aff2572a7e8 100644 --- a/src/main/java/org/apache/sysds/hops/NaryOp.java +++ b/src/main/java/org/apache/sysds/hops/NaryOp.java @@ -236,20 +236,21 @@ public void refreshSizeInformation() { setDim2(1); break; case EINSUM: - String outStr = ((LiteralOp) _input.get(0)).getStringValue().split("->")[1]; - int count = 0; - for (int i = 0; i < outStr.length(); i++){ - if(outStr.charAt(i) != ' ') count++; - } - if(count==0) { + String eqString = ((LiteralOp) _input.get(0)).getStringValue(); + if (eqString.charAt(eqString.length()-1)=='>'){ setDataType(DataType.SCALAR); setDim1(0); setDim2(0); + break; } - else{ - setDim1( HopRewriteUtils.getMaxInputDim(this, true)); - setDim2(count==1 ? 1 : HopRewriteUtils.getMaxInputDim(this, false)); + String outStr = eqString.split("->")[1]; + int count = 0; + for (int i = 0; i < outStr.length(); i++){ + if(outStr.charAt(i) != ' ') count++; } + // not true: todo later - set correct out size + setDim1( HopRewriteUtils.getMaxInputDim(this, true)); + setDim2(count==1 ? 1 : HopRewriteUtils.getMaxInputDim(this, false)); break; case PRINTF: case EVAL: diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java index c73ce9cdcdf..166fad4bc09 100644 --- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java @@ -2111,11 +2111,12 @@ private void validateEinsum(DataIdentifier output){ LanguageErrorCodes.INVALID_PARAMETERS); String eq_string = ((StringIdentifier)getFirstExpr()).getValue(); - String[] eqStringParts = eq_string.split("->"); - if(eqStringParts.length != 2) - raiseValidateError("Einsum: equation str should contain one '->' substring", false, - LanguageErrorCodes.INVALID_PARAMETERS); + if (eq_string.length() == 0) raiseValidateError("Einsum: equation str too short", false, LanguageErrorCodes.INVALID_PARAMETERS); + if (eq_string.charAt(0) == '-' || eq_string.charAt(0) == ',') raiseValidateError("Einsum: equation str invalid", false, LanguageErrorCodes.INVALID_PARAMETERS); + + String[] eqStringParts = eq_string.split("->"); // length 2 if "...->..." , length 1 if "...->" + boolean isResultScalar = eqStringParts.length == 1; Expression[] expressions = getAllExpr(); boolean allDimsKnown = true; @@ -2130,6 +2131,8 @@ private void validateEinsum(DataIdentifier output){ matrixBlocks.add((expressions[i].getOutput())); } + StringBuilder newEqString = new StringBuilder(); + if(allDimsKnown) { // validate dimension sizes as well HashMap charToDimensionSize = new HashMap<>(); Iterator it = matrixBlocks.iterator(); @@ -2138,6 +2141,8 @@ private void validateEinsum(DataIdentifier output){ int numberOfMatrices = 1; for (int i = 0; i < eqStringParts[0].length(); i++) { char c = eq_string.charAt(i); + if(c==' ') continue; + newEqString.append(c); if(c==','){ if(!it.hasNext()) raiseValidateError("Einsum: Provided less operands than specified in equation str", @@ -2145,10 +2150,7 @@ private void validateEinsum(DataIdentifier output){ currArr = it.next(); arrSizeIterator = 0; numberOfMatrices++; - }else if(c==' '){ - continue; - } - else{ + } else{ long thisCharDimension = arrSizeIterator == 0 ? currArr.getDim1() : currArr.getDim2(); if (charToDimensionSize.containsKey(c)){ if (charToDimensionSize.get(c) != thisCharDimension) @@ -2163,60 +2165,76 @@ private void validateEinsum(DataIdentifier output){ if (getAllExpr().length - 1 > numberOfMatrices) raiseValidateError("Einsum: Provided more operands than specified in equation str", false, LanguageErrorCodes.INVALID_PARAMETERS); + newEqString.append("->"); - int numberOfDimensions = 0; - long dim1 = 1; - long dim2 = 1; - for (int i = 0; i < eqStringParts[1].length(); i++) { - char c = eqStringParts[i].charAt(i); - if(c!=' '){ - if(numberOfDimensions == 0){ + if (isResultScalar){ + output.setDataType(DataType.SCALAR); + output.setDimensions(-1, -1); + }else { + int numberOfOutDimensions = 0; + Character dim1Char = null; + long dim1 = 1; + long dim2 = 1; + for (int i = 0; i < eqStringParts[1].length(); i++) { + char c = eqStringParts[1].charAt(i); + if (c == ' ') continue; + newEqString.append(c); + if (numberOfOutDimensions == 0) { + dim1Char = c; dim1 = charToDimensionSize.get(c); - }else{ + } else { + if(c==dim1Char) raiseValidateError("Einsum: output character "+c+" provided multiple times",false, LanguageErrorCodes.INVALID_PARAMETERS); dim2 = charToDimensionSize.get(c); } - numberOfDimensions++; + numberOfOutDimensions++; + } + if (numberOfOutDimensions > 2) { + raiseValidateError("Einsum: output matrices with with no. dims > 2 not supported",false, LanguageErrorCodes.INVALID_PARAMETERS); + } else { + output.setDataType(DataType.MATRIX); + output.setDimensions(dim1, dim2); } - } - if(numberOfDimensions == 0){ - output.setDataType(DataType.SCALAR); - output.setDimensions(-1, -1); - }else if(numberOfDimensions > 2){ - raiseValidateError("Einsum: output matrices with with no. dims > 2 not supported", - false, LanguageErrorCodes.INVALID_PARAMETERS); - }else { - output.setDataType(DataType.MATRIX); - output.setDimensions(dim1, dim2); } } else { // dimensions unknown int numberOfMatrices = 1; for (int i = 0; i < eqStringParts[0].length(); i++) { - if(eqStringParts[0].charAt(i) == ',') + char c = eqStringParts[0].charAt(i); + if(c == ' ') continue; + newEqString.append(c); + if(c == ',') numberOfMatrices++; } checkNumParameters(numberOfMatrices+1); + newEqString.append("->"); - int numberOfDimensions = 0; - for (int i = 0; i < eqStringParts[1].length(); i++) { - char c = eqStringParts[i].charAt(i); - if(c!=' '){ + if(isResultScalar){ + output.setDataType(DataType.SCALAR); + output.setDimensions(-1, -1); + }else { + int numberOfDimensions = 0; + Character dim1Char = null; + for (int i = 0; i < eqStringParts[1].length(); i++) { + char c = eqStringParts[i].charAt(i); + if(c == ' ') continue; + newEqString.append(c); numberOfDimensions++; + if (numberOfDimensions == 1 && c == dim1Char) + raiseValidateError("Einsum: output character "+c+" provided multiple times",false, LanguageErrorCodes.INVALID_PARAMETERS); + dim1Char = c; } - } - if(numberOfDimensions==0){ - output.setDataType(DataType.SCALAR); - output.setDimensions(-1, -1); - }else if(numberOfDimensions>2){ - raiseValidateError("Einsum: output matrices with with no. dims > 2 not supported", - false, LanguageErrorCodes.INVALID_PARAMETERS); - }else{ - output.setDataType(DataType.MATRIX); - output.setDimensions(-1, -1); + if (numberOfDimensions > 2) { + raiseValidateError("Einsum: output matrices with with no. dims > 2 not supported", + false, LanguageErrorCodes.INVALID_PARAMETERS); + } else { + output.setDataType(DataType.MATRIX); + output.setDimensions(-1, -1); + } } } output.setValueType(ValueType.FP64); output.setBlocksize(getSecondExpr().getOutput().getBlocksize()); + ((StringIdentifier) getFirstExpr()).setValue(newEqString.toString()); } private void setBinaryOutputProperties(DataIdentifier output) { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java index 6d22a1ffd5b..c63fca5e184 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -86,35 +86,30 @@ public Class getOperatorClass() { @Override public void processInstruction(ExecutionContext ec) { - //get input matrices and scalars, incl pinning of matrices ArrayList inputs = new ArrayList<>(); - ArrayList inputsNames = new ArrayList<>(); for (CPOperand input : _in) { if(input.getDataType()==DataType.MATRIX){ MatrixBlock mb = ec.getMatrixInput(input.getName()); - //FIXME fused codegen operators already support compressed main inputs if(mb instanceof CompressedMatrixBlock){ mb = ((CompressedMatrixBlock) mb).getUncompressed("Spoof instruction"); } inputs.add(mb); - inputsNames.add(input.getName()); } } EinsumContext einc = getEinsumContext(eqStr,inputs); this.einc = einc; - //todo not throwing err when output char is not in input String[] parts = einc.equationString.split("->"); -// ArrayList inputsChars = new ArrayList<>(Arrays.asList(parts[0].split(","))); if( LOG.isDebugEnabled() ) LOG.trace("outrows:"+einc.outRows+", outcols:"+einc.outCols); Character outChar1 = null; Character outChar2 = null; - if(parts[1].length()>=2){ + if(parts.length == 1){ } + else if(parts[1].length() >= 2){ outChar1 = parts[1].charAt(0); outChar2 = parts[1].charAt(1); }else if (parts[1].length()==1){ @@ -146,14 +141,9 @@ public void processInstruction(ExecutionContext ec) { arrCounter++; } else if(!einc.contractDimsSet.contains(c2)){ - if (c2==c ){ + if (c2 == c){ diagMatrices.add(arrCounter); } - -// if(partsCharactersCounter.containsKey(c2)) -// partsCharactersCounter.put(c2, partsCharactersCounter.get(c2)+1); -// else partsCharactersCounter.put(c2, 1); - if(!partsCharactersToIndices.containsKey(c2)) partsCharactersToIndices.put(c2, new ArrayList<>()); @@ -162,17 +152,14 @@ else if(!einc.contractDimsSet.contains(c2)){ } i++; - } newEquationStringSplit.add(s); } ArrayList inputsChars = newEquationStringSplit; LOG.trace(String.join(",",newEquationStringSplit)); - //todo move to separate op earlier: for(int i=0;i a = partsCharactersToIndices.get(c); + LOG.trace(c+" count= "+a.size()); + } - if(LOG.isTraceEnabled()) - for(Character c :partsCharactersToIndices.keySet()){ - ArrayList a = partsCharactersToIndices.get(c); - LOG.trace(c+" count= "+a.size()); + // compute scalar by suming-all matrices: + Double scalar = null; + for(int i=0;i< inputs.size(); i++){ + String s = inputsChars.get(i); + if(s.equals("")){ + MatrixBlock mb = inputs.get(i); + if (scalar == null) scalar = mb.get(0,0); + else scalar*= mb.get(0,0); + inputs.set(i,null); + inputsChars.set(i,null); } + } + if (scalar != null) { + boolean appliedToSomeMatrix = false; + for(int i = 0; i < inputs.size(); i++){ + if(inputs.get(i) != null){ + inputs.set(i, getScalarMultiplyMatrixBlock(inputs.get(i), scalar)); + appliedToSomeMatrix = true; break; + } + } + if(!appliedToSomeMatrix){ + ec.setScalarOutput(output.getName(), new DoubleObject(scalar)); + releaseMatrixInputs(ec); + return; + } + } - - Double scalar = null; boolean anyCouldNotDo = FORCE_CELL_TPL ? true : generatePlanAndExecute(partsCharactersToIndices, inputs, inputsChars, outChar1, outChar2); // information to do cell tpl for remaining ones if (!anyCouldNotDo){ @@ -270,29 +280,16 @@ else if(!einc.contractDimsSet.contains(c2)){ } } - //final operation - + releaseMatrixInputs(ec); + } - // release input matrices + private void releaseMatrixInputs(ExecutionContext ec){ for (CPOperand input : _in) if(input.getDataType()==DataType.MATRIX) ec.releaseMatrixInput(input.getName()); } private boolean generatePlanAndExecute(HashMap> partsCharactersToIndices, ArrayList inputs, ArrayList inputsChars, Character outChar1, Character outChar2) { - // compute scalars: - Double scalar = null; - for(int i=0;i< inputs.size(); i++){ - String s = inputsChars.get(i); - if(s.equals("")){ - MatrixBlock mb = inputs.get(i); - if (scalar == null) scalar = mb.get(0,0); - else scalar*= mb.get(0,0); - inputs.set(i,null); - inputsChars.set(i,null); - } - - } boolean anyCouldNotDo; boolean didAnything = false; do { @@ -691,6 +688,34 @@ private MatrixBlock getRowCodegenMatrixBlock(MatrixBlock first, MatrixBlock seco MatrixBlock out = op.execute(thisInputs, scalars, mb, _numThreads); return out; } + + private MatrixBlock getScalarMultiplyMatrixBlock(MatrixBlock mbIn, Double scalar){ + ArrayList thisInputs = new ArrayList<>(Arrays.asList(mbIn)); + + ArrayList cnodeIn = new ArrayList<>(); + + CNode c1 = new CNodeData("c1", 1, mbIn.getNumRows(), mbIn.getNumColumns(), DataType.MATRIX); + CNode c2 = new CNodeData(new LiteralOp(scalar), 0, 0, DataType.SCALAR); + cnodeIn.add(c1); + cnodeIn.add(c2); + + CNode cnodeOut = new CNodeBinary(c1,c2, CNodeBinary.BinType.MULT); + CNodeCell cnode = new CNodeCell(cnodeIn, cnodeOut); + cnode.setCellType(SpoofCellwise.CellType.NO_AGG); + cnode.renameInputs(); + + String src = cnode.codegen(false, SpoofCompiler.GeneratorAPI.JAVA); + if( LOG.isTraceEnabled()) LOG.trace(CodegenUtils.printWithLineNumber(src)); + Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); + + SpoofOperator op = CodegenUtils.createInstance(cla); + MatrixBlock mb = new MatrixBlock(); + + ArrayList scalars = new ArrayList<>(); + if(scalar != null) scalars.add(new DoubleObject(scalar)); + MatrixBlock out = op.execute(thisInputs, scalars, mb, _numThreads); + return out; + } private static void indent(StringBuilder sb, int level) { for (int i = 0; i < level; i++) { sb.append(" "); diff --git a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java index 24b0fc2b14b..9d533a3ca7a 100644 --- a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java +++ b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java @@ -50,92 +50,103 @@ public class EinsumTest extends AutomatedTestBase private static final String TEST_EINSUM10 = TEST_NAME_EINSUM+"10"; private static final String TEST_EINSUM11 = TEST_NAME_EINSUM+"11"; private static final String TEST_EINSUM12 = TEST_NAME_EINSUM+"12"; - + private static final String TEST_EINSUM13 = TEST_NAME_EINSUM+"13"; private static final String TEST_DIR = "functions/einsum/"; private static final String TEST_CLASS_DIR = TEST_DIR + EinsumTest.class.getSimpleName() + "/"; private final static String TEST_CONF = "SystemDS-config-codegen.xml"; private final static File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, TEST_CONF); private static double eps = Math.pow(10, -10); - + @Override public void setUp() { TestUtils.clearAssertionInformation(); - for(int i=1; i<=12; i++) + for(int i=1; i<=13; i++) addTestConfiguration( TEST_NAME_EINSUM+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_EINSUM+i, new String[] { String.valueOf(i) }) ); } @Test public void testCodegenEinsum1CP() { - testCodegenIntegration( TEST_EINSUM1, false, ExecType.CP ); + testCodegenIntegration( TEST_EINSUM1, ExecType.CP ); } @Test public void testCodegenEinsum2CP() { - testCodegenIntegration( TEST_EINSUM2, false, ExecType.CP ); + testCodegenIntegration( TEST_EINSUM2, ExecType.CP ); } @Test public void testCodegenEinsum3CP() { - testCodegenIntegration( TEST_EINSUM3, false, ExecType.CP ); -} + testCodegenIntegration( TEST_EINSUM3, ExecType.CP ); + } @Test public void testCodegenEinsum4CP() { - testCodegenIntegration( TEST_EINSUM4, false, ExecType.CP ); + testCodegenIntegration( TEST_EINSUM4, ExecType.CP ); } @Test public void testCodegenEinsum5CP() { - testCodegenIntegration( TEST_EINSUM5, false, ExecType.CP ); + testCodegenIntegration( TEST_EINSUM5, ExecType.CP ); } @Test public void testCodegenEinsum6CP() { - testCodegenIntegration( TEST_EINSUM6, false, ExecType.CP ); + testCodegenIntegration( TEST_EINSUM6, ExecType.CP ); } @Test public void testCodegenEinsum7CP() { - testCodegenIntegration( TEST_EINSUM7, false, ExecType.CP ); + testCodegenIntegration( TEST_EINSUM7, ExecType.CP ); } @Test - public void testCodegenEinsum8CP() { testCodegenIntegration( TEST_EINSUM8, false, ExecType.CP ); } + public void testCodegenEinsum8CP() { testCodegenIntegration( TEST_EINSUM8, ExecType.CP ); } @Test public void testCodegenEinsum9CP() { - testCodegenIntegration( TEST_EINSUM9, false, ExecType.CP ); + testCodegenIntegration( TEST_EINSUM9, ExecType.CP ); } @Test public void testCodegenEinsum10CP() { - testCodegenIntegration( TEST_EINSUM10, false, ExecType.CP ); + testCodegenIntegration( TEST_EINSUM10, ExecType.CP ); } @Test public void testCodegenEinsum11CP() { - testCodegenIntegration( TEST_EINSUM11, false, ExecType.CP ); + testCodegenIntegration( TEST_EINSUM11, ExecType.CP ); } @Test public void testCodegenEinsum12CP() { - testCodegenIntegration( TEST_EINSUM12, false, ExecType.CP ); + testCodegenIntegration( TEST_EINSUM12, ExecType.CP ); + } + @Test + public void testCodegenEinsum13CP() { + testCodegenIntegration( TEST_EINSUM13, ExecType.CP, true ); } - private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType ) + private void testCodegenIntegration( String testname, ExecType instType) { testCodegenIntegration(testname, instType, false); } + private void testCodegenIntegration( String testname, ExecType instType, boolean outputScalar ) { boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; ExecMode platformOld = setExecMode(instType); - + try { TestConfiguration config = getTestConfiguration(testname); loadTestConfiguration(config); - + String HOME = SCRIPT_DIR + TEST_DIR; fullDMLScriptName = HOME + testname + ".dml"; programArgs = new String[]{"-stats", "-explain", "-args", output("S") }; - + fullRScriptName = HOME + testname + ".R"; rCmd = getRCmd(inputDir(), expectedDir()); - OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; - + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = false; + runTest(true, false, null, -1); runRScript(true); - - //compare matrices - HashMap dmlfile = readDMLMatrixFromOutputDir("S"); - HashMap rfile = readRMatrixFromExpectedDir("S"); - TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + if(outputScalar){ + HashMap dmlfile = readDMLScalarFromExpectedDir("S"); + HashMap rfile = readRScalarFromExpectedDir("S"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + }else { + //compare matrices + HashMap dmlfile = readDMLMatrixFromOutputDir("S"); + HashMap rfile = readRMatrixFromExpectedDir("S"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + } } finally { resetExecMode(platformOld); diff --git a/src/test/scripts/functions/einsum/einsum13.R b/src/test/scripts/functions/einsum/einsum13.R new file mode 100644 index 00000000000..e71b18a5bc9 --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum13.R @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") +library("einsum") + +X = matrix(seq(1,6000), 600, 10, byrow=TRUE); +P = matrix(seq(1,30), 6, 5, byrow=TRUE); + +R = einsum("ab,cd->",P,X) + +write(R, paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum13.dml b/src/test/scripts/functions/einsum/einsum13.dml new file mode 100644 index 00000000000..337a5cc5573 --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum13.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(seq(1,6000), 600, 10); +P = matrix(seq(1,30), 6, 5) + +while(FALSE){} + +R = einsum("ab,cd->",P,X) +write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum6.R b/src/test/scripts/functions/einsum/einsum6.R index 0aa46d92a64..7d0b7eef07e 100644 --- a/src/test/scripts/functions/einsum/einsum6.R +++ b/src/test/scripts/functions/einsum/einsum6.R @@ -29,6 +29,6 @@ X = matrix(seq(1,6000), 600, 10, byrow=TRUE); P = matrix(seq(1,30), 6, 5, byrow=TRUE); # R = P * sum(X) -R = einsum("ab,cd->ab",P,X) +R = einsum("ab,cd->ba",P,X) writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum6.dml b/src/test/scripts/functions/einsum/einsum6.dml index 1385d0329eb..fae1e38b5e4 100644 --- a/src/test/scripts/functions/einsum/einsum6.dml +++ b/src/test/scripts/functions/einsum/einsum6.dml @@ -24,7 +24,5 @@ P = matrix(seq(1,30), 6, 5) while(FALSE){} -#R = colSums(t(P) %*% X) ; - -R = einsum("ab,cd->ab",P,X) +R = einsum("ab,cd->ba",P,X) write(R, $1) From d9d52eff18fd19aeab4cc3ea010c62a9ca5eea6f Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Thu, 26 Jun 2025 17:04:30 +0200 Subject: [PATCH 12/28] add test and removing unnecessary changes --- .../instructions/CPInstructionParser.java | 2 +- .../instructions/cp/EinsumCPInstruction.java | 25 ++++++++++---- .../test/functions/einsum/EinsumTest.java | 7 +++- src/test/scripts/functions/einsum/einsum14.R | 34 +++++++++++++++++++ .../scripts/functions/einsum/einsum14.dml | 28 +++++++++++++++ src/test/scripts/functions/einsum/einsum3.R | 2 +- src/test/scripts/functions/einsum/einsum4.R | 5 ++- src/test/scripts/functions/einsum/einsum4.dml | 5 ++- src/test/scripts/functions/einsum/einsum5.R | 4 +-- src/test/scripts/functions/einsum/einsum5.dml | 5 ++- src/test/scripts/functions/einsum/einsum6.R | 6 ++-- src/test/scripts/functions/einsum/einsum6.dml | 6 ++-- 12 files changed, 104 insertions(+), 25 deletions(-) create mode 100644 src/test/scripts/functions/einsum/einsum14.R create mode 100644 src/test/scripts/functions/einsum/einsum14.dml diff --git a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java index 0f6c5e9f57e..ed4e4cfbbbc 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java @@ -64,7 +64,6 @@ import org.apache.sysds.runtime.instructions.cp.UaggOuterChainCPInstruction; import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction; import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction; -import org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction; import org.apache.sysds.runtime.instructions.cpfile.MatrixIndexingCPFileInstruction; public class CPInstructionParser extends InstructionParser { @@ -219,6 +218,7 @@ public static CPInstruction parseSingleInstruction ( InstructionType cptype, Str case EvictLineageCache: return EvictCPInstruction.parseInstruction(str); + default: throw new DMLRuntimeException("Invalid CP Instruction Type: " + cptype ); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java index c63fca5e184..3f400fc091f 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -39,6 +39,7 @@ import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysds.runtime.functionobjects.*; +import org.apache.sysds.runtime.matrix.data.LibMatrixAgg; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.AggregateOperator; import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; @@ -165,8 +166,12 @@ else if(einc.contractDims[i] == CONTRACT_BOTH) { AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads); - MatrixBlock newB = (MatrixBlock)inputs.get(i).aggregateUnaryOperations(aggun,new MatrixBlock(),inputs.get(i).getNumRows(),null); - inputs.set(i, newB); + MatrixBlock res = new MatrixBlock(); + res.setNumRows(1); + res.setNumColumns(1); + LibMatrixAgg.aggregateUnaryMatrix(inputs.get(i), res, aggun, _numThreads); +// MatrixBlock newB = (MatrixBlock)inputs.get(i).aggregateUnaryOperations(aggun,new MatrixBlock(),inputs.get(i).getNumRows(),null); + inputs.set(i, res); }else if(einc.contractDims[i] == CONTRACT_RIGHT){ //rowSums (remove 2nd dim) @@ -174,8 +179,12 @@ else if(einc.contractDims[i] == CONTRACT_BOTH) { AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads); - MatrixBlock newB = (MatrixBlock)inputs.get(i).aggregateUnaryOperations(aggun,new MatrixBlock(),inputs.get(i).getNumRows(),null); - inputs.set(i, newB); + MatrixBlock res = new MatrixBlock(); + res.setNumRows(inputs.get(i).getNumRows()); + res.setNumColumns(1); + LibMatrixAgg.aggregateUnaryMatrix(inputs.get(i), res, aggun, _numThreads); +// MatrixBlock newB = (MatrixBlock)inputs.get(i).aggregateUnaryOperations(aggun,new MatrixBlock(),inputs.get(i).getNumRows(),null); + inputs.set(i, res); }else if(einc.contractDims[i] == CONTRACT_LEFT){ //colSums (remove 1st dim) @@ -183,8 +192,12 @@ else if(einc.contractDims[i] == CONTRACT_BOTH) { AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), _numThreads); - MatrixBlock newB = (MatrixBlock)inputs.get(i).aggregateUnaryOperations(aggun,new MatrixBlock(),inputs.get(i).getNumColumns(),null); - inputs.set(i, newB); + MatrixBlock res = new MatrixBlock(); + res.setNumRows(inputs.get(i).getNumColumns()); + res.setNumColumns(1); + LibMatrixAgg.aggregateUnaryMatrix(inputs.get(i), res, aggun, _numThreads); +// MatrixBlock newB = (MatrixBlock)inputs.get(i).aggregateUnaryOperations(aggun,new MatrixBlock(),inputs.get(i).getNumColumns(),null); + inputs.set(i, res); } } for(Integer idx : diagMatrices){ diff --git a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java index 9d533a3ca7a..69452a0859e 100644 --- a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java +++ b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java @@ -51,6 +51,7 @@ public class EinsumTest extends AutomatedTestBase private static final String TEST_EINSUM11 = TEST_NAME_EINSUM+"11"; private static final String TEST_EINSUM12 = TEST_NAME_EINSUM+"12"; private static final String TEST_EINSUM13 = TEST_NAME_EINSUM+"13"; + private static final String TEST_EINSUM14 = TEST_NAME_EINSUM+"14"; private static final String TEST_DIR = "functions/einsum/"; private static final String TEST_CLASS_DIR = TEST_DIR + EinsumTest.class.getSimpleName() + "/"; private final static String TEST_CONF = "SystemDS-config-codegen.xml"; @@ -61,7 +62,7 @@ public class EinsumTest extends AutomatedTestBase @Override public void setUp() { TestUtils.clearAssertionInformation(); - for(int i=1; i<=13; i++) + for(int i=1; i<=14; i++) addTestConfiguration( TEST_NAME_EINSUM+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_EINSUM+i, new String[] { String.valueOf(i) }) ); } @Test @@ -114,6 +115,10 @@ public void testCodegenEinsum12CP() { public void testCodegenEinsum13CP() { testCodegenIntegration( TEST_EINSUM13, ExecType.CP, true ); } + @Test + public void testCodegenEinsum14CP() { + testCodegenIntegration( TEST_EINSUM14, ExecType.CP); + } private void testCodegenIntegration( String testname, ExecType instType) { testCodegenIntegration(testname, instType, false); } private void testCodegenIntegration( String testname, ExecType instType, boolean outputScalar ) { diff --git a/src/test/scripts/functions/einsum/einsum14.R b/src/test/scripts/functions/einsum/einsum14.R new file mode 100644 index 00000000000..7d0b7eef07e --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum14.R @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") +library("einsum") + +X = matrix(seq(1,6000), 600, 10, byrow=TRUE); +P = matrix(seq(1,30), 6, 5, byrow=TRUE); + +# R = P * sum(X) +R = einsum("ab,cd->ba",P,X) + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum14.dml b/src/test/scripts/functions/einsum/einsum14.dml new file mode 100644 index 00000000000..fae1e38b5e4 --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum14.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(seq(1,6000), 600, 10); +P = matrix(seq(1,30), 6, 5) + +while(FALSE){} + +R = einsum("ab,cd->ba",P,X) +write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum3.R b/src/test/scripts/functions/einsum/einsum3.R index 0a802dc6fde..ce8a34f51ac 100644 --- a/src/test/scripts/functions/einsum/einsum3.R +++ b/src/test/scripts/functions/einsum/einsum3.R @@ -25,8 +25,8 @@ library("Matrix") library("matrixStats") library("einsum") -X = matrix(seq(1,6000), 600, 10, byrow=TRUE); P = matrix(seq(1,3000), 600, 5, byrow=TRUE); +X = matrix(seq(1,6000), 600, 10, byrow=TRUE); # R = sum(t(P) %*% X); R = einsum("ji,jz->i",P,X) diff --git a/src/test/scripts/functions/einsum/einsum4.R b/src/test/scripts/functions/einsum/einsum4.R index bb9caf31e57..74f5560464d 100644 --- a/src/test/scripts/functions/einsum/einsum4.R +++ b/src/test/scripts/functions/einsum/einsum4.R @@ -25,10 +25,9 @@ library("Matrix") library("matrixStats") library("einsum") -X = matrix(seq(1,6000), 600, 10, byrow=TRUE); P = matrix(seq(1,3000), 600, 5, byrow=TRUE); +X = matrix(seq(1,6000), 10, 600, byrow=TRUE); -# R = colSums(t(P) %*% X); -R = einsum("ji,jz->z",P,X) +R = einsum("ji,zj->i",P,X) writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum4.dml b/src/test/scripts/functions/einsum/einsum4.dml index 84c3efbdfae..3806db22ffe 100644 --- a/src/test/scripts/functions/einsum/einsum4.dml +++ b/src/test/scripts/functions/einsum/einsum4.dml @@ -19,11 +19,10 @@ # #------------------------------------------------------------- P = matrix(seq(1,3000), 600, 5) -X = matrix(seq(1,6000), 600, 10); +X = matrix(seq(1,6000), 10, 600); while(FALSE){} -#R = colSums(t(P) %*% X) ; -R = einsum("ji,jz->z",P,X) +R = einsum("ji,zj->i",P,X) write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum5.R b/src/test/scripts/functions/einsum/einsum5.R index 902941b9222..bb9caf31e57 100644 --- a/src/test/scripts/functions/einsum/einsum5.R +++ b/src/test/scripts/functions/einsum/einsum5.R @@ -28,7 +28,7 @@ library("einsum") X = matrix(seq(1,6000), 600, 10, byrow=TRUE); P = matrix(seq(1,3000), 600, 5, byrow=TRUE); -# R = rowSums(P) * rowSums(X) -R = einsum("ji,jz->j",P,X) +# R = colSums(t(P) %*% X); +R = einsum("ji,jz->z",P,X) writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum5.dml b/src/test/scripts/functions/einsum/einsum5.dml index de5fb654e20..84c3efbdfae 100644 --- a/src/test/scripts/functions/einsum/einsum5.dml +++ b/src/test/scripts/functions/einsum/einsum5.dml @@ -18,13 +18,12 @@ # under the License. # #------------------------------------------------------------- - -X = matrix(seq(1,6000), 600, 10); P = matrix(seq(1,3000), 600, 5) +X = matrix(seq(1,6000), 600, 10); while(FALSE){} #R = colSums(t(P) %*% X) ; -R = einsum("ji,jz->j",P,X) +R = einsum("ji,jz->z",P,X) write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum6.R b/src/test/scripts/functions/einsum/einsum6.R index 7d0b7eef07e..902941b9222 100644 --- a/src/test/scripts/functions/einsum/einsum6.R +++ b/src/test/scripts/functions/einsum/einsum6.R @@ -26,9 +26,9 @@ library("matrixStats") library("einsum") X = matrix(seq(1,6000), 600, 10, byrow=TRUE); -P = matrix(seq(1,30), 6, 5, byrow=TRUE); +P = matrix(seq(1,3000), 600, 5, byrow=TRUE); -# R = P * sum(X) -R = einsum("ab,cd->ba",P,X) +# R = rowSums(P) * rowSums(X) +R = einsum("ji,jz->j",P,X) writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum6.dml b/src/test/scripts/functions/einsum/einsum6.dml index fae1e38b5e4..de5fb654e20 100644 --- a/src/test/scripts/functions/einsum/einsum6.dml +++ b/src/test/scripts/functions/einsum/einsum6.dml @@ -20,9 +20,11 @@ #------------------------------------------------------------- X = matrix(seq(1,6000), 600, 10); -P = matrix(seq(1,30), 6, 5) +P = matrix(seq(1,3000), 600, 5) while(FALSE){} -R = einsum("ab,cd->ba",P,X) +#R = colSums(t(P) %*% X) ; + +R = einsum("ji,jz->j",P,X) write(R, $1) From 1c4abc3c5c7ec45fd4ec4e8e10e606a0dbf4064a Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Thu, 26 Jun 2025 17:37:31 +0200 Subject: [PATCH 13/28] remove more unused code, fix bug --- .../instructions/cp/EinsumCPInstruction.java | 136 ++---------------- 1 file changed, 14 insertions(+), 122 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java index 3f400fc091f..8caa8dc3828 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -54,31 +54,18 @@ public class EinsumCPInstruction extends BuiltinNaryCPInstruction { public static boolean FORCE_CELL_TPL = false; protected static final Log LOG = LogFactory.getLog(EinsumCPInstruction.class.getName()); public String eqStr; - private final Class _class; - private final SpoofOperator _op; private final int _numThreads; private final CPOperand[] _in; public EinsumCPInstruction(Operator op, String opcode, String istr, CPOperand out, CPOperand... inputs) { super(op, opcode, istr, out, inputs); - _class = null; - _op = null; _numThreads = OptimizerUtils.getConstrainedNumThreads(-1); _in = inputs; this.eqStr = inputs[0].getName(); Logger.getLogger(EinsumCPInstruction.class).setLevel(Level.TRACE); } - - public SpoofOperator getSpoofOperator() { - return _op; - } - - public Class getOperatorClass() { - return _class; - } - private static final int CONTRACT_LEFT = 1; private static final int CONTRACT_RIGHT = 2; private static final int CONTRACT_BOTH = 3; @@ -262,7 +249,7 @@ else if(einc.contractDims[i] == CONTRACT_BOTH) { res = remMbs.get(0).reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); } }else{ - throw new RuntimeException("did not expect this!"); + throw new RuntimeException("Einsum runtime error"); // should not happen } ec.setMatrixOutput(output.getName(), res); } @@ -277,20 +264,16 @@ else if(einc.contractDims[i] == CONTRACT_BOTH) { chars.add(inputsChars.get(i)); } } - if(chars.size()==1 && chars.get(0).equals(parts[1])){ // maybe result is correct after all... (need to improve logic if cell is needed but for now just check if maybe all is OK) - ec.setMatrixOutput(output.getName(), mbs.get(0)); - }else { - ArrayList summingChars = new ArrayList(); - for (Character c : partsCharactersToIndices.keySet()) { - if (c != outChar1 && c != outChar2) summingChars.add(c); - } + ArrayList summingChars = new ArrayList(); + for (Character c : partsCharactersToIndices.keySet()) { + if (c != outChar1 && c != outChar2) summingChars.add(c); + } - MatrixBlock res = computeCellSummation(mbs, chars, parts[1], einc.charToDimensionSizeInt, summingChars, einc.outRows, einc.outCols); + MatrixBlock res = computeCellSummation(mbs, chars, parts[1], einc.charToDimensionSizeInt, summingChars, einc.outRows, einc.outCols); - if (einc.outRows == 1 && einc.outCols == 1) - ec.setScalarOutput(output.getName(), new DoubleObject(res.get(0, 0))); - else ec.setMatrixOutput(output.getName(), res); - } + if (einc.outRows == 1 && einc.outCols == 1) + ec.setScalarOutput(output.getName(), new DoubleObject(res.get(0, 0))); + else ec.setMatrixOutput(output.getName(), res); } releaseMatrixInputs(ec); @@ -318,15 +301,11 @@ private boolean generatePlanAndExecute(HashMap> pa /* handle situation: ji,ji or ij,ji */ private boolean multiplyTerms(HashMap> partsCharactersToIndices, ArrayList inputs, ArrayList inputsChars, Character outChar1, Character outChar2 ) { - ArrayList multiplyIdxs = new ArrayList<>(); - ArrayList transposeMultiplyIdxs = new ArrayList<>(); - HashMap> stringToIndex = new HashMap<>(); for(int i = 0; i < inputsChars.size(); i++){ String s = inputsChars.get(i); if(s==null) continue; -// if(s.length() != 2) continue; if (stringToIndex.containsKey(s)) stringToIndex.get(s).add(i); else { ArrayList list = new ArrayList<>(); list.add(i); stringToIndex.put(s, list); } @@ -353,8 +332,7 @@ private boolean multiplyTerms(HashMap> partsCharac sb.append(inputsChars.get(idx)); sb.append(","); } - if(idxsT != null) - for(Integer idx : idxsT){ + if(idxsT != null) for(Integer idx : idxsT){ sb.append(inputsChars.get(idx)); sb.append(","); } @@ -366,8 +344,7 @@ private boolean multiplyTerms(HashMap> partsCharac inputsChars.set(idx, null); } ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); - if(idxsT != null) - for(Integer idx : idxsT){ + if(idxsT != null) for(Integer idx : idxsT){ mbs.add(inputs.get(idx).reorgOperations(transpose, new MatrixBlock(), 0, 0, 0)); inputs.set(idx, null); inputsChars.set(idx, null); @@ -377,25 +354,25 @@ private boolean multiplyTerms(HashMap> partsCharac inputs.add(mb); inputsChars.add(s); - for (int i = 0; i < s.length(); i++) { // for each char in string, add pointer to newly created entry char c = s.charAt(i); partsCharactersToIndices.get(c).add(inputs.size() - 1); } - if(idxsT != null) - stringToIndex.remove(sT); + if(idxsT != null) stringToIndex.remove(sT); } return doneAnything; } + // returns true if left with summation with more than 2 inputs private boolean sumCharactersWherePossible(HashMap> partsCharactersToIndices, ArrayList inputs, ArrayList inputsChars, Character outChar1, Character outChar2) { boolean anyCouldNotDo = false; while (true) { List toSum = null; Character sumC = null; + anyCouldNotDo = false; for (Character c : partsCharactersToIndices.keySet()) { // sum one dim at the time if (c == outChar1 || c == outChar2) continue; @@ -447,51 +424,6 @@ private enum SumOperation { aB_aB, Ba_Ba, Ba_aB, aB_Ba,// mult and sums, something like ij,ij->i } - private enum AggregateAtEnd{ - Left, - Right, - Both, - None, - } - private Pair computRowSummationsOutputCharsOnly(List inputs, List inputsChars, String resString, Double scalar ){ - if(resString.length() == 1){ - // dont expect more than two of these, throw error if happens - if(inputs.size() != 2) throw new RuntimeException("did not expects this, please fix me"); - MatrixBlock mb = getRowCodegenMatrixBlock(inputs.get(0), inputs.get(1), CNodeBinary.BinType.VECT_MULT, SpoofRowwise.RowType.NO_AGG); - return Pair.of(mb, inputsChars.get(0)); - }else{ // resString.length() == 2 - // something like a,a,b,b,ab,ba - // group them - - ArrayList a = new ArrayList<>(); - ArrayList b = new ArrayList<>(); - ArrayList ab = new ArrayList<>(); - ArrayList ba = new ArrayList<>(); - for(int i =0;i< inputs.size(); i++){ - String s = inputsChars.get(i); - if(s.length() == 2){ - if(s.equals(resString)) ab.add(inputs.get(i)); - else ba.add(inputs.get(i)); - }else{ - if(s.charAt(0)==resString.charAt(0)) a.add(inputs.get(i)); - else b.add(inputs.get(i)); - } - } - // mult all a-s: - // mult all b-s: - // check if there is ab or ba - // if no: - // then do outer product axb or bxa - // if there is then: - // mult ba and a - // mult ab and b - // transp ba into ab - // mult 2 ab and ab - return Pair.of( ab.get(0) ,resString); -// throw new NotImplementedException("todo"); -// return null; - } - } private Pair computeRowSummation(List toSum, ArrayList inputs, List inputsChars, Character sumChar) { return computeRowSummation(toSum,inputs,inputsChars, null, sumChar); } @@ -859,7 +791,6 @@ else if (summingChars.contains(inputsChars.get(i).charAt(1))) { src = src.replace("%OUT%", sb.toString()); } -// String src = needsSumming ? cnode.codegenEinsum(false, SpoofCompiler.GeneratorAPI.JAVA, sb.toString(), outVar) : cnode.codegenEinsum(false, SpoofCompiler.GeneratorAPI.JAVA, "", sb.toString()); LOG.trace(src); Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); SpoofOperator op = CodegenUtils.createInstance(cla); @@ -871,46 +802,7 @@ else if (summingChars.contains(inputsChars.get(i).charAt(1))) { return out; } - public CPOperand[] getInputs() { return _in; } - - private static final IDSequence _idSeqfn = new IDSequence(); - - private final static String tmpCell = - "package codegen;\n" + - "import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;\n" + - "import org.apache.sysds.runtime.codegen.SpoofCellwise;\n" + - "import org.apache.sysds.runtime.codegen.SpoofCellwise.AggOp;\n" + - "import org.apache.sysds.runtime.codegen.SpoofCellwise.CellType;\n" + - "import org.apache.sysds.runtime.codegen.SpoofOperator.SideInput;\n" + - "import org.apache.commons.math3.util.FastMath;\n" + - "public final class %TMP% extends SpoofCellwise {\n" + - " public %TMP%() {\n" + - " super(CellType.NO_AGG, false, true, null);\n" + - " }\n" + - " protected double genexec(double a, SideInput[] b, double[] scalars, int m, int n, long grix, int rix, int cix) { \n" + - " %BODY_dense%" + - " return %OUT%;\n" + - " }\n" + - "}"; - public static final String JAVA_TEMPLATE = - "package codegen;\n" - + "import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;\n" - + "import org.apache.sysds.runtime.codegen.SpoofCellwise;\n" - + "import org.apache.sysds.runtime.codegen.SpoofCellwise.AggOp;\n" - + "import org.apache.sysds.runtime.codegen.SpoofCellwise.CellType;\n" - + "import org.apache.sysds.runtime.codegen.SpoofOperator.SideInput;\n" - + "import org.apache.commons.math3.util.FastMath;\n" - + "\n" - + "public final class %TMP% extends SpoofCellwise {\n" - + " public %TMP%() {\n" - + " super(CellType.%TYPE%, %SPARSE_SAFE%, %SEQ%, %AGG_OP_NAME%);\n" - + " }\n" - + " protected double genexec(double a, SideInput[] b, double[] scalars, int m, int n, long grix, int rix, int cix) { \n" - + "%BODY_dense%" - + " return %OUT%;\n" - + " }\n" - + "}\n"; } From 0f3c395a4020053fadef0f14a5afc4a0dd03da4d Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Thu, 26 Jun 2025 17:38:17 +0200 Subject: [PATCH 14/28] add einsum to test suite in github workflows --- .github/workflows/javaTests.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/javaTests.yml b/.github/workflows/javaTests.yml index d13b187fb22..6716ddfb4ad 100644 --- a/.github/workflows/javaTests.yml +++ b/.github/workflows/javaTests.yml @@ -83,7 +83,8 @@ jobs: "**.functions.unary.scalar.**,**.functions.updateinplace.**,**.functions.vect.**", "**.functions.reorg.**,**.functions.rewrite.**,**.functions.ternary.**", "**.functions.transform.**","**.functions.unique.**", - "**.functions.unary.matrix.**,**.functions.linearization.**,**.functions.jmlc.**" + "**.functions.unary.matrix.**,**.functions.linearization.**,**.functions.jmlc.**", + "**.functions.einsum.**", ] name: ${{ matrix.tests }} steps: From bf4f910408e41c1cd134cb55f0045fbc653bb222 Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Mon, 30 Jun 2025 17:59:28 +0200 Subject: [PATCH 15/28] add tests and code for vector elementwise mult. and outer product --- .../instructions/cp/EinsumCPInstruction.java | 56 ++++++++++++++++--- .../test/functions/einsum/EinsumTest.java | 15 +++-- src/test/scripts/functions/einsum/einsum15.R | 33 +++++++++++ .../scripts/functions/einsum/einsum15.dml | 28 ++++++++++ src/test/scripts/functions/einsum/einsum16.R | 32 +++++++++++ .../scripts/functions/einsum/einsum16.dml | 27 +++++++++ src/test/scripts/functions/einsum/einsum17.R | 33 +++++++++++ .../scripts/functions/einsum/einsum17.dml | 27 +++++++++ 8 files changed, 240 insertions(+), 11 deletions(-) create mode 100644 src/test/scripts/functions/einsum/einsum15.R create mode 100644 src/test/scripts/functions/einsum/einsum15.dml create mode 100644 src/test/scripts/functions/einsum/einsum16.R create mode 100644 src/test/scripts/functions/einsum/einsum16.dml create mode 100644 src/test/scripts/functions/einsum/einsum17.R create mode 100644 src/test/scripts/functions/einsum/einsum17.dml diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java index 8caa8dc3828..07923436cf2 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -34,13 +34,16 @@ import org.apache.sysds.hops.codegen.cplan.CNodeCell; import org.apache.sysds.hops.codegen.cplan.CNodeData; import org.apache.sysds.hops.codegen.cplan.CNodeRow; +import org.apache.sysds.hops.codegen.template.TemplateUtils; import org.apache.sysds.runtime.codegen.*; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysds.runtime.functionobjects.*; import org.apache.sysds.runtime.matrix.data.LibMatrixAgg; +import org.apache.sysds.runtime.matrix.data.LibMatrixMult; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator; import org.apache.sysds.runtime.matrix.operators.AggregateOperator; import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; import org.apache.sysds.runtime.matrix.operators.Operator; @@ -129,16 +132,16 @@ else if(parts[1].length() >= 2){ arrCounter++; } else if(!einc.contractDimsSet.contains(c2)){ - if (c2 == c){ - diagMatrices.add(arrCounter); - } + if(!partsCharactersToIndices.containsKey(c2)) partsCharactersToIndices.put(c2, new ArrayList<>()); partsCharactersToIndices.get(c2).add(arrCounter); - s+=c2; + if (c2 == c) + diagMatrices.add(arrCounter); + else + s += c2; } - i++; } newEquationStringSplit.add(s); @@ -191,6 +194,16 @@ else if(einc.contractDims[i] == CONTRACT_BOTH) { ReorgOperator op = new ReorgOperator(DiagIndex.getDiagIndexFnObject()); MatrixBlock mb = inputs.get(idx); inputs.set(idx, mb.reorgOperations(op, new MatrixBlock(),0,0,0)); + inputsChars.set(idx, String.valueOf(inputsChars.get(idx).charAt(0))); + } + + //make all vetors col vectors + for(int i = 0; i < inputs.size(); i++){ + if(inputs.get(i) != null && inputsChars.get(i).length() == 1 && inputs.get(i).getNumColumns() > 1){ + inputs.get(i).setNumRows(inputs.get(i).getNumColumns()); + inputs.get(i).setNumColumns(1); + inputs.get(i).getDenseBlock().resetNoFill(inputs.get(i).getNumColumns(),1); + } } if(LOG.isTraceEnabled()) for(Character c :partsCharactersToIndices.keySet()){ @@ -249,7 +262,36 @@ else if(einc.contractDims[i] == CONTRACT_BOTH) { res = remMbs.get(0).reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); } }else{ - throw new RuntimeException("Einsum runtime error"); // should not happen + // maybe the leftovers are i,j and result should be ij or ji -> outer multp. + if(remStrings.size() == 2 && remStrings.get(0).length()==1 && remStrings.get(1).length()==1){ + MatrixBlock first; + MatrixBlock second; + + if(remStrings.get(0).charAt(0) == outChar1 && remStrings.get(1).charAt(0) == outChar2){ + first = remMbs.get(0); + second = remMbs.get(1); + }else if(remStrings.get(0).charAt(0) == outChar2 && remStrings.get(1).charAt(0) == outChar1){ + first = remMbs.get(1); + second = remMbs.get(0); + }else{ + throw new RuntimeException("Einsum runtime error: left with 2 vectors that cannot produce final result "+remStrings.get(0)+" , "+remStrings.get(1)); // should not happen + } + if(first.getNumColumns() > 1){ + int r = first.getNumColumns(); + first.setNumRows(r); + first.setNumColumns(1); + first.getDenseBlock().resetNoFill(r,1); + } + if(second.getNumRows() > 1){ + int c = second.getNumRows(); + second.setNumRows(1); + second.setNumColumns(c); + second.getDenseBlock().resetNoFill(1,c); + } + res = LibMatrixMult.matrixMult(first,second, _numThreads); + }else { + throw new RuntimeException("Einsum runtime error, reductions and multiplications finished but the did not produce one result"); // should not happen + } } ec.setMatrixOutput(output.getName(), res); } @@ -299,7 +341,7 @@ private boolean generatePlanAndExecute(HashMap> pa return anyCouldNotDo; } - /* handle situation: ji,ji or ij,ji */ + /* handle situation: ji,ji or ij,ji, j,j */ private boolean multiplyTerms(HashMap> partsCharactersToIndices, ArrayList inputs, ArrayList inputsChars, Character outChar1, Character outChar2 ) { HashMap> stringToIndex = new HashMap<>(); diff --git a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java index 69452a0859e..4f755f8b560 100644 --- a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java +++ b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java @@ -52,6 +52,9 @@ public class EinsumTest extends AutomatedTestBase private static final String TEST_EINSUM12 = TEST_NAME_EINSUM+"12"; private static final String TEST_EINSUM13 = TEST_NAME_EINSUM+"13"; private static final String TEST_EINSUM14 = TEST_NAME_EINSUM+"14"; + private static final String TEST_EINSUM15 = TEST_NAME_EINSUM+"15"; + private static final String TEST_EINSUM16 = TEST_NAME_EINSUM+"16"; + private static final String TEST_EINSUM17 = TEST_NAME_EINSUM+"17"; private static final String TEST_DIR = "functions/einsum/"; private static final String TEST_CLASS_DIR = TEST_DIR + EinsumTest.class.getSimpleName() + "/"; private final static String TEST_CONF = "SystemDS-config-codegen.xml"; @@ -62,7 +65,7 @@ public class EinsumTest extends AutomatedTestBase @Override public void setUp() { TestUtils.clearAssertionInformation(); - for(int i=1; i<=14; i++) + for(int i=1; i<=17; i++) addTestConfiguration( TEST_NAME_EINSUM+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_EINSUM+i, new String[] { String.valueOf(i) }) ); } @Test @@ -116,9 +119,13 @@ public void testCodegenEinsum13CP() { testCodegenIntegration( TEST_EINSUM13, ExecType.CP, true ); } @Test - public void testCodegenEinsum14CP() { - testCodegenIntegration( TEST_EINSUM14, ExecType.CP); - } + public void testCodegenEinsum14CP() { testCodegenIntegration( TEST_EINSUM14, ExecType.CP); } + @Test + public void testCodegenEinsum15CP() { testCodegenIntegration( TEST_EINSUM15, ExecType.CP); } + @Test + public void testCodegenEinsum16CP() { testCodegenIntegration( TEST_EINSUM16, ExecType.CP); } + @Test + public void testCodegenEinsum17CP() { testCodegenIntegration( TEST_EINSUM17, ExecType.CP); } private void testCodegenIntegration( String testname, ExecType instType) { testCodegenIntegration(testname, instType, false); } private void testCodegenIntegration( String testname, ExecType instType, boolean outputScalar ) { diff --git a/src/test/scripts/functions/einsum/einsum15.R b/src/test/scripts/functions/einsum/einsum15.R new file mode 100644 index 00000000000..da967ef3f1f --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum15.R @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") +library("einsum") + +P = as.vector(seq(1,30)); +X = as.vector(seq(1,600)); + +R = einsum("a,c->ac",P,X) + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum15.dml b/src/test/scripts/functions/einsum/einsum15.dml new file mode 100644 index 00000000000..128730eef34 --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum15.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +P = matrix(seq(1,30), 30, 1) +X = matrix(seq(1,600), 600, 1); + +while(FALSE){} + +R = einsum("a,c->ac",P,X) +write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum16.R b/src/test/scripts/functions/einsum/einsum16.R new file mode 100644 index 00000000000..683cd58a7bf --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum16.R @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") +library("einsum") + +X = as.vector(seq(1,600)); + +R = einsum("a,a->a",X,X) + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum16.dml b/src/test/scripts/functions/einsum/einsum16.dml new file mode 100644 index 00000000000..837db17a018 --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum16.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(seq(1,600), 600, 1); + +while(FALSE){} + +R = einsum("a,a->a",X,X) +write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum17.R b/src/test/scripts/functions/einsum/einsum17.R new file mode 100644 index 00000000000..48d50e0acf4 --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum17.R @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") +library("einsum") + +X = as.vector(seq(1,600)); + +# R = P * sum(X) +R = einsum("a,a->a",X,X) + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum17.dml b/src/test/scripts/functions/einsum/einsum17.dml new file mode 100644 index 00000000000..1053df0539f --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum17.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(seq(1,600), 1, 600); + +while(FALSE){} + +R = einsum("a,a->a",X,X) +write(R, $1) From 1b765cc8f5aef22767b881cbd059e7b2f1e7d90e Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Mon, 30 Jun 2025 20:14:39 +0200 Subject: [PATCH 16/28] moved einsum information extraction to EinsumContext function --- .../instructions/cp/EinsumCPInstruction.java | 218 ++++++----------- .../instructions/cp/EinsumContext.java | 230 +++++++----------- 2 files changed, 167 insertions(+), 281 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java index 07923436cf2..d441efb774c 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -34,16 +34,12 @@ import org.apache.sysds.hops.codegen.cplan.CNodeCell; import org.apache.sysds.hops.codegen.cplan.CNodeData; import org.apache.sysds.hops.codegen.cplan.CNodeRow; -import org.apache.sysds.hops.codegen.template.TemplateUtils; import org.apache.sysds.runtime.codegen.*; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysds.runtime.functionobjects.*; -import org.apache.sysds.runtime.matrix.data.LibMatrixAgg; import org.apache.sysds.runtime.matrix.data.LibMatrixMult; import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator; import org.apache.sysds.runtime.matrix.operators.AggregateOperator; import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; import org.apache.sysds.runtime.matrix.operators.Operator; @@ -69,10 +65,6 @@ public EinsumCPInstruction(Operator op, String opcode, String istr, CPOperand ou Logger.getLogger(EinsumCPInstruction.class).setLevel(Level.TRACE); } - private static final int CONTRACT_LEFT = 1; - private static final int CONTRACT_RIGHT = 2; - private static final int CONTRACT_BOTH = 3; - private EinsumContext einc = null; @Override @@ -89,113 +81,18 @@ public void processInstruction(ExecutionContext ec) { } } - EinsumContext einc = getEinsumContext(eqStr,inputs); - this.einc = einc; + EinsumContext einc = getEinsumContext(eqStr, inputs); - String[] parts = einc.equationString.split("->"); + this.einc = einc; + String resultString = einc.outChar2 != null ? String.valueOf(einc.outChar1) + einc.outChar2 : einc.outChar1 != null ? String.valueOf(einc.outChar1) : null; if( LOG.isDebugEnabled() ) LOG.trace("outrows:"+einc.outRows+", outcols:"+einc.outCols); - Character outChar1 = null; - Character outChar2 = null; + ArrayList inputsChars = einc.newEquationStringSplit; - if(parts.length == 1){ } - else if(parts[1].length() >= 2){ - outChar1 = parts[1].charAt(0); - outChar2 = parts[1].charAt(1); - }else if (parts[1].length()==1){ - outChar1 = parts[1].charAt(0); - } - HashMap> partsCharactersToIndices = new HashMap<>(); - ArrayList newEquationStringSplit = new ArrayList(); - - ArrayList diagMatrices = new ArrayList<>(); - int arrCounter=0; - for(int i=0;i()); - - partsCharactersToIndices.get(c).add(arrCounter); - s+=c; - } - if(i+1()); - - partsCharactersToIndices.get(c2).add(arrCounter); - if (c2 == c) - diagMatrices.add(arrCounter); - else - s += c2; - } - i++; - } - newEquationStringSplit.add(s); - } - ArrayList inputsChars = newEquationStringSplit; - LOG.trace(String.join(",",newEquationStringSplit)); - for(int i=0;i a = partsCharactersToIndices.get(c); + if(LOG.isTraceEnabled()) for(Character c : einc.partsCharactersToIndices.keySet()){ + ArrayList a = einc.partsCharactersToIndices.get(c); LOG.trace(c+" count= "+a.size()); } @@ -239,20 +136,20 @@ else if(einc.contractDims[i] == CONTRACT_BOTH) { } } - boolean anyCouldNotDo = FORCE_CELL_TPL ? true : generatePlanAndExecute(partsCharactersToIndices, inputs, inputsChars, outChar1, outChar2); // information to do cell tpl for remaining ones + boolean needToDoCellTemplate = FORCE_CELL_TPL ? true : generatePlanAndExecute(inputs, einc); - if (!anyCouldNotDo){ + if (!needToDoCellTemplate){ //check if any operations to do that were not-output dimension summations: List remStrings = inputsChars.stream() .filter(Objects::nonNull).toList(); List remMbs = inputs.stream() .filter(Objects::nonNull).toList(); MatrixBlock res; - if(remStrings.size() == 1){ + if(remStrings.size() == 1) { String s = remStrings.get(0); - if(s.equals(parts[1])){ + if(s.equals(resultString)){ res=remMbs.get(0); - }else if(s.charAt(0)==s.charAt(1)) { + }else if(s.charAt(0) == s.charAt(1)) { // diagonal needed ReorgOperator op = new ReorgOperator(DiagIndex.getDiagIndexFnObject()); res= remMbs.get(0).reorgOperations(op, new MatrixBlock(),0,0,0); @@ -261,16 +158,16 @@ else if(einc.contractDims[i] == CONTRACT_BOTH) { ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); res = remMbs.get(0).reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); } - }else{ + } else{ // maybe the leftovers are i,j and result should be ij or ji -> outer multp. if(remStrings.size() == 2 && remStrings.get(0).length()==1 && remStrings.get(1).length()==1){ MatrixBlock first; MatrixBlock second; - if(remStrings.get(0).charAt(0) == outChar1 && remStrings.get(1).charAt(0) == outChar2){ + if(remStrings.get(0).charAt(0) == einc.outChar1 && remStrings.get(1).charAt(0) == einc.outChar2){ first = remMbs.get(0); second = remMbs.get(1); - }else if(remStrings.get(0).charAt(0) == outChar2 && remStrings.get(1).charAt(0) == outChar1){ + }else if(remStrings.get(0).charAt(0) == einc.outChar2 && remStrings.get(1).charAt(0) == einc.outChar1){ first = remMbs.get(1); second = remMbs.get(0); }else{ @@ -307,11 +204,11 @@ else if(einc.contractDims[i] == CONTRACT_BOTH) { } } ArrayList summingChars = new ArrayList(); - for (Character c : partsCharactersToIndices.keySet()) { - if (c != outChar1 && c != outChar2) summingChars.add(c); + for (Character c : einc.partsCharactersToIndices.keySet()) { + if (c != einc.outChar1 && c != einc.outChar2) summingChars.add(c); } - MatrixBlock res = computeCellSummation(mbs, chars, parts[1], einc.charToDimensionSizeInt, summingChars, einc.outRows, einc.outCols); + MatrixBlock res = computeCellSummation(mbs, chars, resultString, einc.charToDimensionSizeInt, summingChars, einc.outRows, einc.outCols); if (einc.outRows == 1 && einc.outCols == 1) ec.setScalarOutput(output.getName(), new DoubleObject(res.get(0, 0))); @@ -321,20 +218,56 @@ else if(einc.contractDims[i] == CONTRACT_BOTH) { releaseMatrixInputs(ec); } + private void contractDimensionsAndComputeDiagonals(EinsumContext einc, ArrayList inputs) { + for(int i = 0; i< einc.contractDims.length; i++){ + //AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(),Types.CorrectionLocationType.LASTCOLUMN); + AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); + + if(einc.diagonalInputs[i]){ + ReorgOperator op = new ReorgOperator(DiagIndex.getDiagIndexFnObject()); + inputs.set(i, inputs.get(i).reorgOperations(op, new MatrixBlock(),0,0,0)); + } + switch (einc.contractDims[i]){ + case EinsumContext.CONTRACT_BOTH: { + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads); + MatrixBlock res = new MatrixBlock(1, 1, false); + inputs.get(i).aggregateUnaryOperations(aggun, res, 0, null); + inputs.set(i, res); + break; + } + case EinsumContext.CONTRACT_RIGHT: { + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads); + MatrixBlock res = new MatrixBlock(inputs.get(i).getNumRows(), 1, false); + inputs.get(i).aggregateUnaryOperations(aggun, res, 0, null); + inputs.set(i, res); + break; + } + case EinsumContext.CONTRACT_LEFT: { + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), _numThreads); + MatrixBlock res = new MatrixBlock(inputs.get(i).getNumColumns(), 1, false); + inputs.get(i).aggregateUnaryOperations(aggun, res, 0, null); + inputs.set(i, res); + break; + } + } + } + } + private void releaseMatrixInputs(ExecutionContext ec){ for (CPOperand input : _in) if(input.getDataType()==DataType.MATRIX) ec.releaseMatrixInput(input.getName()); } - private boolean generatePlanAndExecute(HashMap> partsCharactersToIndices, ArrayList inputs, ArrayList inputsChars, Character outChar1, Character outChar2) { + // returns true if there are elements that appear more than 2 times and cannot be summed + private boolean generatePlanAndExecute(ArrayList inputs, EinsumContext einc) { boolean anyCouldNotDo; - boolean didAnything = false; + boolean didAnything = false; // maybe multiplication will make it summable do { - anyCouldNotDo = sumCharactersWherePossible(partsCharactersToIndices, inputs, inputsChars, outChar1, outChar2); + anyCouldNotDo = sumCharactersWherePossible(einc.partsCharactersToIndices, inputs, einc.newEquationStringSplit, einc.outChar1, einc.outChar2); didAnything = false; - if(inputsChars.stream().filter(Objects::nonNull).count() > 1) - didAnything = multiplyTerms(partsCharactersToIndices, inputs, inputsChars, outChar1, outChar2); + if(einc.newEquationStringSplit.stream().filter(Objects::nonNull).count() > 1) + didAnything = multiplyTerms(einc.partsCharactersToIndices, inputs, einc.newEquationStringSplit, einc.outChar1, einc.outChar2); } while(didAnything); @@ -448,7 +381,8 @@ private boolean sumCharactersWherePossible(HashMap for (int i = 0; i < newS.length(); i++) { // for each char in string, add pointer to newly created entry char c = newS.charAt(i); - partsCharactersToIndices.get(c).add(inputs.size() - 1); + if(partsCharactersToIndices.containsKey(c)) + partsCharactersToIndices.get(c).add(inputs.size() - 1); } partsCharactersToIndices.remove(sumC); } @@ -467,9 +401,6 @@ private enum SumOperation { } private Pair computeRowSummation(List toSum, ArrayList inputs, List inputsChars, Character sumChar) { - return computeRowSummation(toSum,inputs,inputsChars, null, sumChar); - } - private Pair computeRowSummation(List toSum, ArrayList inputs, List inputsChars, Double scalar, Character sumChar) { if(toSum.size() != 2){ return null; @@ -572,44 +503,44 @@ else if(s1.charAt(1) == s2.charAt(0)){ case Ba_a: throw new NotImplementedException(); case Ba_aC: { - out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_MATRIXMULT, SpoofRowwise.RowType.NO_AGG_B1, Long.valueOf( second.getNumColumns()), scalar); + out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_MATRIXMULT, SpoofRowwise.RowType.NO_AGG_B1, Long.valueOf( second.getNumColumns())); break; } case Ba_Ca: { ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); second = second.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_MATRIXMULT, SpoofRowwise.RowType.NO_AGG_B1, Long.valueOf(second.getNumColumns()), scalar); + out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_MATRIXMULT, SpoofRowwise.RowType.NO_AGG_B1, Long.valueOf(second.getNumColumns())); break; } case aB_a: case aB_aC: { - out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_OUTERMULT_ADD, SpoofRowwise.RowType.COL_AGG_B1_T, Long.valueOf( second.getNumColumns()),scalar); + out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_OUTERMULT_ADD, SpoofRowwise.RowType.COL_AGG_B1_T, Long.valueOf( second.getNumColumns())); break; } case a_a: - out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_MULT, SpoofRowwise.RowType.NO_AGG,null, scalar); + out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_MULT, SpoofRowwise.RowType.NO_AGG,null); break; case aB_aB: { ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); first = first.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); second = second.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.DOT_PRODUCT, SpoofRowwise.RowType.COL_AGG, null, scalar); + out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.DOT_PRODUCT, SpoofRowwise.RowType.COL_AGG, null); break; } case Ba_Ba: { - out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.DOT_PRODUCT, SpoofRowwise.RowType.ROW_AGG, null, scalar); + out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.DOT_PRODUCT, SpoofRowwise.RowType.ROW_AGG, null); break; } case aB_Ba: { ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); first = first.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_MATRIXMULT, SpoofRowwise.RowType.ROW_AGG,null, scalar); + out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_MATRIXMULT, SpoofRowwise.RowType.ROW_AGG,null); break; } case Ba_aB: { ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); second = second.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.DOT_PRODUCT, SpoofRowwise.RowType.ROW_AGG,null, scalar); + out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.DOT_PRODUCT, SpoofRowwise.RowType.ROW_AGG,null); break; } @@ -618,9 +549,7 @@ else if(s1.charAt(1) == s2.charAt(0)){ } return Pair.of(out , resS); } - private MatrixBlock getRowCodegenMatrixBlock(MatrixBlock first, MatrixBlock second, CNodeBinary.BinType binaryType, SpoofRowwise.RowType rowType){ - return getRowCodegenMatrixBlock(first, second, binaryType,rowType,null, null); - } + private MatrixBlock getCodegenElemwiseMult(ArrayList mbs) { ArrayList cnodeIn = new ArrayList<>(); @@ -646,7 +575,7 @@ private MatrixBlock getCodegenElemwiseMult(ArrayList mbs) { MatrixBlock out = op.execute(mbs, scalars, mb, _numThreads); return out; } - private MatrixBlock getRowCodegenMatrixBlock(MatrixBlock first, MatrixBlock second, CNodeBinary.BinType binaryType, SpoofRowwise.RowType rowType, Long secondDim, Double scalar) { + private MatrixBlock getRowCodegenMatrixBlock(MatrixBlock first, MatrixBlock second, CNodeBinary.BinType binaryType, SpoofRowwise.RowType rowType, Long secondDim) { ArrayList thisInputs = new ArrayList<>(Arrays.asList(first, second)); ArrayList cnodeIn = new ArrayList<>(); @@ -671,7 +600,6 @@ private MatrixBlock getRowCodegenMatrixBlock(MatrixBlock first, MatrixBlock seco MatrixBlock mb = new MatrixBlock(); ArrayList scalars = new ArrayList<>(); - if(scalar != null) scalars.add(new DoubleObject(scalar)); MatrixBlock out = op.execute(thisInputs, scalars, mb, _numThreads); return out; } @@ -720,7 +648,7 @@ private MatrixBlock computeCellSummation(ArrayList inputs, List x != null); - ; + String itVar0 = cnode.createVarname(); String outVar = itVar0; if (needsSumming) { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumContext.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumContext.java index 81f690eccb9..927fb208363 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumContext.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumContext.java @@ -29,185 +29,143 @@ public class EinsumContext { public Integer outRows; public Integer outCols; + public Character outChar1; + public Character outChar2; public HashMap charToDimensionSizeInt; public String equationString; - public Integer[] contractDims; - public Integer[] summingDims; + public boolean[] diagonalInputs; public HashSet summingChars; public HashSet contractDimsSet; - + public static final int CONTRACT_LEFT = 1; + public static final int CONTRACT_RIGHT = 2; + public static final int CONTRACT_BOTH = 3; + public int[] contractDims; + public ArrayList newEquationStringSplit; + public HashMap> partsCharactersToIndices; // for each character, this tells in which inputs it appears + + private EinsumContext(){}; public static EinsumContext getEinsumContext(String eqStr, ArrayList inputs){ EinsumContext res = new EinsumContext(); res.equationString = eqStr; - int i = 0; res.charToDimensionSizeInt = new HashMap(); - Iterator it = inputs.iterator(); - MatrixBlock curArr = it.next(); - int arrSizeIterator=0; HashSet summingChars = new HashSet<>(); - Integer[] contractDims = new Integer[inputs.size()];//0==nothing, 1 = right, 2=left, 3 = both - Integer[] summingDims = new Integer[inputs.size()];//0/null==nothing, 1 = right, 2=left, 3 = both + int[] contractDims = new int[inputs.size()]; + boolean[] diagonalInputs = new boolean[inputs.size()]; // all false by default HashSet contractDimsSet = new HashSet(); + HashMap> partsCharactersToIndices = new HashMap<>(); + ArrayList newEquationStringSplit = new ArrayList(); - - int arrIt = 0; - for (i = 0; true; i++){ + Iterator it = inputs.iterator(); + MatrixBlock curArr = it.next(); + int arrSizeIterator=0; + int arrayIterator = 0; + int i; + for (i = 0; true; i++) { char c = eqStr.charAt(i); - if(c=='-'){ + if(c == '-'){ i+=2; break; } - if(c==','){ - - arrIt++; + if(c == ','){ + arrayIterator++; curArr = it.next(); arrSizeIterator = 0; } - else{ - if (res.charToDimensionSizeInt.containsKey(c)){ - // just check if dims match! - if(arrSizeIterator==0) - assert (res.charToDimensionSizeInt.get(c) == curArr.getNumRows()); - else if(arrSizeIterator==1) - assert (res.charToDimensionSizeInt.get(c) == curArr.getNumColumns()); - + if (res.charToDimensionSizeInt.containsKey(c)) { // sanity check if dims match, this is already checked at validation + if(arrSizeIterator == 0 && res.charToDimensionSizeInt.get(c) != curArr.getNumRows()) + throw new RuntimeException("Einsum: character "+c+" has multiple conflicting sizes"); + else if(arrSizeIterator == 1 && res.charToDimensionSizeInt.get(c) != curArr.getNumColumns()) + throw new RuntimeException("Einsum: character "+c+" has multiple conflicting sizes"); summingChars.add(c); - - }else{ - if(arrSizeIterator==0) + } else { + if(arrSizeIterator == 0) res.charToDimensionSizeInt.put(c, curArr.getNumRows()); - else if(arrSizeIterator==1) + else if(arrSizeIterator == 1) res.charToDimensionSizeInt.put(c, curArr.getNumColumns()); } + arrSizeIterator++; } - - //Process char } - int rem = eqStr.length() - i; - arrSizeIterator = 0; - if (rem ==0){ - res.outRows=1; - res.outCols=1; - - arrIt=0; - for (i = 0; true; i++) { - char c = eqStr.charAt(i); - if (c == '-') { - break; - } - if (c == ',') { - arrIt++; - arrSizeIterator = 0; - continue; - } + int numOfRemainingChars = eqStr.length() - i; - if(summingChars.contains(c)){ + if (numOfRemainingChars > 2) + throw new RuntimeException("Einsum: dim > 2 not supported"); - }else{ - contractDimsSet.add(c); - if(contractDims[arrIt]==null){ - contractDims[arrIt] = arrSizeIterator +1; + arrSizeIterator = 0; - }else { - contractDims[arrIt] += arrSizeIterator + 1; - } - } - arrSizeIterator++; + Character outChar1 = numOfRemainingChars > 0 ? eqStr.charAt(i) : null; + Character outChar2 = numOfRemainingChars > 1 ? eqStr.charAt(i+1) : null; + res.outRows=(numOfRemainingChars > 0 ? res.charToDimensionSizeInt.get(outChar1) : 1); + res.outCols=(numOfRemainingChars > 1 ? res.charToDimensionSizeInt.get(outChar2) : 1); + arrayIterator=0; + for (i = 0; true; i++) { + char c = eqStr.charAt(i); + if (c == '-') { + break; } - }else if (rem == 1){ - char c1= eqStr.charAt(i); - res.outRows=(res.charToDimensionSizeInt.get(c1)); - - res.outCols=1; - arrIt=0; - for (i = 0; true; i++) { - char c = eqStr.charAt(i); - if (c == '-') { - break; - } - if (c == ',') { - arrIt++; - arrSizeIterator = 0; - continue; - } - - if(summingChars.contains(c)){ - - if(summingDims[arrIt] == null){ - summingDims[arrIt]=arrSizeIterator +1; // it=0->add 1, it==1->add 2 - }else{ - summingDims[arrIt]+=arrSizeIterator +1; // it=0->add 1, it==1->add 2 - - } - }else if(c==c1){ - // this dim is remaining - }else{ - contractDimsSet.add(c); - - if(contractDims[arrIt]==null){ - contractDims[arrIt]=arrSizeIterator +1; - - }else { - contractDims[arrIt] += arrSizeIterator + 1; - } - - } - arrSizeIterator++; - + if (c == ',') { + arrayIterator++; + arrSizeIterator = 0; + continue; } - }else if (rem==2){ - char c1= eqStr.charAt(i); - char c2= eqStr.charAt(i+1); - res.outRows=(res.charToDimensionSizeInt.get(c1)); - res.outCols=(res.charToDimensionSizeInt.get(c2)); - - arrIt=0; - for (i = 0; true; i++) { - char c = eqStr.charAt(i); - if (c == '-') { - break; - } - if (c == ',') { - arrIt++; - arrSizeIterator = 0; - continue; + String s = ""; - } + if(summingChars.contains(c)) { + s+=c; + if(!partsCharactersToIndices.containsKey(c)) + partsCharactersToIndices.put(c, new ArrayList<>()); + partsCharactersToIndices.get(c).add(arrayIterator); + } + else if((outChar1 != null && c == outChar1) || (outChar2 != null && c == outChar2)) { + s+=c; + } + else { + contractDimsSet.add(c); + contractDims[arrayIterator] = CONTRACT_LEFT; + } - if(summingChars.contains(c)){ - if(summingDims[arrIt] == null){ - summingDims[arrIt]=arrSizeIterator +1; // it=0->add 1, it==1->add 2 - }else{ - summingDims[arrIt]+=arrSizeIterator +1; // it=0->add 1, it==1->add 2 + if(i + 1 < eqStr.length()) { // process next character together + char c2 = eqStr.charAt(i + 1); + i++; + if (c2 == '-') { newEquationStringSplit.add(s); break;} + if (c2 == ',') { arrayIterator++; newEquationStringSplit.add(s); continue; } + if (c2 == c){ + diagonalInputs[arrayIterator] = true; + if (contractDims[arrayIterator] == CONTRACT_LEFT) contractDims[arrayIterator] = CONTRACT_BOTH; + } + else{ + if(summingChars.contains(c2)) { + s+=c2; + if(!partsCharactersToIndices.containsKey(c2)) + partsCharactersToIndices.put(c2, new ArrayList<>()); + partsCharactersToIndices.get(c2).add(arrayIterator); } - }else if(c==c1 || c==c2){ - // this dim is remaining - }else{ - contractDimsSet.add(c); - - if(contractDims[arrIt]==null){ - contractDims[arrIt]=arrSizeIterator +1; - - }else { - contractDims[arrIt] += arrSizeIterator + 1; + else if((outChar1 != null && c2 == outChar1) || (outChar2 != null && c2 == outChar2)) { + s+=c2; + } + else { + contractDimsSet.add(c2); + contractDims[arrayIterator] += CONTRACT_RIGHT; } } - arrSizeIterator++; - } - }else{ - throw new RuntimeException("output dim > 2 not supported for now"); + newEquationStringSplit.add(s); + arrSizeIterator++; } - res.contractDims=contractDims; - res.contractDimsSet = contractDimsSet; - res.summingDims=summingDims; + res.contractDims = contractDims; + res.contractDimsSet = contractDimsSet; + res.diagonalInputs = diagonalInputs; res.summingChars = summingChars; + res.outChar1 = outChar1; + res.outChar2 = outChar2; + res.newEquationStringSplit = newEquationStringSplit; + res.partsCharactersToIndices = partsCharactersToIndices; return res; } } From 12564d7168ea0daaf9a3c6d7ee92bdc7c99986bf Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Mon, 30 Jun 2025 20:23:16 +0200 Subject: [PATCH 17/28] add comment --- .../apache/sysds/runtime/instructions/cp/EinsumContext.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumContext.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumContext.java index 927fb208363..bd4f6332405 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumContext.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumContext.java @@ -61,6 +61,7 @@ public static EinsumContext getEinsumContext(String eqStr, ArrayList 2) @@ -102,6 +104,7 @@ else if(arrSizeIterator == 1) res.outCols=(numOfRemainingChars > 1 ? res.charToDimensionSizeInt.get(outChar2) : 1); arrayIterator=0; + // second iteration through string: collect remaining information for (i = 0; true; i++) { char c = eqStr.charAt(i); if (c == '-') { From a0b160846a815a51cfd8d72979122a5b4c592347 Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Sat, 5 Jul 2025 00:33:29 +0200 Subject: [PATCH 18/28] extracted einsumEquationValidation to separate class, it is now called also by refresh size method --- .../java/org/apache/sysds/hops/NaryOp.java | 22 +-- .../parser/BuiltinFunctionExpression.java | 119 ++--------------- .../cp => einsum}/EinsumContext.java | 2 +- .../einsum/EinsumEquationValidator.java | 125 ++++++++++++++++++ .../instructions/cp/EinsumCPInstruction.java | 5 +- 5 files changed, 148 insertions(+), 125 deletions(-) rename src/main/java/org/apache/sysds/runtime/{instructions/cp => einsum}/EinsumContext.java (99%) create mode 100644 src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java diff --git a/src/main/java/org/apache/sysds/hops/NaryOp.java b/src/main/java/org/apache/sysds/hops/NaryOp.java index aff2572a7e8..6962beadcbc 100644 --- a/src/main/java/org/apache/sysds/hops/NaryOp.java +++ b/src/main/java/org/apache/sysds/hops/NaryOp.java @@ -26,6 +26,7 @@ import org.apache.sysds.lops.Lop; import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.lops.Nary; +import org.apache.sysds.runtime.einsum.EinsumEquationValidator; import org.apache.sysds.runtime.meta.DataCharacteristics; import org.apache.sysds.runtime.meta.MatrixCharacteristics; @@ -236,21 +237,12 @@ public void refreshSizeInformation() { setDim2(1); break; case EINSUM: - String eqString = ((LiteralOp) _input.get(0)).getStringValue(); - if (eqString.charAt(eqString.length()-1)=='>'){ - setDataType(DataType.SCALAR); - setDim1(0); - setDim2(0); - break; - } - String outStr = eqString.split("->")[1]; - int count = 0; - for (int i = 0; i < outStr.length(); i++){ - if(outStr.charAt(i) != ' ') count++; - } - // not true: todo later - set correct out size - setDim1( HopRewriteUtils.getMaxInputDim(this, true)); - setDim2(count==1 ? 1 : HopRewriteUtils.getMaxInputDim(this, false)); + String equationString = ((LiteralOp) _input.get(0)).getStringValue(); + var dims = EinsumEquationValidator.validateEinsumEquationAndReturnDimensions(equationString, this.getInput().subList(1, this.getInput().size())); + + setDim1(dims.getLeft()); + setDim2(dims.getMiddle()); + setDataType(dims.getRight()); break; case PRINTF: case EVAL: diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java index 166fad4bc09..4f269092724 100644 --- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java @@ -24,7 +24,6 @@ import java.util.HashMap; import java.util.HashSet; import java.util.LinkedList; -import java.util.Iterator; import org.antlr.v4.runtime.ParserRuleContext; import org.apache.commons.lang3.ArrayUtils; @@ -37,6 +36,7 @@ import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.parser.LanguageException.LanguageErrorCodes; +import org.apache.sysds.runtime.einsum.EinsumEquationValidator; import org.apache.sysds.runtime.meta.MatrixCharacteristics; import org.apache.sysds.runtime.util.DnnUtils; import org.apache.sysds.runtime.util.UtilFunctions; @@ -2110,13 +2110,10 @@ private void validateEinsum(DataIdentifier output){ raiseValidateError("Einsum: first argument has to be equation str", false, LanguageErrorCodes.INVALID_PARAMETERS); - String eq_string = ((StringIdentifier)getFirstExpr()).getValue(); + String equationString = ((StringIdentifier)getFirstExpr()).getValue(); - if (eq_string.length() == 0) raiseValidateError("Einsum: equation str too short", false, LanguageErrorCodes.INVALID_PARAMETERS); - if (eq_string.charAt(0) == '-' || eq_string.charAt(0) == ',') raiseValidateError("Einsum: equation str invalid", false, LanguageErrorCodes.INVALID_PARAMETERS); - - String[] eqStringParts = eq_string.split("->"); // length 2 if "...->..." , length 1 if "...->" - boolean isResultScalar = eqStringParts.length == 1; + if (equationString.length() == 0) raiseValidateError("Einsum: equation str too short", false, LanguageErrorCodes.INVALID_PARAMETERS); + if (equationString.charAt(0) == '-' || equationString.charAt(0) == ',') raiseValidateError("Einsum: equation str invalid", false, LanguageErrorCodes.INVALID_PARAMETERS); Expression[] expressions = getAllExpr(); boolean allDimsKnown = true; @@ -2131,110 +2128,20 @@ private void validateEinsum(DataIdentifier output){ matrixBlocks.add((expressions[i].getOutput())); } - StringBuilder newEqString = new StringBuilder(); - - if(allDimsKnown) { // validate dimension sizes as well - HashMap charToDimensionSize = new HashMap<>(); - Iterator it = matrixBlocks.iterator(); - Identifier currArr = it.next(); - int arrSizeIterator = 0; - int numberOfMatrices = 1; - for (int i = 0; i < eqStringParts[0].length(); i++) { - char c = eq_string.charAt(i); - if(c==' ') continue; - newEqString.append(c); - if(c==','){ - if(!it.hasNext()) - raiseValidateError("Einsum: Provided less operands than specified in equation str", - false, LanguageErrorCodes.INVALID_PARAMETERS); - currArr = it.next(); - arrSizeIterator = 0; - numberOfMatrices++; - } else{ - long thisCharDimension = arrSizeIterator == 0 ? currArr.getDim1() : currArr.getDim2(); - if (charToDimensionSize.containsKey(c)){ - if (charToDimensionSize.get(c) != thisCharDimension) - raiseValidateError("Einsum: Character '" + c + "' expected to be dim " + charToDimensionSize.get(c) + ", but found " + thisCharDimension, - false, LanguageErrorCodes.INVALID_PARAMETERS); - }else{ - charToDimensionSize.put(c, thisCharDimension); - } - arrSizeIterator++; - } - } - if (getAllExpr().length - 1 > numberOfMatrices) - raiseValidateError("Einsum: Provided more operands than specified in equation str", - false, LanguageErrorCodes.INVALID_PARAMETERS); - newEqString.append("->"); + if(allDimsKnown){ + var dims = EinsumEquationValidator.validateEinsumEquationAndReturnDimensions(equationString, matrixBlocks); - if (isResultScalar){ - output.setDataType(DataType.SCALAR); - output.setDimensions(-1, -1); - }else { - int numberOfOutDimensions = 0; - Character dim1Char = null; - long dim1 = 1; - long dim2 = 1; - for (int i = 0; i < eqStringParts[1].length(); i++) { - char c = eqStringParts[1].charAt(i); - if (c == ' ') continue; - newEqString.append(c); - if (numberOfOutDimensions == 0) { - dim1Char = c; - dim1 = charToDimensionSize.get(c); - } else { - if(c==dim1Char) raiseValidateError("Einsum: output character "+c+" provided multiple times",false, LanguageErrorCodes.INVALID_PARAMETERS); - dim2 = charToDimensionSize.get(c); - } - numberOfOutDimensions++; - } - if (numberOfOutDimensions > 2) { - raiseValidateError("Einsum: output matrices with with no. dims > 2 not supported",false, LanguageErrorCodes.INVALID_PARAMETERS); - } else { - output.setDataType(DataType.MATRIX); - output.setDimensions(dim1, dim2); - } - } - } else { // dimensions unknown - int numberOfMatrices = 1; - for (int i = 0; i < eqStringParts[0].length(); i++) { - char c = eqStringParts[0].charAt(i); - if(c == ' ') continue; - newEqString.append(c); - if(c == ',') - numberOfMatrices++; - } - checkNumParameters(numberOfMatrices+1); - newEqString.append("->"); - - if(isResultScalar){ - output.setDataType(DataType.SCALAR); - output.setDimensions(-1, -1); - }else { - int numberOfDimensions = 0; - Character dim1Char = null; - for (int i = 0; i < eqStringParts[1].length(); i++) { - char c = eqStringParts[i].charAt(i); - if(c == ' ') continue; - newEqString.append(c); - numberOfDimensions++; - if (numberOfDimensions == 1 && c == dim1Char) - raiseValidateError("Einsum: output character "+c+" provided multiple times",false, LanguageErrorCodes.INVALID_PARAMETERS); - dim1Char = c; - } + output.setDataType(dims.getRight()); + output.setDimensions(dims.getLeft(), dims.getMiddle()); + }else{ + DataType dataType = EinsumEquationValidator.validateEinsumEquationNoDimensions(equationString, _args.length - 1); - if (numberOfDimensions > 2) { - raiseValidateError("Einsum: output matrices with with no. dims > 2 not supported", - false, LanguageErrorCodes.INVALID_PARAMETERS); - } else { - output.setDataType(DataType.MATRIX); - output.setDimensions(-1, -1); - } - } + output.setDataType(dataType); + output.setDimensions(-1l, -1l); } + output.setValueType(ValueType.FP64); output.setBlocksize(getSecondExpr().getOutput().getBlocksize()); - ((StringIdentifier) getFirstExpr()).setValue(newEqString.toString()); } private void setBinaryOutputProperties(DataIdentifier output) { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumContext.java b/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java similarity index 99% rename from src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumContext.java rename to src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java index bd4f6332405..723dd2fa297 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumContext.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java @@ -17,7 +17,7 @@ * under the License. */ -package org.apache.sysds.runtime.instructions.cp; +package org.apache.sysds.runtime.einsum; import org.apache.sysds.runtime.matrix.data.MatrixBlock; diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java b/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java new file mode 100644 index 00000000000..7a2661c10c6 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java @@ -0,0 +1,125 @@ +package org.apache.sysds.runtime.einsum; + +import org.apache.commons.lang3.tuple.Triple; +import org.apache.sysds.common.Types; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.parser.Identifier; +import org.apache.sysds.parser.LanguageException; +import org.apache.sysds.parser.ParseInfo; + +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; + +public class EinsumEquationValidator { + + public static Triple validateEinsumEquationAndReturnDimensions(String equationString, List expressionsOrIdentifiers) throws LanguageException { + String[] eqStringParts = equationString.split("->"); // length 2 if "...->..." , length 1 if "...->" + boolean isResultScalar = eqStringParts.length == 1; + + if(expressionsOrIdentifiers == null) + throw new RuntimeException("Einsum: called validateEinsumAndReturnDimensions with null list"); + + HashMap charToDimensionSize = new HashMap<>(); + Iterator it = expressionsOrIdentifiers.iterator(); + HopOrIdentifier currArr = it.next(); + int arrSizeIterator = 0; + int numberOfMatrices = 1; + for (int i = 0; i < eqStringParts[0].length(); i++) { + char c = equationString.charAt(i); + if(c==' ') continue; + if(c==','){ + if(!it.hasNext()) + throw new LanguageException("Einsum: Provided less operands than specified in equation str"); + currArr = it.next(); + arrSizeIterator = 0; + numberOfMatrices++; + } else{ + long thisCharDimension = getThisCharDimension(currArr, arrSizeIterator); + if (charToDimensionSize.containsKey(c)){ + if (charToDimensionSize.get(c) != thisCharDimension) + throw new LanguageException("Einsum: Character '" + c + "' expected to be dim " + charToDimensionSize.get(c) + ", but found " + thisCharDimension); + }else{ + charToDimensionSize.put(c, thisCharDimension); + } + arrSizeIterator++; + } + } + if (expressionsOrIdentifiers.size() - 1 > numberOfMatrices) + throw new LanguageException("Einsum: Provided more operands than specified in equation str"); + + if (isResultScalar) + return Triple.of(-1l,-1l, Types.DataType.SCALAR); + + int numberOfOutDimensions = 0; + Character dim1Char = null; + long dim1 = 1; + long dim2 = 1; + for (int i = 0; i < eqStringParts[1].length(); i++) { + char c = eqStringParts[1].charAt(i); + if (c == ' ') continue; + if (numberOfOutDimensions == 0) { + dim1Char = c; + dim1 = charToDimensionSize.get(c); + } else { + if(c==dim1Char) throw new LanguageException("Einsum: output character "+c+" provided multiple times"); + dim2 = charToDimensionSize.get(c); + } + numberOfOutDimensions++; + } + if (numberOfOutDimensions > 2) { + throw new LanguageException("Einsum: output matrices with with no. dims > 2 not supported"); + } else { + return Triple.of(dim1, dim2, Types.DataType.MATRIX); + } + } + + public static Types.DataType validateEinsumEquationNoDimensions(String equationString, int numberOfMatrixInputs) throws LanguageException { + String[] eqStringParts = equationString.split("->"); // length 2 if "...->..." , length 1 if "...->" + boolean isResultScalar = eqStringParts.length == 1; + + int numberOfMatrices = 1; + for (int i = 0; i < eqStringParts[0].length(); i++) { + char c = eqStringParts[0].charAt(i); + if(c == ' ') continue; + if(c == ',') + numberOfMatrices++; + } + if(numberOfMatrixInputs != numberOfMatrices){ + throw new LanguageException("Einsum: Invalid number of parameters, given: " + numberOfMatrixInputs + ", expected: " + numberOfMatrices); + } + + if(isResultScalar){ + return Types.DataType.SCALAR; + }else { + int numberOfDimensions = 0; + Character dim1Char = null; + for (int i = 0; i < eqStringParts[1].length(); i++) { + char c = eqStringParts[i].charAt(i); + if(c == ' ') continue; + numberOfDimensions++; + if (numberOfDimensions == 1 && c == dim1Char) + throw new LanguageException("Einsum: output character "+c+" provided multiple times"); + dim1Char = c; + } + + if (numberOfDimensions > 2) { + throw new LanguageException("Einsum: output matrices with with no. dims > 2 not supported"); + } else { + return Types.DataType.MATRIX; + } + } + } + + private static long getThisCharDimension(HopOrIdentifier currArr, int arrSizeIterator) { + long thisCharDimension; + if(currArr instanceof Hop){ + thisCharDimension = arrSizeIterator == 0 ? ((Hop) currArr).getDim1() : ((Hop) currArr).getDim2(); + } else if(currArr instanceof Identifier){ + thisCharDimension = arrSizeIterator == 0 ? ((Identifier) currArr).getDim1() : ((Identifier) currArr).getDim2(); + } else { + throw new RuntimeException("validateEinsumAndReturnDimensions called with expressions that are not Hop or Identifier: "+ currArr == null ? "null" : currArr.getClass().toString()); + } + return thisCharDimension; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java index d441efb774c..e4dd5a9da7b 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -37,6 +37,7 @@ import org.apache.sysds.runtime.codegen.*; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.einsum.EinsumContext; import org.apache.sysds.runtime.functionobjects.*; import org.apache.sysds.runtime.matrix.data.LibMatrixMult; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -47,8 +48,6 @@ import java.util.*; -import static org.apache.sysds.runtime.instructions.cp.EinsumContext.getEinsumContext; - public class EinsumCPInstruction extends BuiltinNaryCPInstruction { public static boolean FORCE_CELL_TPL = false; protected static final Log LOG = LogFactory.getLog(EinsumCPInstruction.class.getName()); @@ -81,7 +80,7 @@ public void processInstruction(ExecutionContext ec) { } } - EinsumContext einc = getEinsumContext(eqStr, inputs); + EinsumContext einc = EinsumContext.getEinsumContext(eqStr, inputs); this.einc = einc; String resultString = einc.outChar2 != null ? String.valueOf(einc.outChar1) + einc.outChar2 : einc.outChar1 != null ? String.valueOf(einc.outChar1) : null; From acaa2aebb2c8c6a1028fe4bb7eea87fbff1fccb3 Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Sat, 5 Jul 2025 00:44:43 +0200 Subject: [PATCH 19/28] missing licence notice, better more restricting condition for binary outermultpl. change --- .../java/org/apache/sysds/common/Opcodes.java | 1 - .../sysds/hops/codegen/cplan/CNodeBinary.java | 4 ++-- .../einsum/EinsumEquationValidator.java | 19 +++++++++++++++++++ 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java b/src/main/java/org/apache/sysds/common/Opcodes.java index cb15299559a..e60a21fcf3d 100644 --- a/src/main/java/org/apache/sysds/common/Opcodes.java +++ b/src/main/java/org/apache/sysds/common/Opcodes.java @@ -192,7 +192,6 @@ public enum Opcodes { TRANSFORMMETA("transformmeta", InstructionType.ParameterizedBuiltin), TRANSFORMENCODE("transformencode", InstructionType.MultiReturnParameterizedBuiltin, InstructionType.MultiReturnBuiltin), - //Ternary instruction opcodes PM("+*", InstructionType.Ternary), MINUSMULT("-*", InstructionType.Ternary), diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java index c80242f2547..4a7125532fc 100644 --- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java +++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java @@ -188,8 +188,8 @@ public String codegen(boolean sparse, GeneratorAPI api) { (_type == BinType.VECT_MATRIXMULT ? varj : varj + ".vals(0)")) : _inputs.get(j).getDataType() == DataType.MATRIX ? (api == GeneratorAPI.JAVA ? varj : varj + ".vals(0)") : varj); - if(_type == BinType.VECT_OUTERMULT_ADD && (_inputs.get(j) instanceof CNodeData && _inputs.get(j).getDataType().isMatrix()) && - (varj.startsWith("b"))) + if (_type == BinType.VECT_OUTERMULT_ADD && (_inputs.get(j) instanceof CNodeData && _inputs.get(j).getDataType().isMatrix()) && (varj.startsWith("b") + && j > 0 && TemplateUtils.isMatrix(_inputs.get(j-1)))) tmp = tmp.replace("%POS"+(j+1)+"%",varj + ".pos(rix)"); else //replace start position of main input diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java b/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java index 7a2661c10c6..5643159ef9a 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + package org.apache.sysds.runtime.einsum; import org.apache.commons.lang3.tuple.Triple; From f694b7da3f96ebddae08619120e080caa7534823 Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Sat, 5 Jul 2025 01:08:12 +0200 Subject: [PATCH 20/28] better, clearer EinsumContext properties names --- .../sysds/runtime/einsum/EinsumContext.java | 47 ++++++++++--------- .../instructions/cp/EinsumCPInstruction.java | 31 ++++++------ 2 files changed, 42 insertions(+), 36 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java b/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java index 723dd2fa297..c336c9610ed 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java @@ -26,31 +26,34 @@ import java.util.HashSet; import java.util.Iterator; + public class EinsumContext { + public enum ContractDimensions { + CONTRACT_LEFT, + CONTRACT_RIGHT, + CONTRACT_BOTH, + } public Integer outRows; public Integer outCols; public Character outChar1; public Character outChar2; - public HashMap charToDimensionSizeInt; + public HashMap charToDimensionSize; public String equationString; public boolean[] diagonalInputs; public HashSet summingChars; public HashSet contractDimsSet; - public static final int CONTRACT_LEFT = 1; - public static final int CONTRACT_RIGHT = 2; - public static final int CONTRACT_BOTH = 3; - public int[] contractDims; - public ArrayList newEquationStringSplit; - public HashMap> partsCharactersToIndices; // for each character, this tells in which inputs it appears + public ContractDimensions[] contractDims; + public ArrayList newEquationStringInputsSplit; + public HashMap> characterAppearanceIndexes; // for each character, this tells in which inputs it appears private EinsumContext(){}; public static EinsumContext getEinsumContext(String eqStr, ArrayList inputs){ EinsumContext res = new EinsumContext(); res.equationString = eqStr; - res.charToDimensionSizeInt = new HashMap(); + res.charToDimensionSize = new HashMap(); HashSet summingChars = new HashSet<>(); - int[] contractDims = new int[inputs.size()]; + ContractDimensions[] contractDims = new ContractDimensions[inputs.size()]; boolean[] diagonalInputs = new boolean[inputs.size()]; // all false by default HashSet contractDimsSet = new HashSet(); HashMap> partsCharactersToIndices = new HashMap<>(); @@ -58,7 +61,7 @@ public static EinsumContext getEinsumContext(String eqStr, ArrayList it = inputs.iterator(); MatrixBlock curArr = it.next(); - int arrSizeIterator=0; + int arrSizeIterator = 0; int arrayIterator = 0; int i; // first iteration through string: collect information on character-size and what characters are summing characters @@ -74,17 +77,17 @@ public static EinsumContext getEinsumContext(String eqStr, ArrayList 0 ? eqStr.charAt(i) : null; Character outChar2 = numOfRemainingChars > 1 ? eqStr.charAt(i+1) : null; - res.outRows=(numOfRemainingChars > 0 ? res.charToDimensionSizeInt.get(outChar1) : 1); - res.outCols=(numOfRemainingChars > 1 ? res.charToDimensionSizeInt.get(outChar2) : 1); + res.outRows=(numOfRemainingChars > 0 ? res.charToDimensionSize.get(outChar1) : 1); + res.outCols=(numOfRemainingChars > 1 ? res.charToDimensionSize.get(outChar2) : 1); arrayIterator=0; // second iteration through string: collect remaining information @@ -128,7 +131,7 @@ else if((outChar1 != null && c == outChar1) || (outChar2 != null && c == outChar } else { contractDimsSet.add(c); - contractDims[arrayIterator] = CONTRACT_LEFT; + contractDims[arrayIterator] = ContractDimensions.CONTRACT_LEFT; } if(i + 1 < eqStr.length()) { // process next character together @@ -139,7 +142,7 @@ else if((outChar1 != null && c == outChar1) || (outChar2 != null && c == outChar if (c2 == c){ diagonalInputs[arrayIterator] = true; - if (contractDims[arrayIterator] == CONTRACT_LEFT) contractDims[arrayIterator] = CONTRACT_BOTH; + if (contractDims[arrayIterator] == ContractDimensions.CONTRACT_LEFT) contractDims[arrayIterator] = ContractDimensions.CONTRACT_BOTH; } else{ if(summingChars.contains(c2)) { @@ -153,7 +156,7 @@ else if((outChar1 != null && c2 == outChar1) || (outChar2 != null && c2 == outCh } else { contractDimsSet.add(c2); - contractDims[arrayIterator] += CONTRACT_RIGHT; + contractDims[arrayIterator] = contractDims[arrayIterator] == ContractDimensions.CONTRACT_LEFT ? ContractDimensions.CONTRACT_BOTH : ContractDimensions.CONTRACT_RIGHT; } } } @@ -167,8 +170,8 @@ else if((outChar1 != null && c2 == outChar1) || (outChar2 != null && c2 == outCh res.summingChars = summingChars; res.outChar1 = outChar1; res.outChar2 = outChar2; - res.newEquationStringSplit = newEquationStringSplit; - res.partsCharactersToIndices = partsCharactersToIndices; + res.newEquationStringInputsSplit = newEquationStringSplit; + res.characterAppearanceIndexes = partsCharactersToIndices; return res; } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java index e4dd5a9da7b..3aca1b7f21d 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -87,9 +87,9 @@ public void processInstruction(ExecutionContext ec) { if( LOG.isDebugEnabled() ) LOG.trace("outrows:"+einc.outRows+", outcols:"+einc.outCols); - ArrayList inputsChars = einc.newEquationStringSplit; + ArrayList inputsChars = einc.newEquationStringInputsSplit; - if(LOG.isTraceEnabled()) LOG.trace(String.join(",",einc.newEquationStringSplit)); + if(LOG.isTraceEnabled()) LOG.trace(String.join(",",einc.newEquationStringInputsSplit)); contractDimensionsAndComputeDiagonals(einc, inputs); @@ -102,8 +102,8 @@ public void processInstruction(ExecutionContext ec) { } } - if(LOG.isTraceEnabled()) for(Character c : einc.partsCharactersToIndices.keySet()){ - ArrayList a = einc.partsCharactersToIndices.get(c); + if(LOG.isTraceEnabled()) for(Character c : einc.characterAppearanceIndexes.keySet()){ + ArrayList a = einc.characterAppearanceIndexes.get(c); LOG.trace(c+" count= "+a.size()); } @@ -192,7 +192,7 @@ public void processInstruction(ExecutionContext ec) { ec.setMatrixOutput(output.getName(), res); } - else { + else { // if (needToDoCellTemplate) ArrayList mbs = new ArrayList<>(); ArrayList chars = new ArrayList<>(); for (int i = 0; i < inputs.size(); i++) { @@ -203,11 +203,11 @@ public void processInstruction(ExecutionContext ec) { } } ArrayList summingChars = new ArrayList(); - for (Character c : einc.partsCharactersToIndices.keySet()) { + for (Character c : einc.characterAppearanceIndexes.keySet()) { if (c != einc.outChar1 && c != einc.outChar2) summingChars.add(c); } - MatrixBlock res = computeCellSummation(mbs, chars, resultString, einc.charToDimensionSizeInt, summingChars, einc.outRows, einc.outCols); + MatrixBlock res = computeCellSummation(mbs, chars, resultString, einc.charToDimensionSize, summingChars, einc.outRows, einc.outCols); if (einc.outRows == 1 && einc.outCols == 1) ec.setScalarOutput(output.getName(), new DoubleObject(res.get(0, 0))); @@ -226,28 +226,31 @@ private void contractDimensionsAndComputeDiagonals(EinsumContext einc, ArrayList ReorgOperator op = new ReorgOperator(DiagIndex.getDiagIndexFnObject()); inputs.set(i, inputs.get(i).reorgOperations(op, new MatrixBlock(),0,0,0)); } + if (einc.contractDims[i] == null) continue; switch (einc.contractDims[i]){ - case EinsumContext.CONTRACT_BOTH: { + case CONTRACT_BOTH: { AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads); MatrixBlock res = new MatrixBlock(1, 1, false); inputs.get(i).aggregateUnaryOperations(aggun, res, 0, null); inputs.set(i, res); break; } - case EinsumContext.CONTRACT_RIGHT: { + case CONTRACT_RIGHT: { AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads); MatrixBlock res = new MatrixBlock(inputs.get(i).getNumRows(), 1, false); inputs.get(i).aggregateUnaryOperations(aggun, res, 0, null); inputs.set(i, res); break; } - case EinsumContext.CONTRACT_LEFT: { + case CONTRACT_LEFT: { AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), _numThreads); MatrixBlock res = new MatrixBlock(inputs.get(i).getNumColumns(), 1, false); inputs.get(i).aggregateUnaryOperations(aggun, res, 0, null); inputs.set(i, res); break; } + default: + break; } } } @@ -263,10 +266,10 @@ private boolean generatePlanAndExecute(ArrayList inputs, EinsumCont boolean anyCouldNotDo; boolean didAnything = false; // maybe multiplication will make it summable do { - anyCouldNotDo = sumCharactersWherePossible(einc.partsCharactersToIndices, inputs, einc.newEquationStringSplit, einc.outChar1, einc.outChar2); + anyCouldNotDo = sumCharactersWherePossible(einc.characterAppearanceIndexes, inputs, einc.newEquationStringInputsSplit, einc.outChar1, einc.outChar2); didAnything = false; - if(einc.newEquationStringSplit.stream().filter(Objects::nonNull).count() > 1) - didAnything = multiplyTerms(einc.partsCharactersToIndices, inputs, einc.newEquationStringSplit, einc.outChar1, einc.outChar2); + if(einc.newEquationStringInputsSplit.stream().filter(Objects::nonNull).count() > 1) + didAnything = multiplyTerms(einc.characterAppearanceIndexes, inputs, einc.newEquationStringInputsSplit, einc.outChar1, einc.outChar2); } while(didAnything); @@ -341,7 +344,7 @@ private boolean multiplyTerms(HashMap> partsCharac // returns true if left with summation with more than 2 inputs private boolean sumCharactersWherePossible(HashMap> partsCharactersToIndices, ArrayList inputs, ArrayList inputsChars, Character outChar1, Character outChar2) { - boolean anyCouldNotDo = false; + boolean anyCouldNotDo; while (true) { List toSum = null; From f622d4b0584057a15e2b5fe30d662042dbcc2cfe Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Mon, 21 Jul 2025 01:29:20 +0200 Subject: [PATCH 21/28] aB_a contraction works, by transposing matrix, no longer change to codegen is needed, elementwise mutliplication between matrix and vector works --- .../sysds/hops/codegen/cplan/CNodeBinary.java | 18 +- .../instructions/cp/EinsumCPInstruction.java | 193 +++++++++++++++--- .../test/functions/einsum/EinsumTest.java | 11 +- src/test/scripts/functions/einsum/einsum18.R | 38 ++++ .../scripts/functions/einsum/einsum18.dml | 33 +++ src/test/scripts/functions/einsum/einsum19.R | 33 +++ .../scripts/functions/einsum/einsum19.dml | 29 +++ src/test/scripts/functions/einsum/einsum20.R | 33 +++ .../scripts/functions/einsum/einsum20.dml | 29 +++ 9 files changed, 377 insertions(+), 40 deletions(-) create mode 100644 src/test/scripts/functions/einsum/einsum18.R create mode 100644 src/test/scripts/functions/einsum/einsum18.dml create mode 100644 src/test/scripts/functions/einsum/einsum19.R create mode 100644 src/test/scripts/functions/einsum/einsum19.dml create mode 100644 src/test/scripts/functions/einsum/einsum20.R create mode 100644 src/test/scripts/functions/einsum/einsum20.dml diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java index 4a7125532fc..66b6122c475 100644 --- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java +++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java @@ -188,16 +188,16 @@ public String codegen(boolean sparse, GeneratorAPI api) { (_type == BinType.VECT_MATRIXMULT ? varj : varj + ".vals(0)")) : _inputs.get(j).getDataType() == DataType.MATRIX ? (api == GeneratorAPI.JAVA ? varj : varj + ".vals(0)") : varj); - if (_type == BinType.VECT_OUTERMULT_ADD && (_inputs.get(j) instanceof CNodeData && _inputs.get(j).getDataType().isMatrix()) && (varj.startsWith("b") - && j > 0 && TemplateUtils.isMatrix(_inputs.get(j-1)))) - tmp = tmp.replace("%POS"+(j+1)+"%",varj + ".pos(rix)"); - else +// if (_type == BinType.VECT_OUTERMULT_ADD && (_inputs.get(j) instanceof CNodeData && _inputs.get(j).getDataType().isMatrix()) && (varj.startsWith("b") +// && j > 0 && TemplateUtils.isMatrix(_inputs.get(j-1)))) +// tmp = tmp.replace("%POS"+(j+1)+"%",varj + ".pos(rix)"); +// else //replace start position of main input - tmp = tmp.replace("%POS"+(j+1)+"%", (_inputs.get(j) instanceof CNodeData - && _inputs.get(j).getDataType().isMatrix()) ? (!varj.startsWith("b")) ? varj+"i" : - ((TemplateUtils.isMatrix(_inputs.get(j)) || (_type.isElementwise() - && TemplateUtils.isColVector(_inputs.get(j)))) && _type!=BinType.VECT_MATRIXMULT) ? - varj + ".pos(rix)" : "0" : "0"); + tmp = tmp.replace("%POS"+(j+1)+"%", (_inputs.get(j) instanceof CNodeData + && _inputs.get(j).getDataType().isMatrix()) ? (!varj.startsWith("b")) ? varj+"i" : + ((TemplateUtils.isMatrix(_inputs.get(j)) || (_type.isElementwise() + && TemplateUtils.isColVector(_inputs.get(j)))) && _type!=BinType.VECT_MATRIXMULT) ? + varj + ".pos(rix)" : "0" : "0"); } //replace length information (e.g., after matrix mult) if( _type == BinType.VECT_OUTERMULT_ADD || (_type == BinType.VECT_CBIND && vectorVector) ) { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java index 3aca1b7f21d..e3addee50a7 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -19,7 +19,6 @@ package org.apache.sysds.runtime.instructions.cp; -import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -95,11 +94,7 @@ public void processInstruction(ExecutionContext ec) { //make all vetors col vectors for(int i = 0; i < inputs.size(); i++){ - if(inputs.get(i) != null && inputsChars.get(i).length() == 1 && inputs.get(i).getNumColumns() > 1){ - inputs.get(i).setNumRows(inputs.get(i).getNumColumns()); - inputs.get(i).setNumColumns(1); - inputs.get(i).getDenseBlock().resetNoFill(inputs.get(i).getNumColumns(),1); - } + if(inputs.get(i) != null && inputsChars.get(i).length() == 1) EnsureMatrixBlockColumnVector(inputs.get(i)); } if(LOG.isTraceEnabled()) for(Character c : einc.characterAppearanceIndexes.keySet()){ @@ -206,6 +201,7 @@ public void processInstruction(ExecutionContext ec) { for (Character c : einc.characterAppearanceIndexes.keySet()) { if (c != einc.outChar1 && c != einc.outChar2) summingChars.add(c); } + if(LOG.isTraceEnabled()) LOG.trace("finishing with cell tpl: "+String.join(",", chars)); MatrixBlock res = computeCellSummation(mbs, chars, resultString, einc.charToDimensionSize, summingChars, einc.outRows, einc.outCols); @@ -214,6 +210,8 @@ public void processInstruction(ExecutionContext ec) { else ec.setMatrixOutput(output.getName(), res); } + if(LOG.isTraceEnabled()) LOG.trace("EinsumCPInstruction Finished"); + releaseMatrixInputs(ec); } @@ -273,36 +271,112 @@ private boolean generatePlanAndExecute(ArrayList inputs, EinsumCont } while(didAnything); + if(LOG.isTraceEnabled()) LOG.trace("generatePlanAndExecute() finished"); + return anyCouldNotDo; } + private static void EnsureMatrixBlockColumnVector(MatrixBlock mb){ + if(mb.getNumColumns() > 1){ + mb.setNumRows(mb.getNumColumns()); + mb.setNumColumns(1); + mb.getDenseBlock().resetNoFill(mb.getNumRows(),1); + } + } + private static void EnsureMatrixBlockRowVector(MatrixBlock mb){ + if(mb.getNumRows() > 1){ + mb.setNumColumns(mb.getNumRows()); + mb.setNumRows(1); + mb.getDenseBlock().resetNoFill(1,mb.getNumColumns()); + } + } + /* handle situation: ji,ji or ij,ji, j,j */ private boolean multiplyTerms(HashMap> partsCharactersToIndices, ArrayList inputs, ArrayList inputsChars, Character outChar1, Character outChar2 ) { - HashMap> stringToIndex = new HashMap<>(); + HashMap> matrixStringToIndex = new HashMap<>(); + HashMap> vectorStringToIndex = new HashMap<>(); for(int i = 0; i < inputsChars.size(); i++){ String s = inputsChars.get(i); if(s==null) continue; - if (stringToIndex.containsKey(s)) stringToIndex.get(s).add(i); - else { ArrayList list = new ArrayList<>(); list.add(i); stringToIndex.put(s, list); } + if(s.length() == 1){ + if (vectorStringToIndex.containsKey(s)) vectorStringToIndex.get(s).add(i); + else { LinkedList list = new LinkedList<>(); list.add(i); vectorStringToIndex.put(s, list); } + }else{ + if (matrixStringToIndex.containsKey(s)) matrixStringToIndex.get(s).add(i); + else { LinkedList list = new LinkedList<>(); list.add(i); matrixStringToIndex.put(s, list); } + } } boolean doneAnything = false; - for(var s : stringToIndex.keySet()){ - if(!stringToIndex.containsKey(s)) continue; // entries can be removed + // first do the vector-vector multipl. + for(var s : vectorStringToIndex.keySet()){ + if(vectorStringToIndex.get(s).size() <= 1) continue; + + doneAnything = true; + + Integer vectorIdx0 = vectorStringToIndex.get(s).removeFirst(); + + if(LOG.isTraceEnabled()){ + StringBuilder sb = new StringBuilder(); + + LOG.trace("Element wise multiplying: "+s); + } + MatrixBlock mb = inputs.get(vectorIdx0); + EnsureMatrixBlockColumnVector(mb); + + do{ // do all vectors with same char + Integer vectorIdx1 = vectorStringToIndex.get(s).removeFirst(); + + ArrayList mbs = new ArrayList<>(); + + mbs.add(mb); + mbs.add(inputs.get(vectorIdx1)); + EnsureMatrixBlockColumnVector(mbs.get(1)); + + mb = getCodegenElemwiseMult(mbs); + + inputs.set(vectorIdx1, null); + inputsChars.set(vectorIdx1, null); + } while (vectorStringToIndex.get(s).size() > 1); + + inputs.set(vectorIdx0, null); + inputsChars.set(vectorIdx0, null); + + inputs.add(mb); + inputsChars.add(s); + if (partsCharactersToIndices.containsKey(s.charAt(0))) partsCharactersToIndices.get(s.charAt(0)).add(inputs.size() - 1); + vectorStringToIndex.get(s).add(inputs.size() - 1); + } + + for(var s : matrixStringToIndex.keySet()){ + if(!matrixStringToIndex.containsKey(s)) continue; // entries can be removed String sT = s.length() == 2 ? String.valueOf(s.charAt(1)) + s.charAt(0) : null; - ArrayList idxs = stringToIndex.get(s); - ArrayList idxsT = sT != null ? stringToIndex.containsKey(sT) ? stringToIndex.get(sT) : null : null; + LinkedList idxs = matrixStringToIndex.get(s); + LinkedList idxsT = sT != null ? matrixStringToIndex.containsKey(sT) ? matrixStringToIndex.get(sT) : null : null; + + Integer vectorIdx0 = null; + Integer vectorIdx1 = null; + String char0 = String.valueOf(s.charAt(0)); + if(vectorStringToIndex.containsKey(char0)){ + // ab,a-> ab + vectorIdx0 = vectorStringToIndex.get(char0).removeFirst(); // only one should left + } + + String char1 = String.valueOf(s.charAt(1)); + if(vectorStringToIndex.containsKey(char1)){ + // ab,b -> ab + vectorIdx1 = vectorStringToIndex.get(char1).removeFirst(); // only one should left + } - if(idxs.size() <= 1 && idxsT == null) continue; + if(idxs.size() <= 1 && idxsT == null && vectorIdx0 == null && vectorIdx1 == null) continue; doneAnything = true; - // do decision if transpose idxs or idxsT: right now just alway transpose second - ArrayList mbs = new ArrayList<>(); + // do decision if transpose idxs or idxsT: right now just always transpose second if(LOG.isTraceEnabled()){ StringBuilder sb = new StringBuilder(); for(Integer idx : idxs){ @@ -313,32 +387,83 @@ private boolean multiplyTerms(HashMap> partsCharac sb.append(inputsChars.get(idx)); sb.append(","); } + if(vectorIdx0 != null) { sb.append(s.charAt(0)).append(","); } + if(vectorIdx1 != null) { sb.append(s.charAt(1)).append(","); } + LOG.trace("Element wise multiplying: "+sb.toString()); } - for(Integer idx : idxs){ - mbs.add(inputs.get(idx)); - inputs.set(idx, null); - inputsChars.set(idx, null); + Integer matrixIdx0 = idxs.removeFirst(); + + MatrixBlock mb = inputs.get(matrixIdx0); + + inputs.set(matrixIdx0, null); + inputsChars.set(matrixIdx0, null); + + while (!idxs.isEmpty()){ + Integer matrixIdx1 = idxs.removeFirst(); + + ArrayList mbs = new ArrayList<>(); + mbs.add(mb); + mbs.add(inputs.get(matrixIdx1)); + + mb = getCodegenElemwiseMult(mbs); + + inputs.set(matrixIdx1, null); + inputsChars.set(matrixIdx1, null); } - ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); - if(idxsT != null) for(Integer idx : idxsT){ - mbs.add(inputs.get(idx).reorgOperations(transpose, new MatrixBlock(), 0, 0, 0)); - inputs.set(idx, null); - inputsChars.set(idx, null); + + if(idxsT != null) while(!idxsT.isEmpty()) { + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); + + Integer matrixIdx1 = idxsT.removeFirst(); + + ArrayList mbs = new ArrayList<>(); + mbs.add(mb); + mbs.add(inputs.get(matrixIdx1).reorgOperations(transpose, new MatrixBlock(), 0, 0, 0)); + + mb = getCodegenElemwiseMult(mbs); + + inputs.set(matrixIdx1, null); + inputsChars.set(matrixIdx1, null); } - MatrixBlock mb = getCodegenElemwiseMult(mbs); + if(vectorIdx1 != null){ // ab,b->ab + EnsureMatrixBlockRowVector(inputs.get(vectorIdx1)); + + mb = getRowCodegenMatrixBlock(mb, inputs.get(vectorIdx1), CNodeBinary.BinType.VECT_MULT, SpoofRowwise.RowType.NO_AGG, null); + + inputs.set(vectorIdx1, null); + inputsChars.set(vectorIdx1, null); + } + + if(vectorIdx0 != null){ // ab,a->ab + EnsureMatrixBlockRowVector(inputs.get(vectorIdx0)); + +// mb = getRowCodegenMatrixBlock(mb, inputs.get(vectorIdx0), CNodeBinary.BinType.VECT_MULT, SpoofRowwise.RowType.NO_AGG,null); +// mb = getRowCodegenMatrixBlock(mb, inputs.get(vectorIdx0), CNodeBinary.BinType.VECT_MULT_SCALAR, SpoofRowwise.RowType.NO_AGG, Long.valueOf( inputs.get(vectorIdx0).getNumColumns())); + { + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); + mb = mb.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + mb = getRowCodegenMatrixBlock(mb, inputs.get(vectorIdx0), CNodeBinary.BinType.VECT_MULT, SpoofRowwise.RowType.NO_AGG, null); + } + + inputs.set(vectorIdx0, null); + inputsChars.set(vectorIdx0, null); + s = String.valueOf(s.charAt(1))+String.valueOf(s.charAt(0)); + } inputs.add(mb); inputsChars.add(s); for (int i = 0; i < s.length(); i++) { // for each char in string, add pointer to newly created entry char c = s.charAt(i); - partsCharactersToIndices.get(c).add(inputs.size() - 1); + if (partsCharactersToIndices.containsKey(c)) partsCharactersToIndices.get(c).add(inputs.size() - 1); } - if(idxsT != null) stringToIndex.remove(sT); + if(idxsT != null) matrixStringToIndex.remove(sT); } + + return doneAnything; } @@ -500,10 +625,18 @@ else if(s1.charAt(1) == s2.charAt(0)){ } MatrixBlock out; + if(LOG.isTraceEnabled()) LOG.trace("remaining: "+String.join(",",inputsChars.stream().filter(Objects::nonNull).toList())); if(LOG.isTraceEnabled()) LOG.trace("Summing: "+s1+","+s2+"->"+resS); switch (sumOp) { + case aB_a:{ + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); + first = first.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.DOT_PRODUCT, SpoofRowwise.RowType.ROW_AGG, null); + break; + } case Ba_a: - throw new NotImplementedException(); + out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.DOT_PRODUCT, SpoofRowwise.RowType.ROW_AGG, null); + break; case Ba_aC: { out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_MATRIXMULT, SpoofRowwise.RowType.NO_AGG_B1, Long.valueOf( second.getNumColumns())); break; @@ -514,7 +647,7 @@ else if(s1.charAt(1) == s2.charAt(0)){ out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_MATRIXMULT, SpoofRowwise.RowType.NO_AGG_B1, Long.valueOf(second.getNumColumns())); break; } - case aB_a: +// case aB_a: case aB_aC: { out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_OUTERMULT_ADD, SpoofRowwise.RowType.COL_AGG_B1_T, Long.valueOf( second.getNumColumns())); break; diff --git a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java index 4f755f8b560..90660f9f0b1 100644 --- a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java +++ b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java @@ -55,6 +55,9 @@ public class EinsumTest extends AutomatedTestBase private static final String TEST_EINSUM15 = TEST_NAME_EINSUM+"15"; private static final String TEST_EINSUM16 = TEST_NAME_EINSUM+"16"; private static final String TEST_EINSUM17 = TEST_NAME_EINSUM+"17"; + private static final String TEST_EINSUM18 = TEST_NAME_EINSUM+"18"; + private static final String TEST_EINSUM19 = TEST_NAME_EINSUM+"19"; + private static final String TEST_EINSUM20 = TEST_NAME_EINSUM+"20"; private static final String TEST_DIR = "functions/einsum/"; private static final String TEST_CLASS_DIR = TEST_DIR + EinsumTest.class.getSimpleName() + "/"; private final static String TEST_CONF = "SystemDS-config-codegen.xml"; @@ -65,7 +68,7 @@ public class EinsumTest extends AutomatedTestBase @Override public void setUp() { TestUtils.clearAssertionInformation(); - for(int i=1; i<=17; i++) + for(int i=1; i<=20; i++) addTestConfiguration( TEST_NAME_EINSUM+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_EINSUM+i, new String[] { String.valueOf(i) }) ); } @Test @@ -126,6 +129,12 @@ public void testCodegenEinsum13CP() { public void testCodegenEinsum16CP() { testCodegenIntegration( TEST_EINSUM16, ExecType.CP); } @Test public void testCodegenEinsum17CP() { testCodegenIntegration( TEST_EINSUM17, ExecType.CP); } + @Test + public void testCodegenEinsum18CP() { testCodegenIntegration( TEST_EINSUM18, ExecType.CP); } + @Test + public void testCodegenEinsum19CP() { testCodegenIntegration( TEST_EINSUM19, ExecType.CP); } + @Test + public void testCodegenEinsum20CP() { testCodegenIntegration( TEST_EINSUM20, ExecType.CP); } private void testCodegenIntegration( String testname, ExecType instType) { testCodegenIntegration(testname, instType, false); } private void testCodegenIntegration( String testname, ExecType instType, boolean outputScalar ) { diff --git a/src/test/scripts/functions/einsum/einsum18.R b/src/test/scripts/functions/einsum/einsum18.R new file mode 100644 index 00000000000..7fae3a7e621 --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum18.R @@ -0,0 +1,38 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") +library("einsum") + +A0 = matrix(seq(1,500), 10, 50, byrow=TRUE) * 0.0001 +A1 = matrix(seq(1,800), 10, 80, byrow=TRUE) * 0.0001 +A2 = matrix(seq(1,30), 10, 3, byrow=TRUE) * 0.0001 +A3 = matrix(seq(1,4000), 50, 80, byrow=TRUE) * 0.0001 +A4 = matrix(seq(1,120), 3, 40, byrow=TRUE) * 0.0001 +A5 = seq(1,40) * 0.0001 +A6 = seq(1,3) * 0.0001 + +R = einsum("fx,fg,fz,xg,pq,q,p->zp", A0, A1, A2, A3, A4, A5, A6) + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum18.dml b/src/test/scripts/functions/einsum/einsum18.dml new file mode 100644 index 00000000000..7748e2b91e2 --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum18.dml @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A0 = matrix(seq(1,500), 10, 50) * 0.0001 +A1 = matrix(seq(1,800), 10, 80) * 0.0001 +A2 = matrix(seq(1,30), 10, 3) * 0.0001 +A3 = matrix(seq(1,4000), 50, 80) * 0.0001 +A4 = matrix(seq(1,120), 3, 40) * 0.0001 +A5 = seq(1,40) * 0.0001 +A6 = seq(1,3) * 0.0001 + +while(FALSE){} + +R = einsum("fx,fg,fz,xg,pq,q,p->zp", A0, A1, A2, A3, A4, A5, A6) +write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum19.R b/src/test/scripts/functions/einsum/einsum19.R new file mode 100644 index 00000000000..fced04caa92 --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum19.R @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") +library("einsum") + +A0 = matrix(seq(1,50000), 1000, 50, byrow=TRUE) * 0.0001 +A1 = seq(1,50) * 0.0001 + +R = einsum("ij,j->ij", A0, A1) + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum19.dml b/src/test/scripts/functions/einsum/einsum19.dml new file mode 100644 index 00000000000..c5743714a26 --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum19.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +A0 = matrix(seq(1,50000), 1000, 50) * 0.0001 +A1 = seq(1,50) * 0.0001 + +while(FALSE){} + +R = einsum("ij,j->ij", A0, A1) +write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum20.R b/src/test/scripts/functions/einsum/einsum20.R new file mode 100644 index 00000000000..2b3c556bda3 --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum20.R @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") +library("einsum") + +A0 = matrix(seq(1,50000), 1000, 50, byrow=TRUE) * 0.0001 +A1 = seq(1,1000) * 0.0001 + +R = einsum("ij,i->ij", A0, A1) + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum20.dml b/src/test/scripts/functions/einsum/einsum20.dml new file mode 100644 index 00000000000..aeb591beff9 --- /dev/null +++ b/src/test/scripts/functions/einsum/einsum20.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +A0 = matrix(seq(1,50000), 1000, 50) * 0.0001 +A1 = seq(1,1000) * 0.0001 + +while(FALSE){} + +R = einsum("ij,i->ij", A0, A1) +write(R, $1) From 1e5e0fae9fd6f28783d96df829a9ef384e2b5a45 Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Fri, 25 Jul 2025 11:42:38 +0200 Subject: [PATCH 22/28] remove changes to CNodeBinary --- .../apache/sysds/hops/codegen/cplan/CNodeBinary.java | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java index 66b6122c475..8cca98dfef2 100644 --- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java +++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java @@ -188,16 +188,12 @@ public String codegen(boolean sparse, GeneratorAPI api) { (_type == BinType.VECT_MATRIXMULT ? varj : varj + ".vals(0)")) : _inputs.get(j).getDataType() == DataType.MATRIX ? (api == GeneratorAPI.JAVA ? varj : varj + ".vals(0)") : varj); -// if (_type == BinType.VECT_OUTERMULT_ADD && (_inputs.get(j) instanceof CNodeData && _inputs.get(j).getDataType().isMatrix()) && (varj.startsWith("b") -// && j > 0 && TemplateUtils.isMatrix(_inputs.get(j-1)))) -// tmp = tmp.replace("%POS"+(j+1)+"%",varj + ".pos(rix)"); -// else //replace start position of main input tmp = tmp.replace("%POS"+(j+1)+"%", (_inputs.get(j) instanceof CNodeData - && _inputs.get(j).getDataType().isMatrix()) ? (!varj.startsWith("b")) ? varj+"i" : - ((TemplateUtils.isMatrix(_inputs.get(j)) || (_type.isElementwise() - && TemplateUtils.isColVector(_inputs.get(j)))) && _type!=BinType.VECT_MATRIXMULT) ? - varj + ".pos(rix)" : "0" : "0"); + && _inputs.get(j).getDataType().isMatrix()) ? (!varj.startsWith("b")) ? varj+"i" : + ((TemplateUtils.isMatrix(_inputs.get(j)) || (_type.isElementwise() + && TemplateUtils.isColVector(_inputs.get(j)))) && _type!=BinType.VECT_MATRIXMULT) ? + varj + ".pos(rix)" : "0" : "0"); } //replace length information (e.g., after matrix mult) if( _type == BinType.VECT_OUTERMULT_ADD || (_type == BinType.VECT_CBIND && vectorVector) ) { From 63683c9f308b97309946383fe24a74a7fe789480 Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Fri, 25 Jul 2025 11:46:55 +0200 Subject: [PATCH 23/28] remove changes to CNodeBinary 2 --- .../apache/sysds/hops/codegen/cplan/CNodeBinary.java | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java index 8cca98dfef2..b29d586c38a 100644 --- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java +++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java @@ -187,13 +187,13 @@ public String codegen(boolean sparse, GeneratorAPI api) { varj.startsWith("b") ? (api == GeneratorAPI.JAVA ? varj + ".values(rix)" : (_type == BinType.VECT_MATRIXMULT ? varj : varj + ".vals(0)")) : _inputs.get(j).getDataType() == DataType.MATRIX ? (api == GeneratorAPI.JAVA ? varj : varj + ".vals(0)") : varj); - + //replace start position of main input - tmp = tmp.replace("%POS"+(j+1)+"%", (_inputs.get(j) instanceof CNodeData - && _inputs.get(j).getDataType().isMatrix()) ? (!varj.startsWith("b")) ? varj+"i" : + tmp = tmp.replace("%POS"+(j+1)+"%", (_inputs.get(j) instanceof CNodeData + && _inputs.get(j).getDataType().isMatrix()) ? (!varj.startsWith("b")) ? varj+"i" : ((TemplateUtils.isMatrix(_inputs.get(j)) || (_type.isElementwise() - && TemplateUtils.isColVector(_inputs.get(j)))) && _type!=BinType.VECT_MATRIXMULT) ? - varj + ".pos(rix)" : "0" : "0"); + && TemplateUtils.isColVector(_inputs.get(j)))) && _type!=BinType.VECT_MATRIXMULT) ? + varj + ".pos(rix)" : "0" : "0"); } //replace length information (e.g., after matrix mult) if( _type == BinType.VECT_OUTERMULT_ADD || (_type == BinType.VECT_CBIND && vectorVector) ) { From ae4bb82334180e3883dd0b2bce787bd25063ea13 Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Fri, 25 Jul 2025 11:54:08 +0200 Subject: [PATCH 24/28] reduce size of einsum test18 --- src/test/scripts/functions/einsum/einsum18.R | 10 +++++----- src/test/scripts/functions/einsum/einsum18.dml | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/test/scripts/functions/einsum/einsum18.R b/src/test/scripts/functions/einsum/einsum18.R index 7fae3a7e621..fa25980405b 100644 --- a/src/test/scripts/functions/einsum/einsum18.R +++ b/src/test/scripts/functions/einsum/einsum18.R @@ -25,12 +25,12 @@ library("Matrix") library("matrixStats") library("einsum") -A0 = matrix(seq(1,500), 10, 50, byrow=TRUE) * 0.0001 -A1 = matrix(seq(1,800), 10, 80, byrow=TRUE) * 0.0001 +A0 = matrix(seq(1,250), 10, 25, byrow=TRUE) * 0.0001 +A1 = matrix(seq(1,200), 10, 20, byrow=TRUE) * 0.0001 A2 = matrix(seq(1,30), 10, 3, byrow=TRUE) * 0.0001 -A3 = matrix(seq(1,4000), 50, 80, byrow=TRUE) * 0.0001 -A4 = matrix(seq(1,120), 3, 40, byrow=TRUE) * 0.0001 -A5 = seq(1,40) * 0.0001 +A3 = matrix(seq(1,500), 25, 20, byrow=TRUE) * 0.0001 +A4 = matrix(seq(1,33), 3, 11, byrow=TRUE) * 0.0001 +A5 = seq(1,11) * 0.0001 A6 = seq(1,3) * 0.0001 R = einsum("fx,fg,fz,xg,pq,q,p->zp", A0, A1, A2, A3, A4, A5, A6) diff --git a/src/test/scripts/functions/einsum/einsum18.dml b/src/test/scripts/functions/einsum/einsum18.dml index 7748e2b91e2..c71654bf01d 100644 --- a/src/test/scripts/functions/einsum/einsum18.dml +++ b/src/test/scripts/functions/einsum/einsum18.dml @@ -19,12 +19,12 @@ # #------------------------------------------------------------- -A0 = matrix(seq(1,500), 10, 50) * 0.0001 -A1 = matrix(seq(1,800), 10, 80) * 0.0001 +A0 = matrix(seq(1,250), 10, 25) * 0.0001 +A1 = matrix(seq(1,200), 10, 20) * 0.0001 A2 = matrix(seq(1,30), 10, 3) * 0.0001 -A3 = matrix(seq(1,4000), 50, 80) * 0.0001 -A4 = matrix(seq(1,120), 3, 40) * 0.0001 -A5 = seq(1,40) * 0.0001 +A3 = matrix(seq(1,500), 25, 20) * 0.0001 +A4 = matrix(seq(1,33), 3, 11) * 0.0001 +A5 = seq(1,11) * 0.0001 A6 = seq(1,3) * 0.0001 while(FALSE){} From 7eb950c46130fe45917d9cc7f200f0c366476982 Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Sun, 3 Aug 2025 22:56:11 +0200 Subject: [PATCH 25/28] create einsum test files from configuration list, fixes in code --- .../instructions/cp/EinsumCPInstruction.java | 15 +- .../test/functions/einsum/EinsumTest.java | 344 +++++++++++++----- src/test/scripts/functions/einsum/einsum1.R | 34 -- src/test/scripts/functions/einsum/einsum1.dml | 30 -- src/test/scripts/functions/einsum/einsum10.R | 35 -- .../scripts/functions/einsum/einsum10.dml | 31 -- src/test/scripts/functions/einsum/einsum11.R | 38 -- .../scripts/functions/einsum/einsum11.dml | 33 -- src/test/scripts/functions/einsum/einsum12.R | 35 -- .../scripts/functions/einsum/einsum12.dml | 31 -- src/test/scripts/functions/einsum/einsum13.R | 33 -- .../scripts/functions/einsum/einsum13.dml | 28 -- src/test/scripts/functions/einsum/einsum14.R | 34 -- .../scripts/functions/einsum/einsum14.dml | 28 -- src/test/scripts/functions/einsum/einsum15.R | 33 -- .../scripts/functions/einsum/einsum15.dml | 28 -- src/test/scripts/functions/einsum/einsum16.R | 32 -- .../scripts/functions/einsum/einsum16.dml | 27 -- src/test/scripts/functions/einsum/einsum17.R | 33 -- .../scripts/functions/einsum/einsum17.dml | 27 -- src/test/scripts/functions/einsum/einsum18.R | 38 -- .../scripts/functions/einsum/einsum18.dml | 33 -- src/test/scripts/functions/einsum/einsum19.R | 33 -- .../scripts/functions/einsum/einsum19.dml | 29 -- src/test/scripts/functions/einsum/einsum2.R | 34 -- src/test/scripts/functions/einsum/einsum2.dml | 30 -- src/test/scripts/functions/einsum/einsum20.R | 33 -- .../scripts/functions/einsum/einsum20.dml | 29 -- src/test/scripts/functions/einsum/einsum3.R | 34 -- src/test/scripts/functions/einsum/einsum3.dml | 30 -- src/test/scripts/functions/einsum/einsum4.R | 33 -- src/test/scripts/functions/einsum/einsum4.dml | 28 -- src/test/scripts/functions/einsum/einsum5.R | 34 -- src/test/scripts/functions/einsum/einsum5.dml | 29 -- src/test/scripts/functions/einsum/einsum6.R | 34 -- src/test/scripts/functions/einsum/einsum6.dml | 30 -- src/test/scripts/functions/einsum/einsum7.R | 35 -- src/test/scripts/functions/einsum/einsum7.dml | 28 -- src/test/scripts/functions/einsum/einsum8.R | 34 -- src/test/scripts/functions/einsum/einsum8.dml | 30 -- src/test/scripts/functions/einsum/einsum9.R | 34 -- src/test/scripts/functions/einsum/einsum9.dml | 30 -- 42 files changed, 264 insertions(+), 1367 deletions(-) delete mode 100644 src/test/scripts/functions/einsum/einsum1.R delete mode 100644 src/test/scripts/functions/einsum/einsum1.dml delete mode 100644 src/test/scripts/functions/einsum/einsum10.R delete mode 100644 src/test/scripts/functions/einsum/einsum10.dml delete mode 100644 src/test/scripts/functions/einsum/einsum11.R delete mode 100644 src/test/scripts/functions/einsum/einsum11.dml delete mode 100644 src/test/scripts/functions/einsum/einsum12.R delete mode 100644 src/test/scripts/functions/einsum/einsum12.dml delete mode 100644 src/test/scripts/functions/einsum/einsum13.R delete mode 100644 src/test/scripts/functions/einsum/einsum13.dml delete mode 100644 src/test/scripts/functions/einsum/einsum14.R delete mode 100644 src/test/scripts/functions/einsum/einsum14.dml delete mode 100644 src/test/scripts/functions/einsum/einsum15.R delete mode 100644 src/test/scripts/functions/einsum/einsum15.dml delete mode 100644 src/test/scripts/functions/einsum/einsum16.R delete mode 100644 src/test/scripts/functions/einsum/einsum16.dml delete mode 100644 src/test/scripts/functions/einsum/einsum17.R delete mode 100644 src/test/scripts/functions/einsum/einsum17.dml delete mode 100644 src/test/scripts/functions/einsum/einsum18.R delete mode 100644 src/test/scripts/functions/einsum/einsum18.dml delete mode 100644 src/test/scripts/functions/einsum/einsum19.R delete mode 100644 src/test/scripts/functions/einsum/einsum19.dml delete mode 100644 src/test/scripts/functions/einsum/einsum2.R delete mode 100644 src/test/scripts/functions/einsum/einsum2.dml delete mode 100644 src/test/scripts/functions/einsum/einsum20.R delete mode 100644 src/test/scripts/functions/einsum/einsum20.dml delete mode 100644 src/test/scripts/functions/einsum/einsum3.R delete mode 100644 src/test/scripts/functions/einsum/einsum3.dml delete mode 100644 src/test/scripts/functions/einsum/einsum4.R delete mode 100644 src/test/scripts/functions/einsum/einsum4.dml delete mode 100644 src/test/scripts/functions/einsum/einsum5.R delete mode 100644 src/test/scripts/functions/einsum/einsum5.dml delete mode 100644 src/test/scripts/functions/einsum/einsum6.R delete mode 100644 src/test/scripts/functions/einsum/einsum6.dml delete mode 100644 src/test/scripts/functions/einsum/einsum7.R delete mode 100644 src/test/scripts/functions/einsum/einsum7.dml delete mode 100644 src/test/scripts/functions/einsum/einsum8.R delete mode 100644 src/test/scripts/functions/einsum/einsum8.dml delete mode 100644 src/test/scripts/functions/einsum/einsum9.R delete mode 100644 src/test/scripts/functions/einsum/einsum9.dml diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java index e3addee50a7..80002084c63 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -82,7 +82,7 @@ public void processInstruction(ExecutionContext ec) { EinsumContext einc = EinsumContext.getEinsumContext(eqStr, inputs); this.einc = einc; - String resultString = einc.outChar2 != null ? String.valueOf(einc.outChar1) + einc.outChar2 : einc.outChar1 != null ? String.valueOf(einc.outChar1) : null; + String resultString = einc.outChar2 != null ? String.valueOf(einc.outChar1) + einc.outChar2 : einc.outChar1 != null ? String.valueOf(einc.outChar1) : ""; if( LOG.isDebugEnabled() ) LOG.trace("outrows:"+einc.outRows+", outcols:"+einc.outCols); @@ -184,7 +184,9 @@ public void processInstruction(ExecutionContext ec) { throw new RuntimeException("Einsum runtime error, reductions and multiplications finished but the did not produce one result"); // should not happen } } - ec.setMatrixOutput(output.getName(), res); + if (einc.outRows == 1 && einc.outCols == 1) + ec.setScalarOutput(output.getName(), new DoubleObject(res.get(0, 0))); + else ec.setMatrixOutput(output.getName(), res); } else { // if (needToDoCellTemplate) @@ -294,6 +296,7 @@ private static void EnsureMatrixBlockRowVector(MatrixBlock mb){ /* handle situation: ji,ji or ij,ji, j,j */ private boolean multiplyTerms(HashMap> partsCharactersToIndices, ArrayList inputs, ArrayList inputsChars, Character outChar1, Character outChar2 ) { HashMap> matrixStringToIndex = new HashMap<>(); + HashSet matrixStringToIndexSkip = new HashSet<>(); HashMap> vectorStringToIndex = new HashMap<>(); for(int i = 0; i < inputsChars.size(); i++){ @@ -352,7 +355,7 @@ private boolean multiplyTerms(HashMap> partsCharac } for(var s : matrixStringToIndex.keySet()){ - if(!matrixStringToIndex.containsKey(s)) continue; // entries can be removed + if(matrixStringToIndexSkip.contains(s)) continue; String sT = s.length() == 2 ? String.valueOf(s.charAt(1)) + s.charAt(0) : null; LinkedList idxs = matrixStringToIndex.get(s); @@ -459,7 +462,7 @@ private boolean multiplyTerms(HashMap> partsCharac if (partsCharactersToIndices.containsKey(c)) partsCharactersToIndices.get(c).add(inputs.size() - 1); } - if(idxsT != null) matrixStringToIndex.remove(sT); + if(idxsT != null) matrixStringToIndexSkip.add(sT); } @@ -542,9 +545,11 @@ private Pair computeRowSummation(List toSum, Array String resS; SumOperation sumOp; - if(s1.length()==1 && s2.length() == 1){ //remove never happening here + if(s1.length()==1 && s2.length() == 1){ sumOp = SumOperation.a_a; resS = ""; + first = inputs.get(toSum.get(0)); + second = inputs.get(toSum.get(1)); } else if(s2.length() == 1 || s1.length() == 1){ if(s1.length() == 1){ diff --git a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java index 90660f9f0b1..97166a8cbef 100644 --- a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java +++ b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java @@ -28,36 +28,261 @@ import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.test.TestUtils; +import org.junit.After; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; +import java.util.List; +@RunWith(Parameterized.class) public class EinsumTest extends AutomatedTestBase { + final private static List TEST_CONFIGS = List.of( + new Config("ij,jk->ik", List.of(shape(5, 600), shape(600, 10))), // mm + new Config("ji,jk->ik", List.of(shape(600, 5), shape(600, 10))), + new Config("ji,kj->ik", List.of(shape(600, 5), shape(10, 600))), + new Config("ij,kj->ik", List.of(shape(5, 600), shape(10, 600))), + + new Config("ji,jk->i", List.of(shape(600, 5), shape(600, 10))), + new Config("ij,jk->i", List.of(shape(5, 600), shape(600, 10))), + + new Config("ji,jk->k", List.of(shape(600, 5), shape(600, 10))), + new Config("ij,jk->k", List.of(shape(5, 600), shape(600, 10))), + + new Config("ji,jk->j", List.of(shape(600, 5), shape(600, 10))), + + new Config("ji,ji->ji", List.of(shape(600, 5), shape(600, 5))), // elemwise mult + new Config("ji,ij->ji", List.of(shape(600, 5), shape(5, 600))), // elemwise mult + + + new Config("ij,i->ij", List.of(shape(1000, 50), shape(1000))), // col mult + new Config("ji,i->ij", List.of(shape(50, 1000), shape(1000))), // row mult + new Config("ij,i->i", List.of(shape(1000, 50), shape(1000))), + new Config("ij,i->j", List.of(shape(1000, 50), shape(1000))), + + new Config("i,i->", List.of(shape(500), shape(500))), + new Config("i,j->", List.of(shape(500), shape(800))), + new Config("i,j->ij", List.of(shape(500), shape(800))), // outer vect mult + new Config("i,j->ji", List.of(shape(500), shape(800))), // outer vect mult + + new Config("ij->", List.of(shape(1000, 50))), // sum + new Config("ij->i", List.of(shape(1000, 50))), // sum(1) + new Config("ij->j", List.of(shape(1000, 50))), // sum(0) + new Config("ij->ji", List.of(shape(1000, 50))), // T + + new Config("ab,cd->ba", List.of(shape( 600, 10), shape(6, 5))), + + new Config("ab,bc,cd,de->ae", List.of(shape(5, 600), shape(600, 10), shape(10, 5), shape(5, 4))), // chain of mm + + new Config("ji,jz,zx->ix", List.of(shape(600, 5), shape( 600, 10), shape(10, 2))), + new Config("fx,fg,fz,xg->z", List.of(shape(600, 5), shape( 600, 10), shape(600, 6), shape(5, 10))), + new Config("fx,fg,fz,xg,zx,zg->g", // each idx 3 times (cell tpl) + List.of(shape(5, 60), shape(5, 30), shape(5, 100), shape(60, 30), shape(100, 60), shape(100, 30))), + + new Config("i->", List.of(shape(1000))), + new Config("i->i", List.of(shape(1000))) + ); + + private final int id; + private final String einsumStr; + private final List shapes; + private final File dmlFile; + private final File rFile; + private final boolean outputScalar; + + public EinsumTest(String einsumStr, List shapes, File dmlFile, File rFile, boolean outputScalar, int id){ + this.id = id; + this.einsumStr = einsumStr; + this.shapes = shapes; + this.dmlFile = dmlFile; + this.rFile = rFile; + this.outputScalar = outputScalar; + } + + @Parameterized.Parameters(name = "{index}: einsum={0}") + public static Collection data() throws IOException { + List parameters = new ArrayList<>(); + + int counter = 1; + + for (Config config : TEST_CONFIGS) { + List files = new ArrayList<>(); + String fullDMLScriptName = "SystemDS_einsum_test" + counter; + + File dmlFile = File.createTempFile(fullDMLScriptName, ".dml"); + dmlFile.deleteOnExit(); + + boolean outputScalar = config.einsumStr.trim().endsWith("->"); + + StringBuilder sb = createDmlFile(config, outputScalar); + + Files.writeString(dmlFile.toPath(), sb.toString()); + + File rFile = File.createTempFile(fullDMLScriptName, ".R"); + rFile.deleteOnExit(); + + sb = createRFile(config, outputScalar); + + Files.writeString(rFile.toPath(), sb.toString()); + + parameters.add(new Object[]{config.einsumStr, config.shapes, dmlFile, rFile, outputScalar, counter}); + + counter++; + } + + return parameters; + } + + private static StringBuilder createDmlFile(Config config, boolean outputScalar) { + StringBuilder sb = new StringBuilder(); + + for (int i = 0; i < config.shapes.size(); i++) { + int[] dims = config.shapes.get(i); + + double factor = 0.0001; + sb.append("A"); + sb.append(i); + + if (dims.length == 1) { // A1 = seq(1,1000) * 0.0001 + sb.append(" = seq(1,"); + sb.append(dims[0]); + sb.append(") * "); + sb.append(factor); + } else { // A0 = matrix(seq(1,50000), 1000, 50) * 0.0001 + sb.append(" = matrix(seq(1, "); + sb.append(dims[0]*dims[1]); + sb.append("), "); + sb.append(dims[0]); + sb.append(", "); + sb.append(dims[1]); + + sb.append(") * "); + sb.append(factor); + } + sb.append("\n"); + } + sb.append("\n"); + + sb.append("R = einsum(\""); + sb.append(config.einsumStr); + sb.append("\", "); + + for (int i = 0; i < config.shapes.size()-1; i++) { + sb.append("A"); + sb.append(i); + sb.append(", "); + } + sb.append("A"); + sb.append(config.shapes.size()-1); + sb.append(")"); + + sb.append("\n\n"); + sb.append("write(R, $1)\n"); + return sb; + } + + private static StringBuilder createRFile(Config config, boolean outputScalar) { + StringBuilder sb = new StringBuilder(); + sb.append("args<-commandArgs(TRUE)\n"); + sb.append("options(digits=22)\n"); + sb.append("library(\"Matrix\")\n"); + sb.append("library(\"matrixStats\")\n"); + sb.append("library(\"einsum\")\n\n"); + + + for (int i = 0; i < config.shapes.size(); i++) { + int[] dims = config.shapes.get(i); + + double factor = 0.0001; + sb.append("A"); + sb.append(i); + + if (dims.length == 1) { // A1 = seq(1,1000) * 0.0001 + sb.append(" = seq(1,"); + sb.append(dims[0]); + sb.append(") * "); + sb.append(factor); + } else { // A0 = matrix(seq(1,50000), 1000, 50, byrow=TRUE) * 0.0001 + sb.append(" = matrix(seq(1, "); + sb.append(dims[0]*dims[1]); + sb.append("), "); + sb.append(dims[0]); + sb.append(", "); + sb.append(dims[1]); + + sb.append(", byrow=TRUE) * "); + sb.append(factor); + } + sb.append("\n"); + } + sb.append("\n"); + + sb.append("R = einsum(\""); + sb.append(config.einsumStr); + sb.append("\", "); + + for (int i = 0; i < config.shapes.size()-1; i++) { + sb.append("A"); + sb.append(i); + sb.append(", "); + } + sb.append("A"); + sb.append(config.shapes.size()-1); + sb.append(")"); + + sb.append("\n\n"); + if(outputScalar){ + sb.append("write(R, paste(args[2], \"S\", sep=\"\"))\n"); + }else{ + sb.append("writeMM(as(R, \"CsparseMatrix\"), paste(args[2], \"S\", sep=\"\"))\n"); + } + return sb; + } + + @Test + public void testEinsumWithFiles() { + System.out.println("Testing einsum: " + this.einsumStr); + testCodegenIntegration(TEST_NAME_EINSUM+this.id); + } + @After + public void cleanUp() { + if (this.dmlFile.exists()) { + boolean deleted = this.dmlFile.delete(); + if (!deleted) { + System.err.println("Failed to delete temp file: " + this.dmlFile.getAbsolutePath()); + } + } + if (this.rFile.exists()) { + boolean deleted = this.rFile.delete(); + if (!deleted) { + System.err.println("Failed to delete temp file: " + this.rFile.getAbsolutePath()); + } + } + } + + private static class Config { + String einsumStr; + List shapes; + + Config(String einsum, List shapes) { + this.einsumStr = einsum; + this.shapes = shapes; + } + } + + private static int[] shape(int... dims) { + return dims; + } private static final Log LOG = LogFactory.getLog(EinsumTest.class.getName()); private static final String TEST_NAME_EINSUM = "einsum"; - private static final String TEST_EINSUM1 = TEST_NAME_EINSUM+"1"; - private static final String TEST_EINSUM2 = TEST_NAME_EINSUM+"2"; - private static final String TEST_EINSUM3 = TEST_NAME_EINSUM+"3"; - private static final String TEST_EINSUM4 = TEST_NAME_EINSUM+"4"; - private static final String TEST_EINSUM5 = TEST_NAME_EINSUM+"5"; - private static final String TEST_EINSUM6 = TEST_NAME_EINSUM+"6"; - private static final String TEST_EINSUM7 = TEST_NAME_EINSUM+"7"; - private static final String TEST_EINSUM8 = TEST_NAME_EINSUM+"8"; - private static final String TEST_EINSUM9 = TEST_NAME_EINSUM+"9"; - private static final String TEST_EINSUM10 = TEST_NAME_EINSUM+"10"; - private static final String TEST_EINSUM11 = TEST_NAME_EINSUM+"11"; - private static final String TEST_EINSUM12 = TEST_NAME_EINSUM+"12"; - private static final String TEST_EINSUM13 = TEST_NAME_EINSUM+"13"; - private static final String TEST_EINSUM14 = TEST_NAME_EINSUM+"14"; - private static final String TEST_EINSUM15 = TEST_NAME_EINSUM+"15"; - private static final String TEST_EINSUM16 = TEST_NAME_EINSUM+"16"; - private static final String TEST_EINSUM17 = TEST_NAME_EINSUM+"17"; - private static final String TEST_EINSUM18 = TEST_NAME_EINSUM+"18"; - private static final String TEST_EINSUM19 = TEST_NAME_EINSUM+"19"; - private static final String TEST_EINSUM20 = TEST_NAME_EINSUM+"20"; private static final String TEST_DIR = "functions/einsum/"; private static final String TEST_CLASS_DIR = TEST_DIR + EinsumTest.class.getSimpleName() + "/"; private final static String TEST_CONF = "SystemDS-config-codegen.xml"; @@ -68,89 +293,27 @@ public class EinsumTest extends AutomatedTestBase @Override public void setUp() { TestUtils.clearAssertionInformation(); - for(int i=1; i<=20; i++) + for(int i = 1; i<= TEST_CONFIGS.size(); i++) addTestConfiguration( TEST_NAME_EINSUM+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_EINSUM+i, new String[] { String.valueOf(i) }) ); } - @Test - public void testCodegenEinsum1CP() { - testCodegenIntegration( TEST_EINSUM1, ExecType.CP ); - } - @Test - public void testCodegenEinsum2CP() { - testCodegenIntegration( TEST_EINSUM2, ExecType.CP ); - } - @Test - public void testCodegenEinsum3CP() { - testCodegenIntegration( TEST_EINSUM3, ExecType.CP ); - } - @Test - public void testCodegenEinsum4CP() { - testCodegenIntegration( TEST_EINSUM4, ExecType.CP ); - } - @Test - public void testCodegenEinsum5CP() { - testCodegenIntegration( TEST_EINSUM5, ExecType.CP ); - } - @Test - public void testCodegenEinsum6CP() { - testCodegenIntegration( TEST_EINSUM6, ExecType.CP ); - } - @Test - public void testCodegenEinsum7CP() { - testCodegenIntegration( TEST_EINSUM7, ExecType.CP ); - } - @Test - public void testCodegenEinsum8CP() { testCodegenIntegration( TEST_EINSUM8, ExecType.CP ); } - @Test - public void testCodegenEinsum9CP() { - testCodegenIntegration( TEST_EINSUM9, ExecType.CP ); - } - @Test - public void testCodegenEinsum10CP() { - testCodegenIntegration( TEST_EINSUM10, ExecType.CP ); - } - @Test - public void testCodegenEinsum11CP() { - testCodegenIntegration( TEST_EINSUM11, ExecType.CP ); - } - @Test - public void testCodegenEinsum12CP() { - testCodegenIntegration( TEST_EINSUM12, ExecType.CP ); - } - @Test - public void testCodegenEinsum13CP() { - testCodegenIntegration( TEST_EINSUM13, ExecType.CP, true ); - } - @Test - public void testCodegenEinsum14CP() { testCodegenIntegration( TEST_EINSUM14, ExecType.CP); } - @Test - public void testCodegenEinsum15CP() { testCodegenIntegration( TEST_EINSUM15, ExecType.CP); } - @Test - public void testCodegenEinsum16CP() { testCodegenIntegration( TEST_EINSUM16, ExecType.CP); } - @Test - public void testCodegenEinsum17CP() { testCodegenIntegration( TEST_EINSUM17, ExecType.CP); } - @Test - public void testCodegenEinsum18CP() { testCodegenIntegration( TEST_EINSUM18, ExecType.CP); } - @Test - public void testCodegenEinsum19CP() { testCodegenIntegration( TEST_EINSUM19, ExecType.CP); } - @Test - public void testCodegenEinsum20CP() { testCodegenIntegration( TEST_EINSUM20, ExecType.CP); } - private void testCodegenIntegration( String testname, ExecType instType) { testCodegenIntegration(testname, instType, false); } - private void testCodegenIntegration( String testname, ExecType instType, boolean outputScalar ) + + private void testCodegenIntegration( String testname) { boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; - ExecMode platformOld = setExecMode(instType); + ExecMode platformOld = setExecMode(ExecType.CP); + String testnameDml = this.dmlFile.getAbsolutePath(); + String testnameR = this.rFile.getAbsolutePath(); try { TestConfiguration config = getTestConfiguration(testname); loadTestConfiguration(config); String HOME = SCRIPT_DIR + TEST_DIR; - fullDMLScriptName = HOME + testname + ".dml"; + fullDMLScriptName = testnameDml; programArgs = new String[]{"-stats", "-explain", "-args", output("S") }; - fullRScriptName = HOME + testname + ".R"; + fullRScriptName = testnameR; rCmd = getRCmd(inputDir(), expectedDir()); OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = false; @@ -177,6 +340,7 @@ private void testCodegenIntegration( String testname, ExecType instType, boolean } } + /** * Override default configuration with custom test configuration to ensure * scratch space and local temporary directory locations are also updated. diff --git a/src/test/scripts/functions/einsum/einsum1.R b/src/test/scripts/functions/einsum/einsum1.R deleted file mode 100644 index 98e52486f95..00000000000 --- a/src/test/scripts/functions/einsum/einsum1.R +++ /dev/null @@ -1,34 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -args<-commandArgs(TRUE) -options(digits=22) -library("Matrix") -library("matrixStats") -library("einsum") - -P = matrix(seq(1,3000), 600, 5, byrow=TRUE); -X = matrix(seq(1,6000), 600, 10, byrow=TRUE); - -# R = t(P) %*% X; -R = einsum("ji,jz->iz",P,X) - -writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum1.dml b/src/test/scripts/functions/einsum/einsum1.dml deleted file mode 100644 index 1523339e21c..00000000000 --- a/src/test/scripts/functions/einsum/einsum1.dml +++ /dev/null @@ -1,30 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -P = matrix(seq(1,3000), 600,5); -X = matrix(seq(1,6000), 600, 10) - -while(FALSE){} - -#R = t(P) %*% X ; - -R = einsum("ji,jz->iz",P,X) -write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum10.R b/src/test/scripts/functions/einsum/einsum10.R deleted file mode 100644 index 9fe7fa58b33..00000000000 --- a/src/test/scripts/functions/einsum/einsum10.R +++ /dev/null @@ -1,35 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -args<-commandArgs(TRUE) -options(digits=22) -library("Matrix") -library("matrixStats") -library("einsum") - -A = matrix(seq(1,3000), 5, 600, byrow=TRUE); -B = matrix(seq(1,6000), 600, 10, byrow=TRUE); -C = matrix(seq(1,50), 10, 5, byrow=TRUE); -D = matrix(seq(1,20), 5, 4, byrow=TRUE); - -R = einsum("ab,bc,cd,de->ae",A,B,C,D) - -writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum10.dml b/src/test/scripts/functions/einsum/einsum10.dml deleted file mode 100644 index b7d2e931da5..00000000000 --- a/src/test/scripts/functions/einsum/einsum10.dml +++ /dev/null @@ -1,31 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -A = matrix(seq(1,3000), 5, 600) -B = matrix(seq(1,6000), 600, 10); -C = matrix(seq(1,50), 10, 5); -D = matrix(seq(1,20), 5, 4); - -while(FALSE){} - -R = einsum("ab,bc,cd,de->ae",A,B,C,D) - -write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum11.R b/src/test/scripts/functions/einsum/einsum11.R deleted file mode 100644 index 45c2a647799..00000000000 --- a/src/test/scripts/functions/einsum/einsum11.R +++ /dev/null @@ -1,38 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -args<-commandArgs(TRUE) -options(digits=22) -library("Matrix") -library("matrixStats") -library("einsum") - - -A = matrix(seq(1,300), 5, 60, byrow=TRUE)/1000; -B = matrix(seq(1,150), 5, 30, byrow=TRUE)/1000; -C = matrix(seq(1,500), 5, 100, byrow=TRUE)/1000; -D = matrix(seq(1,1800), 60, 30, byrow=TRUE)/1000; -E = matrix(seq(1,6000), 100, 60, byrow=TRUE)/1000; -F = matrix(seq(1,3000), 100, 30, byrow=TRUE)/1000; - -R = einsum("fx,fg,fz,xg,zx,zg->g",A,B,C,D,E,F) - -writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum11.dml b/src/test/scripts/functions/einsum/einsum11.dml deleted file mode 100644 index 7c2b3e0e9f2..00000000000 --- a/src/test/scripts/functions/einsum/einsum11.dml +++ /dev/null @@ -1,33 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -A = matrix(seq(1,300), 5, 60)/1000; -B = matrix(seq(1,150), 5, 30)/1000; -C = matrix(seq(1,500), 5, 100)/1000; -D = matrix(seq(1,1800), 60, 30)/1000; -E = matrix(seq(1,6000), 100, 60)/1000; -F = matrix(seq(1,3000), 100, 30)/1000; - -while(FALSE){} - -R = einsum("fx,fg,fz,xg,zx,zg->g",A,B,C,D,E,F) - -write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum12.R b/src/test/scripts/functions/einsum/einsum12.R deleted file mode 100644 index b2abb5bf642..00000000000 --- a/src/test/scripts/functions/einsum/einsum12.R +++ /dev/null @@ -1,35 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -args<-commandArgs(TRUE) -options(digits=22) -library("Matrix") -library("matrixStats") -library("einsum") - -A = matrix(seq(1,3000), 600, 5, byrow=TRUE)/1000; -B = matrix(seq(1,6000), 600, 10, byrow=TRUE)/1000; -C = matrix(seq(1,3600), 600, 6, byrow=TRUE)/1000; -D = matrix(seq(1,50), 5, 10, byrow=TRUE)/1000; - -R = einsum("fx,fg,fz,xg->z",A,B,C,D) - -writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum12.dml b/src/test/scripts/functions/einsum/einsum12.dml deleted file mode 100644 index f8bbd275100..00000000000 --- a/src/test/scripts/functions/einsum/einsum12.dml +++ /dev/null @@ -1,31 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -A = matrix(seq(1,3000), 600, 5)/1000; -B = matrix(seq(1,6000), 600, 10)/1000; -C = matrix(seq(1,3600), 600, 6)/1000; -D = matrix(seq(1,50), 5, 10)/1000; - -while(FALSE){} - -R = einsum("fx,fg,fz,xg->z",A,B,C,D) - -write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum13.R b/src/test/scripts/functions/einsum/einsum13.R deleted file mode 100644 index e71b18a5bc9..00000000000 --- a/src/test/scripts/functions/einsum/einsum13.R +++ /dev/null @@ -1,33 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -args<-commandArgs(TRUE) -options(digits=22) -library("Matrix") -library("matrixStats") -library("einsum") - -X = matrix(seq(1,6000), 600, 10, byrow=TRUE); -P = matrix(seq(1,30), 6, 5, byrow=TRUE); - -R = einsum("ab,cd->",P,X) - -write(R, paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum13.dml b/src/test/scripts/functions/einsum/einsum13.dml deleted file mode 100644 index 337a5cc5573..00000000000 --- a/src/test/scripts/functions/einsum/einsum13.dml +++ /dev/null @@ -1,28 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -X = matrix(seq(1,6000), 600, 10); -P = matrix(seq(1,30), 6, 5) - -while(FALSE){} - -R = einsum("ab,cd->",P,X) -write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum14.R b/src/test/scripts/functions/einsum/einsum14.R deleted file mode 100644 index 7d0b7eef07e..00000000000 --- a/src/test/scripts/functions/einsum/einsum14.R +++ /dev/null @@ -1,34 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -args<-commandArgs(TRUE) -options(digits=22) -library("Matrix") -library("matrixStats") -library("einsum") - -X = matrix(seq(1,6000), 600, 10, byrow=TRUE); -P = matrix(seq(1,30), 6, 5, byrow=TRUE); - -# R = P * sum(X) -R = einsum("ab,cd->ba",P,X) - -writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum14.dml b/src/test/scripts/functions/einsum/einsum14.dml deleted file mode 100644 index fae1e38b5e4..00000000000 --- a/src/test/scripts/functions/einsum/einsum14.dml +++ /dev/null @@ -1,28 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -X = matrix(seq(1,6000), 600, 10); -P = matrix(seq(1,30), 6, 5) - -while(FALSE){} - -R = einsum("ab,cd->ba",P,X) -write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum15.R b/src/test/scripts/functions/einsum/einsum15.R deleted file mode 100644 index da967ef3f1f..00000000000 --- a/src/test/scripts/functions/einsum/einsum15.R +++ /dev/null @@ -1,33 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -args<-commandArgs(TRUE) -options(digits=22) -library("Matrix") -library("matrixStats") -library("einsum") - -P = as.vector(seq(1,30)); -X = as.vector(seq(1,600)); - -R = einsum("a,c->ac",P,X) - -writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum15.dml b/src/test/scripts/functions/einsum/einsum15.dml deleted file mode 100644 index 128730eef34..00000000000 --- a/src/test/scripts/functions/einsum/einsum15.dml +++ /dev/null @@ -1,28 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -P = matrix(seq(1,30), 30, 1) -X = matrix(seq(1,600), 600, 1); - -while(FALSE){} - -R = einsum("a,c->ac",P,X) -write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum16.R b/src/test/scripts/functions/einsum/einsum16.R deleted file mode 100644 index 683cd58a7bf..00000000000 --- a/src/test/scripts/functions/einsum/einsum16.R +++ /dev/null @@ -1,32 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -args<-commandArgs(TRUE) -options(digits=22) -library("Matrix") -library("matrixStats") -library("einsum") - -X = as.vector(seq(1,600)); - -R = einsum("a,a->a",X,X) - -writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum16.dml b/src/test/scripts/functions/einsum/einsum16.dml deleted file mode 100644 index 837db17a018..00000000000 --- a/src/test/scripts/functions/einsum/einsum16.dml +++ /dev/null @@ -1,27 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -X = matrix(seq(1,600), 600, 1); - -while(FALSE){} - -R = einsum("a,a->a",X,X) -write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum17.R b/src/test/scripts/functions/einsum/einsum17.R deleted file mode 100644 index 48d50e0acf4..00000000000 --- a/src/test/scripts/functions/einsum/einsum17.R +++ /dev/null @@ -1,33 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -args<-commandArgs(TRUE) -options(digits=22) -library("Matrix") -library("matrixStats") -library("einsum") - -X = as.vector(seq(1,600)); - -# R = P * sum(X) -R = einsum("a,a->a",X,X) - -writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum17.dml b/src/test/scripts/functions/einsum/einsum17.dml deleted file mode 100644 index 1053df0539f..00000000000 --- a/src/test/scripts/functions/einsum/einsum17.dml +++ /dev/null @@ -1,27 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -X = matrix(seq(1,600), 1, 600); - -while(FALSE){} - -R = einsum("a,a->a",X,X) -write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum18.R b/src/test/scripts/functions/einsum/einsum18.R deleted file mode 100644 index fa25980405b..00000000000 --- a/src/test/scripts/functions/einsum/einsum18.R +++ /dev/null @@ -1,38 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -args<-commandArgs(TRUE) -options(digits=22) -library("Matrix") -library("matrixStats") -library("einsum") - -A0 = matrix(seq(1,250), 10, 25, byrow=TRUE) * 0.0001 -A1 = matrix(seq(1,200), 10, 20, byrow=TRUE) * 0.0001 -A2 = matrix(seq(1,30), 10, 3, byrow=TRUE) * 0.0001 -A3 = matrix(seq(1,500), 25, 20, byrow=TRUE) * 0.0001 -A4 = matrix(seq(1,33), 3, 11, byrow=TRUE) * 0.0001 -A5 = seq(1,11) * 0.0001 -A6 = seq(1,3) * 0.0001 - -R = einsum("fx,fg,fz,xg,pq,q,p->zp", A0, A1, A2, A3, A4, A5, A6) - -writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum18.dml b/src/test/scripts/functions/einsum/einsum18.dml deleted file mode 100644 index c71654bf01d..00000000000 --- a/src/test/scripts/functions/einsum/einsum18.dml +++ /dev/null @@ -1,33 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -A0 = matrix(seq(1,250), 10, 25) * 0.0001 -A1 = matrix(seq(1,200), 10, 20) * 0.0001 -A2 = matrix(seq(1,30), 10, 3) * 0.0001 -A3 = matrix(seq(1,500), 25, 20) * 0.0001 -A4 = matrix(seq(1,33), 3, 11) * 0.0001 -A5 = seq(1,11) * 0.0001 -A6 = seq(1,3) * 0.0001 - -while(FALSE){} - -R = einsum("fx,fg,fz,xg,pq,q,p->zp", A0, A1, A2, A3, A4, A5, A6) -write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum19.R b/src/test/scripts/functions/einsum/einsum19.R deleted file mode 100644 index fced04caa92..00000000000 --- a/src/test/scripts/functions/einsum/einsum19.R +++ /dev/null @@ -1,33 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -args<-commandArgs(TRUE) -options(digits=22) -library("Matrix") -library("matrixStats") -library("einsum") - -A0 = matrix(seq(1,50000), 1000, 50, byrow=TRUE) * 0.0001 -A1 = seq(1,50) * 0.0001 - -R = einsum("ij,j->ij", A0, A1) - -writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum19.dml b/src/test/scripts/functions/einsum/einsum19.dml deleted file mode 100644 index c5743714a26..00000000000 --- a/src/test/scripts/functions/einsum/einsum19.dml +++ /dev/null @@ -1,29 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - - -A0 = matrix(seq(1,50000), 1000, 50) * 0.0001 -A1 = seq(1,50) * 0.0001 - -while(FALSE){} - -R = einsum("ij,j->ij", A0, A1) -write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum2.R b/src/test/scripts/functions/einsum/einsum2.R deleted file mode 100644 index 15be5c772c2..00000000000 --- a/src/test/scripts/functions/einsum/einsum2.R +++ /dev/null @@ -1,34 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -args<-commandArgs(TRUE) -options(digits=22) -library("Matrix") -library("matrixStats") -library("einsum") - -P = matrix(seq(1,3000), 5, 600, byrow=TRUE); -X = matrix(seq(1,6000), 600, 10, byrow=TRUE); - -# R = P %*% X; -R = einsum("ij,jz->iz",P,X) - -writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum2.dml b/src/test/scripts/functions/einsum/einsum2.dml deleted file mode 100644 index eb47e19c807..00000000000 --- a/src/test/scripts/functions/einsum/einsum2.dml +++ /dev/null @@ -1,30 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -P = matrix(seq(1,3000), 5, 600) -X = matrix(seq(1,6000), 600, 10); - -while(FALSE){} - -#R = P %*% X ; -R = einsum("ij,jz->iz",P,X) - -write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum20.R b/src/test/scripts/functions/einsum/einsum20.R deleted file mode 100644 index 2b3c556bda3..00000000000 --- a/src/test/scripts/functions/einsum/einsum20.R +++ /dev/null @@ -1,33 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -args<-commandArgs(TRUE) -options(digits=22) -library("Matrix") -library("matrixStats") -library("einsum") - -A0 = matrix(seq(1,50000), 1000, 50, byrow=TRUE) * 0.0001 -A1 = seq(1,1000) * 0.0001 - -R = einsum("ij,i->ij", A0, A1) - -writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum20.dml b/src/test/scripts/functions/einsum/einsum20.dml deleted file mode 100644 index aeb591beff9..00000000000 --- a/src/test/scripts/functions/einsum/einsum20.dml +++ /dev/null @@ -1,29 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - - -A0 = matrix(seq(1,50000), 1000, 50) * 0.0001 -A1 = seq(1,1000) * 0.0001 - -while(FALSE){} - -R = einsum("ij,i->ij", A0, A1) -write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum3.R b/src/test/scripts/functions/einsum/einsum3.R deleted file mode 100644 index ce8a34f51ac..00000000000 --- a/src/test/scripts/functions/einsum/einsum3.R +++ /dev/null @@ -1,34 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -args<-commandArgs(TRUE) -options(digits=22) -library("Matrix") -library("matrixStats") -library("einsum") - -P = matrix(seq(1,3000), 600, 5, byrow=TRUE); -X = matrix(seq(1,6000), 600, 10, byrow=TRUE); - -# R = sum(t(P) %*% X); -R = einsum("ji,jz->i",P,X) - -writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum3.dml b/src/test/scripts/functions/einsum/einsum3.dml deleted file mode 100644 index 9e8c96939f2..00000000000 --- a/src/test/scripts/functions/einsum/einsum3.dml +++ /dev/null @@ -1,30 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- -P = matrix(seq(1,3000), 600, 5) -X = matrix(seq(1,6000), 600, 10); - -while(FALSE){} - - -#R = sum(t(P) %*% X) ; - -R = einsum("ji,jz->i",P,X) -write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum4.R b/src/test/scripts/functions/einsum/einsum4.R deleted file mode 100644 index 74f5560464d..00000000000 --- a/src/test/scripts/functions/einsum/einsum4.R +++ /dev/null @@ -1,33 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -args<-commandArgs(TRUE) -options(digits=22) -library("Matrix") -library("matrixStats") -library("einsum") - -P = matrix(seq(1,3000), 600, 5, byrow=TRUE); -X = matrix(seq(1,6000), 10, 600, byrow=TRUE); - -R = einsum("ji,zj->i",P,X) - -writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum4.dml b/src/test/scripts/functions/einsum/einsum4.dml deleted file mode 100644 index 3806db22ffe..00000000000 --- a/src/test/scripts/functions/einsum/einsum4.dml +++ /dev/null @@ -1,28 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- -P = matrix(seq(1,3000), 600, 5) -X = matrix(seq(1,6000), 10, 600); - -while(FALSE){} - - -R = einsum("ji,zj->i",P,X) -write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum5.R b/src/test/scripts/functions/einsum/einsum5.R deleted file mode 100644 index bb9caf31e57..00000000000 --- a/src/test/scripts/functions/einsum/einsum5.R +++ /dev/null @@ -1,34 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -args<-commandArgs(TRUE) -options(digits=22) -library("Matrix") -library("matrixStats") -library("einsum") - -X = matrix(seq(1,6000), 600, 10, byrow=TRUE); -P = matrix(seq(1,3000), 600, 5, byrow=TRUE); - -# R = colSums(t(P) %*% X); -R = einsum("ji,jz->z",P,X) - -writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum5.dml b/src/test/scripts/functions/einsum/einsum5.dml deleted file mode 100644 index 84c3efbdfae..00000000000 --- a/src/test/scripts/functions/einsum/einsum5.dml +++ /dev/null @@ -1,29 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- -P = matrix(seq(1,3000), 600, 5) -X = matrix(seq(1,6000), 600, 10); - -while(FALSE){} - -#R = colSums(t(P) %*% X) ; - -R = einsum("ji,jz->z",P,X) -write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum6.R b/src/test/scripts/functions/einsum/einsum6.R deleted file mode 100644 index 902941b9222..00000000000 --- a/src/test/scripts/functions/einsum/einsum6.R +++ /dev/null @@ -1,34 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -args<-commandArgs(TRUE) -options(digits=22) -library("Matrix") -library("matrixStats") -library("einsum") - -X = matrix(seq(1,6000), 600, 10, byrow=TRUE); -P = matrix(seq(1,3000), 600, 5, byrow=TRUE); - -# R = rowSums(P) * rowSums(X) -R = einsum("ji,jz->j",P,X) - -writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum6.dml b/src/test/scripts/functions/einsum/einsum6.dml deleted file mode 100644 index de5fb654e20..00000000000 --- a/src/test/scripts/functions/einsum/einsum6.dml +++ /dev/null @@ -1,30 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -X = matrix(seq(1,6000), 600, 10); -P = matrix(seq(1,3000), 600, 5) - -while(FALSE){} - -#R = colSums(t(P) %*% X) ; - -R = einsum("ji,jz->j",P,X) -write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum7.R b/src/test/scripts/functions/einsum/einsum7.R deleted file mode 100644 index f0075f86dc9..00000000000 --- a/src/test/scripts/functions/einsum/einsum7.R +++ /dev/null @@ -1,35 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -args<-commandArgs(TRUE) -options(digits=22) -library("Matrix") -library("matrixStats") -library("einsum") - -P = matrix(seq(1,3000), 600, 5, byrow=TRUE); -X = matrix(seq(1,6000), 600, 10, byrow=TRUE); -Z = matrix(seq(1,20), 10, 2, byrow=TRUE); - -# R = t(P) %*% X; -# RR= R %*% Z -RR = einsum("ji,jz,zx->ix",P,X,Z) -writeMM(as(RR, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum7.dml b/src/test/scripts/functions/einsum/einsum7.dml deleted file mode 100644 index 5a1b889ce29..00000000000 --- a/src/test/scripts/functions/einsum/einsum7.dml +++ /dev/null @@ -1,28 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- -P = matrix(seq(1,3000), 600, 5) -X = matrix(seq(1,6000), 600, 10); -Z = matrix(seq(1,20), 10, 2) - -while(FALSE){} - -R = einsum("ji,jz,zx->ix",P,X,Z) -write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum8.R b/src/test/scripts/functions/einsum/einsum8.R deleted file mode 100644 index 5587c9d878f..00000000000 --- a/src/test/scripts/functions/einsum/einsum8.R +++ /dev/null @@ -1,34 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -args<-commandArgs(TRUE) -options(digits=22) -library("Matrix") -library("matrixStats") -library("einsum") - -P = matrix(seq(1,3000), 600, 5, byrow=TRUE); -X = matrix(seq(1,6000), 600, 10, byrow=TRUE); - -# R = t(P) %*% X; -R = einsum("ji,jz->zi",P,X) - -writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum8.dml b/src/test/scripts/functions/einsum/einsum8.dml deleted file mode 100644 index 2e47614800b..00000000000 --- a/src/test/scripts/functions/einsum/einsum8.dml +++ /dev/null @@ -1,30 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -P = matrix(seq(1,3000), 600,5); -X = matrix(seq(1,6000), 600, 10) - -while(FALSE){} - -#R = t(P) %*% X ; - -R = einsum("ji,jz->zi",P,X) -write(R, $1) diff --git a/src/test/scripts/functions/einsum/einsum9.R b/src/test/scripts/functions/einsum/einsum9.R deleted file mode 100644 index 61cb86a96a7..00000000000 --- a/src/test/scripts/functions/einsum/einsum9.R +++ /dev/null @@ -1,34 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -args<-commandArgs(TRUE) -options(digits=22) -library("Matrix") -library("matrixStats") -library("einsum") - -P = matrix(seq(1,3000), 5, 600, byrow=TRUE); -X = matrix(seq(1,6000), 10, 600, byrow=TRUE); - -# R = t(P) %*% X; -R = einsum("ij,zj->iz",P,X) - -writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); diff --git a/src/test/scripts/functions/einsum/einsum9.dml b/src/test/scripts/functions/einsum/einsum9.dml deleted file mode 100644 index a22c03fdc9b..00000000000 --- a/src/test/scripts/functions/einsum/einsum9.dml +++ /dev/null @@ -1,30 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -P = matrix(seq(1,3000), 5,600); -X = matrix(seq(1,6000), 10, 600) - -while(FALSE){} - -#R = t(P) %*% X ; - -R = einsum("ij,zj->iz",P,X) -write(R, $1) From ba21c30ba85135db17928f1d6432da23967b548f Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Sat, 9 Aug 2025 23:12:29 +0200 Subject: [PATCH 26/28] refactored plan generation and execution into separate steps; plan is optimal; by default no codegen is used --- .../instructions/cp/EinsumCPInstruction.java | 967 ++++++++---------- .../test/functions/einsum/EinsumTest.java | 52 +- 2 files changed, 483 insertions(+), 536 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java index 80002084c63..007c89f7821 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -19,7 +19,9 @@ package org.apache.sysds.runtime.instructions.cp; +import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.lang3.tuple.Triple; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.log4j.Level; @@ -32,6 +34,7 @@ import org.apache.sysds.hops.codegen.cplan.CNodeBinary; import org.apache.sysds.hops.codegen.cplan.CNodeCell; import org.apache.sysds.hops.codegen.cplan.CNodeData; +import org.apache.sysds.hops.codegen.cplan.CNodeNary; import org.apache.sysds.hops.codegen.cplan.CNodeRow; import org.apache.sysds.runtime.codegen.*; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; @@ -42,10 +45,13 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.AggregateOperator; import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; +import org.apache.sysds.runtime.matrix.operators.BinaryOperator; import org.apache.sysds.runtime.matrix.operators.Operator; import org.apache.sysds.runtime.matrix.operators.ReorgOperator; +import org.apache.sysds.runtime.matrix.operators.SimpleOperator; import java.util.*; +import java.util.function.Predicate; public class EinsumCPInstruction extends BuiltinNaryCPInstruction { public static boolean FORCE_CELL_TPL = false; @@ -116,89 +122,70 @@ public void processInstruction(ExecutionContext ec) { } if (scalar != null) { - boolean appliedToSomeMatrix = false; - for(int i = 0; i < inputs.size(); i++){ - if(inputs.get(i) != null){ - inputs.set(i, getScalarMultiplyMatrixBlock(inputs.get(i), scalar)); - appliedToSomeMatrix = true; break; + inputsChars.add(""); + inputs.add(new MatrixBlock(scalar)); + } + + HashMap characterToOccurences = new HashMap<>(); + for (Character key :einc.characterAppearanceIndexes.keySet()) { + characterToOccurences.put(key, einc.characterAppearanceIndexes.get(key).size()); + } + for (Character key :einc.charToDimensionSize.keySet()) { + if(!characterToOccurences.containsKey(key)) + characterToOccurences.put(key, 1); + } + + ArrayList eOpNodes = new ArrayList<>(inputsChars.size()); + for (int i = 0; i < inputsChars.size(); i++) { + if (inputsChars.get(i) == null) continue; + EOpNodeData n = new EOpNodeData(inputsChars.get(i).length() > 0 ? inputsChars.get(i).charAt(0) : null, inputsChars.get(i).length() > 1 ? inputsChars.get(i).charAt(1) : null, i); + eOpNodes.add(n); + } + Pair > plan = FORCE_CELL_TPL ? null : generatePlan(0, eOpNodes, einc.charToDimensionSize, characterToOccurences, einc.outChar1, einc.outChar2); + + + ArrayList resMatrices = FORCE_CELL_TPL ? null : executePlan(plan.getRight(), inputs); +// ArrayList resMatrices = executePlan(plan.getRight(), inputs, true); + + if(!FORCE_CELL_TPL && resMatrices.size() == 1){ + EOpNode resNode = plan.getRight().get(0); + if (einc.outChar1 != null && einc.outChar2 != null){ + if(resNode.c1 == einc.outChar1 && resNode.c2 == einc.outChar2){ + ec.setMatrixOutput(output.getName(), resMatrices.get(0)); } - } - if(!appliedToSomeMatrix){ - ec.setScalarOutput(output.getName(), new DoubleObject(scalar)); - releaseMatrixInputs(ec); - return; - } - } - - boolean needToDoCellTemplate = FORCE_CELL_TPL ? true : generatePlanAndExecute(inputs, einc); - - if (!needToDoCellTemplate){ - //check if any operations to do that were not-output dimension summations: - List remStrings = inputsChars.stream() - .filter(Objects::nonNull).toList(); - List remMbs = inputs.stream() - .filter(Objects::nonNull).toList(); - MatrixBlock res; - if(remStrings.size() == 1) { - String s = remStrings.get(0); - if(s.equals(resultString)){ - res=remMbs.get(0); - }else if(s.charAt(0) == s.charAt(1)) { - // diagonal needed - ReorgOperator op = new ReorgOperator(DiagIndex.getDiagIndexFnObject()); - res= remMbs.get(0).reorgOperations(op, new MatrixBlock(),0,0,0); - }else{ - //it has to be transpose: ab->ba + else if(resNode.c1 == einc.outChar2 && resNode.c2 == einc.outChar1){ ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); - res = remMbs.get(0).reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + MatrixBlock resM = resMatrices.get(0).reorgOperations(transpose, new MatrixBlock(),0,0,0); + ec.setMatrixOutput(output.getName(), resM); + }else{ + if(LOG.isTraceEnabled()) LOG.trace("Einsum expected: "+resultString + ", got: "+resNode.c1+resNode.c2); + throw new RuntimeException("Einsum plan produced different result"); } - } else{ - // maybe the leftovers are i,j and result should be ij or ji -> outer multp. - if(remStrings.size() == 2 && remStrings.get(0).length()==1 && remStrings.get(1).length()==1){ - MatrixBlock first; - MatrixBlock second; - - if(remStrings.get(0).charAt(0) == einc.outChar1 && remStrings.get(1).charAt(0) == einc.outChar2){ - first = remMbs.get(0); - second = remMbs.get(1); - }else if(remStrings.get(0).charAt(0) == einc.outChar2 && remStrings.get(1).charAt(0) == einc.outChar1){ - first = remMbs.get(1); - second = remMbs.get(0); - }else{ - throw new RuntimeException("Einsum runtime error: left with 2 vectors that cannot produce final result "+remStrings.get(0)+" , "+remStrings.get(1)); // should not happen - } - if(first.getNumColumns() > 1){ - int r = first.getNumColumns(); - first.setNumRows(r); - first.setNumColumns(1); - first.getDenseBlock().resetNoFill(r,1); - } - if(second.getNumRows() > 1){ - int c = second.getNumRows(); - second.setNumRows(1); - second.setNumColumns(c); - second.getDenseBlock().resetNoFill(1,c); - } - res = LibMatrixMult.matrixMult(first,second, _numThreads); - }else { - throw new RuntimeException("Einsum runtime error, reductions and multiplications finished but the did not produce one result"); // should not happen + }else if (einc.outChar1 != null){ + if(resNode.c1 == einc.outChar1 && resNode.c2 == null){ + ec.setMatrixOutput(output.getName(), resMatrices.get(0)); + }else{ + if(LOG.isTraceEnabled()) LOG.trace("Einsum expected: "+resultString + ", got: "+resNode.c1+resNode.c2); + throw new RuntimeException("Einsum plan produced different result"); + } + }else{ + if(resNode.c1 == null && resNode.c2 == null){ + ec.setScalarOutput(output.getName(), new DoubleObject(resMatrices.get(0).get(0, 0)));; } } - if (einc.outRows == 1 && einc.outCols == 1) - ec.setScalarOutput(output.getName(), new DoubleObject(res.get(0, 0))); - else ec.setMatrixOutput(output.getName(), res); - } - - else { // if (needToDoCellTemplate) - ArrayList mbs = new ArrayList<>(); + }else{ + // use cell template with loops for remaining + ArrayList mbs = resMatrices; ArrayList chars = new ArrayList<>(); - for (int i = 0; i < inputs.size(); i++) { - MatrixBlock mb = inputs.get(i); - if (mb != null) { - mbs.add(mb); - chars.add(inputsChars.get(i)); - } + + for (int i = 0; i < plan.getRight().size(); i++) { + String s; + if(plan.getRight().get(i).c1 == null) s = ""; + else if(plan.getRight().get(i).c2 == null) s = plan.getRight().get(i).c1.toString(); + else s = plan.getRight().get(i).c1.toString() + plan.getRight().get(i).c2; + chars.add(s); } + ArrayList summingChars = new ArrayList(); for (Character c : einc.characterAppearanceIndexes.keySet()) { if (c != einc.outChar1 && c != einc.outChar2) summingChars.add(c); @@ -211,10 +198,10 @@ public void processInstruction(ExecutionContext ec) { ec.setScalarOutput(output.getName(), new DoubleObject(res.get(0, 0))); else ec.setMatrixOutput(output.getName(), res); } - if(LOG.isTraceEnabled()) LOG.trace("EinsumCPInstruction Finished"); releaseMatrixInputs(ec); + } private void contractDimensionsAndComputeDiagonals(EinsumContext einc, ArrayList inputs) { @@ -255,522 +242,472 @@ private void contractDimensionsAndComputeDiagonals(EinsumContext einc, ArrayList } } - private void releaseMatrixInputs(ExecutionContext ec){ - for (CPOperand input : _in) - if(input.getDataType()==DataType.MATRIX) - ec.releaseMatrixInput(input.getName()); + private enum EBinaryOperand { // upper case: char has to remain, lower case: to be summed + ////// summations: ////// + aB_a,// -> B + Ba_a, // -> B + Ba_aC, // mmult -> BC + aB_Ca, + Ba_Ca, // -> BC + aB_aC, // outer mult, possibly with transposing first -> BC + a_a,// dot -> + + ////// elementwisemult and sums, something like ij,ij->i ////// + aB_aB,// elemwise and colsum -> B + Ba_Ba, // elemwise and rowsum ->B + Ba_aB, // elemwise, either colsum or rowsum -> B +// aB_Ba, + + ////// elementwise, no summations: ////// + A_A,// v-elemwise -> A + AB_AB,// M-M elemwise -> AB + AB_BA, // M-M.T elemwise -> AB + AB_A, // M-v colwise -> BA!? + BA_A, // M-v rowwise -> BA + ab_ab,//M-M sum all + ab_ba, //M-M.T sum all + ////// other ////// + A_B, // outer mult -> AB + A_scalar, // v-scalar + AB_scalar, // m-scalar + scalar_scalar } - - // returns true if there are elements that appear more than 2 times and cannot be summed - private boolean generatePlanAndExecute(ArrayList inputs, EinsumContext einc) { - boolean anyCouldNotDo; - boolean didAnything = false; // maybe multiplication will make it summable - do { - anyCouldNotDo = sumCharactersWherePossible(einc.characterAppearanceIndexes, inputs, einc.newEquationStringInputsSplit, einc.outChar1, einc.outChar2); - didAnything = false; - if(einc.newEquationStringInputsSplit.stream().filter(Objects::nonNull).count() > 1) - didAnything = multiplyTerms(einc.characterAppearanceIndexes, inputs, einc.newEquationStringInputsSplit, einc.outChar1, einc.outChar2); + private abstract class EOpNode { + public Character c1; + public Character c2; // nullable + public EOpNode(Character c1, Character c2){ + this.c1 = c1; + this.c2 = c2; } - while(didAnything); - - if(LOG.isTraceEnabled()) LOG.trace("generatePlanAndExecute() finished"); - - return anyCouldNotDo; } - - private static void EnsureMatrixBlockColumnVector(MatrixBlock mb){ - if(mb.getNumColumns() > 1){ - mb.setNumRows(mb.getNumColumns()); - mb.setNumColumns(1); - mb.getDenseBlock().resetNoFill(mb.getNumRows(),1); + private class EOpNodeBinary extends EOpNode { + public EOpNode left; + public EOpNode right; + public EBinaryOperand operand; + public EOpNodeBinary(Character c1, Character c2, EOpNode left, EOpNode right, EBinaryOperand operand){ + super(c1,c2); + this.left = left; + this.right = right; + this.operand = operand; } } - private static void EnsureMatrixBlockRowVector(MatrixBlock mb){ - if(mb.getNumRows() > 1){ - mb.setNumColumns(mb.getNumRows()); - mb.setNumRows(1); - mb.getDenseBlock().resetNoFill(1,mb.getNumColumns()); + private class EOpNodeData extends EOpNode { + public int matrixIdx; + public EOpNodeData(Character c1, Character c2, int matrixIdx){ + super(c1,c2); + this.matrixIdx = matrixIdx; } } - /* handle situation: ji,ji or ij,ji, j,j */ - private boolean multiplyTerms(HashMap> partsCharactersToIndices, ArrayList inputs, ArrayList inputsChars, Character outChar1, Character outChar2 ) { - HashMap> matrixStringToIndex = new HashMap<>(); - HashSet matrixStringToIndexSkip = new HashSet<>(); - HashMap> vectorStringToIndex = new HashMap<>(); - - for(int i = 0; i < inputsChars.size(); i++){ - String s = inputsChars.get(i); - if(s==null) continue; - - if(s.length() == 1){ - if (vectorStringToIndex.containsKey(s)) vectorStringToIndex.get(s).add(i); - else { LinkedList list = new LinkedList<>(); list.add(i); vectorStringToIndex.put(s, list); } - }else{ - if (matrixStringToIndex.containsKey(s)) matrixStringToIndex.get(s).add(i); - else { LinkedList list = new LinkedList<>(); list.add(i); matrixStringToIndex.put(s, list); } - } + private Pair /* ideally with one element */> generatePlan(int cost, ArrayList operands, HashMap charToSizeMap, HashMap charToOccurences, Character outChar1, Character outChar2) { + Integer minCost = cost; + List minNodes = operands; + + if (operands.size() == 2){ + boolean swap = (operands.get(0).c2 == null && operands.get(1).c2 != null) || operands.get(0).c1 == null; + EOpNode n1 = operands.get(!swap ? 0 : 1); + EOpNode n2 = operands.get(!swap ? 1 : 0); + Triple> t = TryCombineAndCost(n1, n2, charToSizeMap, charToOccurences, outChar1, outChar2); + if (t != null) { + EOpNodeBinary newNode = new EOpNodeBinary(t.getRight().getLeft(), t.getRight().getRight(), n1, n2, t.getMiddle()); + int thisCost = cost + t.getLeft(); + return Pair.of(thisCost, Arrays.asList(newNode)); + } + return Pair.of(cost, operands); + } + else if (operands.size() == 1){ + // check for transpose + return Pair.of(cost, operands); } - boolean doneAnything = false; - - // first do the vector-vector multipl. - for(var s : vectorStringToIndex.keySet()){ - if(vectorStringToIndex.get(s).size() <= 1) continue; - - doneAnything = true; + for(int i = 0; i < operands.size()-1; i++){ + for (int j = i+1; j < operands.size(); j++){ + boolean swap = (operands.get(i).c2 == null && operands.get(j).c2 != null) || operands.get(i).c1 == null; + EOpNode n1 = operands.get(!swap ? i : j); + EOpNode n2 = operands.get(!swap ? j : i); - Integer vectorIdx0 = vectorStringToIndex.get(s).removeFirst(); - if(LOG.isTraceEnabled()){ - StringBuilder sb = new StringBuilder(); + Triple> t = TryCombineAndCost(n1, n2, charToSizeMap, charToOccurences, outChar1, outChar2); + if (t != null){ + EOpNodeBinary newNode = new EOpNodeBinary(t.getRight().getLeft(), t.getRight().getRight(), n1, n2, t.getMiddle()); + int thisCost = cost + t.getLeft(); - LOG.trace("Element wise multiplying: "+s); - } - MatrixBlock mb = inputs.get(vectorIdx0); - EnsureMatrixBlockColumnVector(mb); + if(n1.c1 != null) charToOccurences.put(n1.c1, charToOccurences.get(n1.c1)-1); + if(n1.c2 != null) charToOccurences.put(n1.c2, charToOccurences.get(n1.c2)-1); + if(n2.c1 != null) charToOccurences.put(n2.c1, charToOccurences.get(n2.c1)-1); + if(n2.c2 != null) charToOccurences.put(n2.c2, charToOccurences.get(n2.c2)-1); - do{ // do all vectors with same char - Integer vectorIdx1 = vectorStringToIndex.get(s).removeFirst(); + if(newNode.c1 != null) charToOccurences.put(newNode.c1, charToOccurences.get(newNode.c1)+1); + if(newNode.c2 != null) charToOccurences.put(newNode.c2, charToOccurences.get(newNode.c2)+1); - ArrayList mbs = new ArrayList<>(); + ArrayList newOperands = new ArrayList<>(operands.size()-1); + for(int z = 0; z < operands.size(); z++){ + if(z != i && z != j) newOperands.add(operands.get(z)); + } + newOperands.add(newNode); - mbs.add(mb); - mbs.add(inputs.get(vectorIdx1)); - EnsureMatrixBlockColumnVector(mbs.get(1)); + Pair> furtherPlan = generatePlan(thisCost, newOperands,charToSizeMap, charToOccurences, outChar1, outChar2); + if(furtherPlan.getRight().size() < (minNodes.size()) || furtherPlan.getLeft() < minCost){ + minCost = furtherPlan.getLeft(); + minNodes = furtherPlan.getRight(); + } - mb = getCodegenElemwiseMult(mbs); + if(n1.c1 != null) charToOccurences.put(n1.c1, charToOccurences.get(n1.c1)+1); + if(n1.c2 != null) charToOccurences.put(n1.c2, charToOccurences.get(n1.c2)+1); + if(n2.c1 != null) charToOccurences.put(n2.c1, charToOccurences.get(n2.c1)+1); + if(n2.c2 != null) charToOccurences.put(n2.c2, charToOccurences.get(n2.c2)+1); + if(newNode.c1 != null) charToOccurences.put(newNode.c1, charToOccurences.get(newNode.c1)-1); + if(newNode.c2 != null) charToOccurences.put(newNode.c2, charToOccurences.get(newNode.c2)-1); + } + } + } - inputs.set(vectorIdx1, null); - inputsChars.set(vectorIdx1, null); - } while (vectorStringToIndex.get(s).size() > 1); + return Pair.of(minCost, minNodes); + } - inputs.set(vectorIdx0, null); - inputsChars.set(vectorIdx0, null); + private static Triple> TryCombineAndCost(EOpNode n1 , EOpNode n2, HashMap charToSizeMap, HashMap charToOccurences, Character outChar1, Character outChar2){ + Predicate cannotBeSummed = (c) -> + c == outChar1 || c == outChar2 || charToOccurences.get(c) > 2; - inputs.add(mb); - inputsChars.add(s); - if (partsCharactersToIndices.containsKey(s.charAt(0))) partsCharactersToIndices.get(s.charAt(0)).add(inputs.size() - 1); - vectorStringToIndex.get(s).add(inputs.size() - 1); + if(n1.c1 == null) { + // n2.c1 also has to be null + return Triple.of(1, EBinaryOperand.scalar_scalar, Pair.of(null, null)); } - for(var s : matrixStringToIndex.keySet()){ - if(matrixStringToIndexSkip.contains(s)) continue; - - String sT = s.length() == 2 ? String.valueOf(s.charAt(1)) + s.charAt(0) : null; - LinkedList idxs = matrixStringToIndex.get(s); - LinkedList idxsT = sT != null ? matrixStringToIndex.containsKey(sT) ? matrixStringToIndex.get(sT) : null : null; + if(n2.c1 == null) { + if(n1.c2 == null) + return Triple.of(charToSizeMap.get(n1.c1), EBinaryOperand.A_scalar, Pair.of(n1.c1, null)); + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.AB_scalar, Pair.of(n1.c1, n1.c2)); + } - Integer vectorIdx0 = null; - Integer vectorIdx1 = null; - String char0 = String.valueOf(s.charAt(0)); - if(vectorStringToIndex.containsKey(char0)){ - // ab,a-> ab - vectorIdx0 = vectorStringToIndex.get(char0).removeFirst(); // only one should left - } + if(n1.c1 == n2.c1){ + if(n1.c2 != null){ + if ( n1.c2 == n2.c2){ + if( cannotBeSummed.test(n1.c1)){ + if(cannotBeSummed.test(n1.c2)){ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.AB_AB, Pair.of(n1.c1, n1.c2)); + } + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ba_Ba, Pair.of(n1.c1, null)); + } - String char1 = String.valueOf(s.charAt(1)); - if(vectorStringToIndex.containsKey(char1)){ - // ab,b -> ab - vectorIdx1 = vectorStringToIndex.get(char1).removeFirst(); // only one should left - } + if(cannotBeSummed.test(n1.c2)){ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.aB_aB, Pair.of(n1.c2, null)); + } - if(idxs.size() <= 1 && idxsT == null && vectorIdx0 == null && vectorIdx1 == null) continue; + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.ab_ab, Pair.of(null, null)); - doneAnything = true; + } - // do decision if transpose idxs or idxsT: right now just always transpose second - if(LOG.isTraceEnabled()){ - StringBuilder sb = new StringBuilder(); - for(Integer idx : idxs){ - sb.append(inputsChars.get(idx)); - sb.append(","); + else if(n2.c2 == null){ + if(cannotBeSummed.test(n1.c1)){ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*2, EBinaryOperand.AB_A, Pair.of(n1.c1, n1.c2)); + } + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*2, EBinaryOperand.aB_a, Pair.of(n1.c2, null)); // in theory (null, n1.c2) } - if(idxsT != null) for(Integer idx : idxsT){ - sb.append(inputsChars.get(idx)); - sb.append(","); + else if(n1.c1 ==outChar1 || n1.c1==outChar2|| charToOccurences.get(n1.c1) > 2){ + return null;// AB,AC } - if(vectorIdx0 != null) { sb.append(s.charAt(0)).append(","); } - if(vectorIdx1 != null) { sb.append(s.charAt(1)).append(","); } - - LOG.trace("Element wise multiplying: "+sb.toString()); - } - Integer matrixIdx0 = idxs.removeFirst(); - - MatrixBlock mb = inputs.get(matrixIdx0); - - inputs.set(matrixIdx0, null); - inputsChars.set(matrixIdx0, null); - - while (!idxs.isEmpty()){ - Integer matrixIdx1 = idxs.removeFirst(); - - ArrayList mbs = new ArrayList<>(); - mbs.add(mb); - mbs.add(inputs.get(matrixIdx1)); - - mb = getCodegenElemwiseMult(mbs); - - inputs.set(matrixIdx1, null); - inputsChars.set(matrixIdx1, null); - } - - if(idxsT != null) while(!idxsT.isEmpty()) { - ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); - - Integer matrixIdx1 = idxsT.removeFirst(); - - ArrayList mbs = new ArrayList<>(); - mbs.add(mb); - mbs.add(inputs.get(matrixIdx1).reorgOperations(transpose, new MatrixBlock(), 0, 0, 0)); - - mb = getCodegenElemwiseMult(mbs); - - inputs.set(matrixIdx1, null); - inputsChars.set(matrixIdx1, null); + else { + return Triple.of((charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2))+(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c2)), EBinaryOperand.aB_aC, Pair.of(n1.c2, n2.c2)); // or n2.c2, n1.c2 + } + }else{ // n1.c2 = null -> c2.c2 = null + if(n1.c1 ==outChar1 || n1.c1==outChar2 || charToOccurences.get(n1.c1) > 2){ + return Triple.of(charToSizeMap.get(n1.c1), EBinaryOperand.A_A, Pair.of(n1.c1, null)); + } + return Triple.of(charToSizeMap.get(n1.c1), EBinaryOperand.a_a, Pair.of(null, null)); } - if(vectorIdx1 != null){ // ab,b->ab - EnsureMatrixBlockRowVector(inputs.get(vectorIdx1)); - - mb = getRowCodegenMatrixBlock(mb, inputs.get(vectorIdx1), CNodeBinary.BinType.VECT_MULT, SpoofRowwise.RowType.NO_AGG, null); - inputs.set(vectorIdx1, null); - inputsChars.set(vectorIdx1, null); + }else{ // n1.c1 != n2.c1 + if(n1.c2 == null) { + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.A_B, Pair.of(n1.c1, n2.c1)); } - - if(vectorIdx0 != null){ // ab,a->ab - EnsureMatrixBlockRowVector(inputs.get(vectorIdx0)); - -// mb = getRowCodegenMatrixBlock(mb, inputs.get(vectorIdx0), CNodeBinary.BinType.VECT_MULT, SpoofRowwise.RowType.NO_AGG,null); -// mb = getRowCodegenMatrixBlock(mb, inputs.get(vectorIdx0), CNodeBinary.BinType.VECT_MULT_SCALAR, SpoofRowwise.RowType.NO_AGG, Long.valueOf( inputs.get(vectorIdx0).getNumColumns())); - { - ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); - mb = mb.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - mb = getRowCodegenMatrixBlock(mb, inputs.get(vectorIdx0), CNodeBinary.BinType.VECT_MULT, SpoofRowwise.RowType.NO_AGG, null); + else if(n2.c2 == null) { // ab,c + if (n1.c2 == n2.c1) { + if(cannotBeSummed.test(n1.c2)){ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.BA_A, Pair.of(n1.c1, n1.c2)); + } + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.Ba_a, Pair.of(n1.c1, null)); } - - inputs.set(vectorIdx0, null); - inputsChars.set(vectorIdx0, null); - s = String.valueOf(s.charAt(1))+String.valueOf(s.charAt(0)); - } - - inputs.add(mb); - inputsChars.add(s); - for (int i = 0; i < s.length(); i++) { // for each char in string, add pointer to newly created entry - char c = s.charAt(i); - if (partsCharactersToIndices.containsKey(c)) partsCharactersToIndices.get(c).add(inputs.size() - 1); - } - - if(idxsT != null) matrixStringToIndexSkip.add(sT); - } - - - - return doneAnything; - } - - // returns true if left with summation with more than 2 inputs - private boolean sumCharactersWherePossible(HashMap> partsCharactersToIndices, ArrayList inputs, ArrayList inputsChars, Character outChar1, Character outChar2) { - boolean anyCouldNotDo; - - while (true) { - List toSum = null; - Character sumC = null; - anyCouldNotDo = false; - for (Character c : partsCharactersToIndices.keySet()) { // sum one dim at the time - if (c == outChar1 || c == outChar2) - continue; - toSum = new ArrayList<>(); - for (Integer idx : partsCharactersToIndices.get(c).stream().filter(Objects::nonNull).toList()) { - if (inputs.get(idx) != null) { - toSum.add(idx); + return null; // AB,C + } + else if (n1.c2 == n2.c1) { + if(n1.c1 == n2.c2){ // ab,ba + if(cannotBeSummed.test(n1.c1)){ + if(cannotBeSummed.test(n1.c2)){ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.AB_BA, Pair.of(n1.c1, n1.c2)); + } + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ba_aB, Pair.of(n1.c1, null)); } + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.ab_ba, Pair.of(null, null)); } - if (toSum.size() > 2) { - anyCouldNotDo = true; - continue; + if(cannotBeSummed.test(n1.c2)){ + return null; // AB_B + }else{ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c2), EBinaryOperand.Ba_aC, Pair.of(n1.c1, n2.c2)); +// if(n1.c1 ==outChar1 || n1.c1==outChar2|| charToOccurences.get(n1.c1) > 2){ +// return null; // AB_B +// } +// return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ba_a, Pair.of(n1.c1, null)); } - if (toSum.size() != 2) - continue; - sumC = c; - break; } - - if(sumC == null) break; - - Pair res = computeRowSummation(toSum, inputs, inputsChars, sumC); - String newS = res.getRight(); - - for (Integer idx : toSum) { - inputs.set(idx, null); - inputsChars.set(idx, null); + if(n1.c1 == n2.c2) { + if(cannotBeSummed.test(n1.c1)){ + return null; // AB_B + } + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c1), EBinaryOperand.aB_Ca, Pair.of(n2.c1, n1.c2)); // * its just reorder of mmult } - inputs.add(res.getLeft()); - inputsChars.add(newS); - - for (int i = 0; i < newS.length(); i++) { // for each char in string, add pointer to newly created entry - char c = newS.charAt(i); - if(partsCharactersToIndices.containsKey(c)) - partsCharactersToIndices.get(c).add(inputs.size() - 1); + else if (n1.c2 == n2.c2) { + if(n1.c2 ==outChar1 || n1.c2==outChar2|| charToOccurences.get(n1.c2) > 2){ + return null; // BA_CA + }else{ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2) +(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c1)), EBinaryOperand.Ba_Ca, Pair.of(n1.c1, n2.c1)); // or n2.c1, n1.c1 + } + } + else { // we have something like ab,cd + return null; } - partsCharactersToIndices.remove(sumC); } - return anyCouldNotDo; } - private enum SumOperation { - aB_a, - Ba_a, - Ba_aC, // mmult -> BC -// aB_Ca, - Ba_Ca, - aB_aC, // outer mult - a_a, - aB_aB, Ba_Ba, Ba_aB, aB_Ba,// mult and sums, something like ij,ij->i + private ArrayList executePlan(List plan, ArrayList inputs){ + return executePlan(plan, inputs, false); } - - private Pair computeRowSummation(List toSum, ArrayList inputs, List inputsChars, Character sumChar) { - - if(toSum.size() != 2){ - return null; + private ArrayList executePlan(List plan, ArrayList inputs, boolean codegen) { + ArrayList res = new ArrayList<>(plan.size()); + for(EOpNode p : plan){ + if(codegen) res.add(ComputeEOpNodeCodegen(p, inputs)); + else res.add(ComputeEOpNode(p, inputs)); } + return res; + } - String s1 = inputsChars.get(toSum.get(0)); - String s2 = inputsChars.get(toSum.get(1)); - - MatrixBlock first = null; - MatrixBlock second = null; - - String resS; - SumOperation sumOp; - - if(s1.length()==1 && s2.length() == 1){ - sumOp = SumOperation.a_a; - resS = ""; - first = inputs.get(toSum.get(0)); - second = inputs.get(toSum.get(1)); + private MatrixBlock ComputeEOpNode(EOpNode eOpNode, ArrayList inputs){ + if(eOpNode instanceof EOpNodeData eOpNodeData){ + return inputs.get(eOpNodeData.matrixIdx); } - else if(s2.length() == 1 || s1.length() == 1){ - if(s1.length() == 1){ - String sTemp = s1; - s1=s2; - s2=sTemp; - - first = inputs.get(toSum.get(1)); - second = inputs.get(toSum.get(0)); - }else{ - first = inputs.get(toSum.get(0)); - second = inputs.get(toSum.get(1)); + EOpNodeBinary bin = (EOpNodeBinary) eOpNode; + MatrixBlock left = ComputeEOpNode(bin.left, inputs); + MatrixBlock right = ComputeEOpNode(bin.right, inputs); + + AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); + + switch (bin.operand){ + case AB_AB -> { + var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + return res; + } + case A_A -> { + EnsureMatrixBlockColumnVector(left); + EnsureMatrixBlockColumnVector(right); + var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + return res; + } + case a_a -> { + EnsureMatrixBlockColumnVector(left); + EnsureMatrixBlockColumnVector(right); + var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads); + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); + return res; + } + //////////// + case Ba_Ba -> { + var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads); + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); + return res; + } + case aB_aB -> { + var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), _numThreads); + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); + EnsureMatrixBlockColumnVector(res); + return res; + } + case ab_ab -> { + var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads); + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); + return res; + } + case ab_ba -> { + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); + right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads); + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); + return res; } - - if(s1.charAt(0) == s2.charAt(0)){ - sumOp = SumOperation.aB_a; - resS = String.valueOf(s1.charAt(1)); - }else{ - sumOp = SumOperation.Ba_a; - resS = String.valueOf(s1.charAt(0)); - } - } else if (s1.equals(s2)) { - if(s1.charAt(0) == sumChar){ - sumOp = SumOperation.aB_aB; - first = inputs.get(toSum.get(0)); - second = inputs.get(toSum.get(1)); - resS = String.valueOf(s1.charAt(1)); - }else{ - sumOp = SumOperation.Ba_Ba; - first = inputs.get(toSum.get(0)); - second = inputs.get(toSum.get(1)); - resS = String.valueOf(s1.charAt(0)); - } - }else if (s1.charAt(0) == s2.charAt(1) && s1.charAt(1) == s2.charAt(0)) { - if(s1.charAt(0) == sumChar){ - sumOp = SumOperation.aB_Ba; - first = inputs.get(toSum.get(0)); - second = inputs.get(toSum.get(1)); - resS = String.valueOf(s1.charAt(1)); - }else{ - sumOp = SumOperation.Ba_aB; - first = inputs.get(toSum.get(0)); - second = inputs.get(toSum.get(1)); - resS = String.valueOf(s1.charAt(0)); + case Ba_aB -> { + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); + right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads); + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); + return res; } - } else if(s1.charAt(0) == s2.charAt(0)){ - sumOp = SumOperation.aB_aC; - first = inputs.get(toSum.get(0)); - second = inputs.get(toSum.get(1)); - resS = String.valueOf(s1.charAt(1))+String.valueOf(s2.charAt(1)); - - } - else if(s1.charAt(1) == s2.charAt(1)){ - sumOp = SumOperation.Ba_Ca; - first = inputs.get(toSum.get(0)); - second = inputs.get(toSum.get(1)); - resS = String.valueOf(s1.charAt(0))+String.valueOf(s2.charAt(0)); - } - else if(s1.charAt(0) == s2.charAt(1)){ - sumOp = SumOperation.Ba_aC; - String sTemp = s1; - s1=s2; - s2=sTemp; - first = inputs.get(toSum.get(1)); - second = inputs.get(toSum.get(0)); - resS = String.valueOf(s1.charAt(0))+String.valueOf(s2.charAt(1)); - - } - else if(s1.charAt(1) == s2.charAt(0)){ - sumOp = SumOperation.Ba_aC; - first = inputs.get(toSum.get(0)); - second = inputs.get(toSum.get(1)); - resS = String.valueOf(s1.charAt(0))+String.valueOf(s2.charAt(1)); - - }else{ - throw new RuntimeException("Error when choosing row multiplication operation"); - } - MatrixBlock out; - - if(LOG.isTraceEnabled()) LOG.trace("remaining: "+String.join(",",inputsChars.stream().filter(Objects::nonNull).toList())); - if(LOG.isTraceEnabled()) LOG.trace("Summing: "+s1+","+s2+"->"+resS); - switch (sumOp) { - case aB_a:{ + ///////// + case AB_BA -> { ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); - first = first.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.DOT_PRODUCT, SpoofRowwise.RowType.ROW_AGG, null); - break; - } - case Ba_a: - out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.DOT_PRODUCT, SpoofRowwise.RowType.ROW_AGG, null); - break; - case Ba_aC: { - out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_MATRIXMULT, SpoofRowwise.RowType.NO_AGG_B1, Long.valueOf( second.getNumColumns())); - break; - } - case Ba_Ca: { - ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); - second = second.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_MATRIXMULT, SpoofRowwise.RowType.NO_AGG_B1, Long.valueOf(second.getNumColumns())); - break; - } -// case aB_a: - case aB_aC: { - out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_OUTERMULT_ADD, SpoofRowwise.RowType.COL_AGG_B1_T, Long.valueOf( second.getNumColumns())); - break; - } - case a_a: - out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_MULT, SpoofRowwise.RowType.NO_AGG,null); - break; - case aB_aB: { - ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); - first = first.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - second = second.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.DOT_PRODUCT, SpoofRowwise.RowType.COL_AGG, null); - break; + right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + return res; + } + case Ba_aC -> { + var res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), _numThreads); + return res; } - case Ba_Ba: { - out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.DOT_PRODUCT, SpoofRowwise.RowType.ROW_AGG, null); - break; + case aB_Ca -> { + var res = LibMatrixMult.matrixMult(right,left, new MatrixBlock(), _numThreads); + return res; } - case aB_Ba: { + case Ba_Ca -> { ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); - first = first.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.VECT_MATRIXMULT, SpoofRowwise.RowType.ROW_AGG,null); - break; + right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + var res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), _numThreads); + return res; } - case Ba_aB: { + case aB_aC -> { ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); - second = second.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - out = getRowCodegenMatrixBlock(first, second, CNodeBinary.BinType.DOT_PRODUCT, SpoofRowwise.RowType.ROW_AGG,null); - break; + left = left.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + var res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), _numThreads); + return res; + } + case A_scalar, AB_scalar -> { + var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left},new ScalarObject[]{new DoubleObject(right.get(0,0))}, new MatrixBlock()); + return res; + } + case BA_A -> { + EnsureMatrixBlockRowVector(right); + var res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); + return res; + } + case Ba_a -> { + EnsureMatrixBlockRowVector(right); + var res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads); + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); + return res; + } + + case AB_A -> { + EnsureMatrixBlockColumnVector(right); + var res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); + return res; + } + case aB_a -> { + EnsureMatrixBlockColumnVector(right); + var res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), _numThreads); + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); + EnsureMatrixBlockColumnVector(res); + return res; + } + + case A_B -> { + EnsureMatrixBlockColumnVector(left); + EnsureMatrixBlockRowVector(right); + var res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); + return res; + } + case scalar_scalar -> { + return new MatrixBlock(left.get(0,0)*right.get(0,0)); + } + default -> { + throw new IllegalArgumentException("Unexpected value: " + bin.operand.toString()); } - default: - throw new IllegalStateException("Unexpected value: " + sumOp); } - return Pair.of(out , resS); } - private MatrixBlock getCodegenElemwiseMult(ArrayList mbs) { - - ArrayList cnodeIn = new ArrayList<>(); - for(int i=0;i inputs){ + return rComputeEOpNodeCodegen(eOpNode, inputs); +// throw new NotImplementedException(); + } + private static CNodeData MatrixBlockToCNodeData(MatrixBlock mb, int id){ + return new CNodeData("ce"+id, id, mb.getNumRows(), mb.getNumColumns(), DataType.MATRIX); + } + private static MatrixBlock rComputeEOpNodeCodegen(EOpNode eOpNode, ArrayList inputs) { + if (eOpNode instanceof EOpNodeData eOpNodeData){ + return inputs.get(eOpNodeData.matrixIdx); +// return new CNodeData("ce"+eOpNodeData.matrixIdx, eOpNodeData.matrixIdx, inputs.get(eOpNodeData.matrixIdx).getNumRows(), inputs.get(eOpNodeData.matrixIdx).getNumColumns(), DataType.MATRIX); } - CNode cnodeOut = new CNodeBinary(cnodeIn.get(0), cnodeIn.get(1), CNodeBinary.BinType.VECT_MULT); - CNodeRow cnode = new CNodeRow(cnodeIn, cnodeOut); + EOpNodeBinary bin = (EOpNodeBinary) eOpNode; +// CNodeData dataLeft = null; +// if (bin.left instanceof EOpNodeData eOpNodeData) dataLeft = new CNodeData("ce"+eOpNodeData.matrixIdx, eOpNodeData.matrixIdx, inputs.get(eOpNodeData.matrixIdx).getNumRows(), inputs.get(eOpNodeData.matrixIdx).getNumColumns(), DataType.MATRIX); +// CNodeData dataRight = null; +// if (bin.right instanceof EOpNodeData eOpNodeData) dataRight = new CNodeData("ce"+eOpNodeData.matrixIdx, eOpNodeData.matrixIdx, inputs.get(eOpNodeData.matrixIdx).getNumRows(), inputs.get(eOpNodeData.matrixIdx).getNumColumns(), DataType.MATRIX); - cnode.setRowType(SpoofRowwise.RowType.NO_AGG); - cnode.renameInputs(); + if(bin.operand == EBinaryOperand.AB_AB){ + if (bin.right instanceof EOpNodeBinary rBinary && rBinary.operand == EBinaryOperand.AB_AB){ + MatrixBlock left = rComputeEOpNodeCodegen(bin.left, inputs); - String src = cnode.codegen(false, SpoofCompiler.GeneratorAPI.JAVA); - if( LOG.isTraceEnabled()) LOG.trace(CodegenUtils.printWithLineNumber(src)); - Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); + MatrixBlock right1 = rComputeEOpNodeCodegen(((EOpNodeBinary) bin.right).left, inputs); + MatrixBlock right2 = rComputeEOpNodeCodegen(((EOpNodeBinary) bin.right).right, inputs); - SpoofOperator op = CodegenUtils.createInstance(cla); - MatrixBlock mb = new MatrixBlock(); + CNodeData d0 = MatrixBlockToCNodeData(left, 0); + CNodeData d1 = MatrixBlockToCNodeData(right1, 1); + CNodeData d2 = MatrixBlockToCNodeData(right2, 2); +// CNodeNary nary = new CNodeNary(cnodeIn, CNodeNary.NaryType.) + CNodeBinary rightBinary = new CNodeBinary(d1, d2, CNodeBinary.BinType.VECT_MULT); + CNodeBinary cNodeBinary = new CNodeBinary(d0, rightBinary, CNodeBinary.BinType.VECT_MULT); + ArrayList cnodeIn = new ArrayList<>(); + cnodeIn.add(d0); + cnodeIn.add(d1); + cnodeIn.add(d2); - ArrayList scalars = new ArrayList<>(); - MatrixBlock out = op.execute(mbs, scalars, mb, _numThreads); - return out; - } - private MatrixBlock getRowCodegenMatrixBlock(MatrixBlock first, MatrixBlock second, CNodeBinary.BinType binaryType, SpoofRowwise.RowType rowType, Long secondDim) { - ArrayList thisInputs = new ArrayList<>(Arrays.asList(first, second)); + CNodeRow cnode = new CNodeRow(cnodeIn, cNodeBinary); - ArrayList cnodeIn = new ArrayList<>(); + cnode.setRowType(SpoofRowwise.RowType.NO_AGG); + cnode.renameInputs(); - CNode c1 = new CNodeData("c1", 1, first.getNumRows(), first.getNumColumns(), DataType.MATRIX); - CNode c2 = new CNodeData("c2", 2, second.getNumRows(), second.getNumColumns(), DataType.MATRIX); - cnodeIn.add(c1); - cnodeIn.add(c2); - CNode cnodeOut = new CNodeBinary(c1, c2, binaryType); - CNodeRow cnode = new CNodeRow(cnodeIn, cnodeOut); - cnode.setRowType(rowType); + String src = cnode.codegen(false, SpoofCompiler.GeneratorAPI.JAVA); + if( LOG.isTraceEnabled()) LOG.trace(CodegenUtils.printWithLineNumber(src)); + Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); - if(secondDim != null) cnode.setConstDim2(secondDim); - cnode.renameInputs(); + SpoofOperator op = CodegenUtils.createInstance(cla); + MatrixBlock mb = new MatrixBlock(); - String src = cnode.codegen(false, SpoofCompiler.GeneratorAPI.JAVA); - if( LOG.isTraceEnabled()) LOG.trace(CodegenUtils.printWithLineNumber(src)); - Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); + ArrayList scalars = new ArrayList<>(); + ArrayList mbs = new ArrayList<>(3); + mbs.add(left); + mbs.add(right1); + mbs.add(right2); + MatrixBlock out = op.execute(mbs, scalars, mb, 6); - SpoofOperator op = CodegenUtils.createInstance(cla); - MatrixBlock mb = new MatrixBlock(); + return out; + } + } - ArrayList scalars = new ArrayList<>(); - MatrixBlock out = op.execute(thisInputs, scalars, mb, _numThreads); - return out; + throw new NotImplementedException(); } - private MatrixBlock getScalarMultiplyMatrixBlock(MatrixBlock mbIn, Double scalar){ - ArrayList thisInputs = new ArrayList<>(Arrays.asList(mbIn)); - - ArrayList cnodeIn = new ArrayList<>(); - - CNode c1 = new CNodeData("c1", 1, mbIn.getNumRows(), mbIn.getNumColumns(), DataType.MATRIX); - CNode c2 = new CNodeData(new LiteralOp(scalar), 0, 0, DataType.SCALAR); - cnodeIn.add(c1); - cnodeIn.add(c2); - CNode cnodeOut = new CNodeBinary(c1,c2, CNodeBinary.BinType.MULT); - CNodeCell cnode = new CNodeCell(cnodeIn, cnodeOut); - cnode.setCellType(SpoofCellwise.CellType.NO_AGG); - cnode.renameInputs(); - - String src = cnode.codegen(false, SpoofCompiler.GeneratorAPI.JAVA); - if( LOG.isTraceEnabled()) LOG.trace(CodegenUtils.printWithLineNumber(src)); - Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); - - SpoofOperator op = CodegenUtils.createInstance(cla); - MatrixBlock mb = new MatrixBlock(); + private void releaseMatrixInputs(ExecutionContext ec){ + for (CPOperand input : _in) + if(input.getDataType()==DataType.MATRIX) + ec.releaseMatrixInput(input.getName()); //todo release other + } - ArrayList scalars = new ArrayList<>(); - if(scalar != null) scalars.add(new DoubleObject(scalar)); - MatrixBlock out = op.execute(thisInputs, scalars, mb, _numThreads); - return out; + private static void EnsureMatrixBlockColumnVector(MatrixBlock mb){ + if(mb.getNumColumns() > 1){ + mb.setNumRows(mb.getNumColumns()); + mb.setNumColumns(1); + mb.getDenseBlock().resetNoFill(mb.getNumRows(),1); + } } + private static void EnsureMatrixBlockRowVector(MatrixBlock mb){ + if(mb.getNumRows() > 1){ + mb.setNumColumns(mb.getNumRows()); + mb.setNumRows(1); + mb.getDenseBlock().resetNoFill(1,mb.getNumColumns()); + } + } + private static void indent(StringBuilder sb, int level) { for (int i = 0; i < level; i++) { sb.append(" "); @@ -901,7 +838,7 @@ else if (summingChars.contains(inputsChars.get(i).charAt(1))) { src = src.replace("%OUT%", sb.toString()); } - LOG.trace(src); + if( LOG.isTraceEnabled()) LOG.trace(src); Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); SpoofOperator op = CodegenUtils.createInstance(cla); MatrixBlock resBlock = new MatrixBlock(); diff --git a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java index 97166a8cbef..f33c1bc5a1d 100644 --- a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java +++ b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java @@ -45,7 +45,7 @@ public class EinsumTest extends AutomatedTestBase { final private static List TEST_CONFIGS = List.of( - new Config("ij,jk->ik", List.of(shape(5, 600), shape(600, 10))), // mm + new Config("ij,jk->ik", List.of(shape(50, 600), shape(600, 10))), // mm new Config("ji,jk->ik", List.of(shape(600, 5), shape(600, 10))), new Config("ji,kj->ik", List.of(shape(600, 5), shape(10, 600))), new Config("ij,kj->ik", List.of(shape(5, 600), shape(10, 600))), @@ -59,35 +59,38 @@ public class EinsumTest extends AutomatedTestBase new Config("ji,jk->j", List.of(shape(600, 5), shape(600, 10))), new Config("ji,ji->ji", List.of(shape(600, 5), shape(600, 5))), // elemwise mult + new Config("ji,ji,ji->ji", List.of(shape(600, 5),shape(600, 5), shape(600, 5)), + List.of(0.0001, 0.0005, 0.001)), new Config("ji,ij->ji", List.of(shape(600, 5), shape(5, 600))), // elemwise mult - new Config("ij,i->ij", List.of(shape(1000, 50), shape(1000))), // col mult - new Config("ji,i->ij", List.of(shape(50, 1000), shape(1000))), // row mult - new Config("ij,i->i", List.of(shape(1000, 50), shape(1000))), - new Config("ij,i->j", List.of(shape(1000, 50), shape(1000))), + new Config("ij,i->ij", List.of(shape(100, 50), shape(100))), // col mult + new Config("ji,i->ij", List.of(shape(50, 100), shape(100))), // row mult + new Config("ij,i->i", List.of(shape(100, 50), shape(100))), + new Config("ij,i->j", List.of(shape(100, 50), shape(100))), - new Config("i,i->", List.of(shape(500), shape(500))), - new Config("i,j->", List.of(shape(500), shape(800))), - new Config("i,j->ij", List.of(shape(500), shape(800))), // outer vect mult - new Config("i,j->ji", List.of(shape(500), shape(800))), // outer vect mult + new Config("i,i->", List.of(shape(50), shape(50))), + new Config("i,j->", List.of(shape(50), shape(80))), + new Config("i,j->ij", List.of(shape(50), shape(80))), // outer vect mult + new Config("i,j->ji", List.of(shape(50), shape(80))), // outer vect mult - new Config("ij->", List.of(shape(1000, 50))), // sum - new Config("ij->i", List.of(shape(1000, 50))), // sum(1) - new Config("ij->j", List.of(shape(1000, 50))), // sum(0) - new Config("ij->ji", List.of(shape(1000, 50))), // T + new Config("ij->", List.of(shape(100, 50))), // sum + new Config("ij->i", List.of(shape(100, 50))), // sum(1) + new Config("ij->j", List.of(shape(100, 50))), // sum(0) + new Config("ij->ji", List.of(shape(100, 50))), // T new Config("ab,cd->ba", List.of(shape( 600, 10), shape(6, 5))), + new Config("ab,cd,g->ba", List.of(shape( 600, 10), shape(6, 5), shape(3))), new Config("ab,bc,cd,de->ae", List.of(shape(5, 600), shape(600, 10), shape(10, 5), shape(5, 4))), // chain of mm new Config("ji,jz,zx->ix", List.of(shape(600, 5), shape( 600, 10), shape(10, 2))), new Config("fx,fg,fz,xg->z", List.of(shape(600, 5), shape( 600, 10), shape(600, 6), shape(5, 10))), new Config("fx,fg,fz,xg,zx,zg->g", // each idx 3 times (cell tpl) - List.of(shape(5, 60), shape(5, 30), shape(5, 100), shape(60, 30), shape(100, 60), shape(100, 30))), + List.of(shape(5, 60), shape(5, 30), shape(5, 10), shape(60, 30), shape(10, 60), shape(10, 30))), - new Config("i->", List.of(shape(1000))), - new Config("i->i", List.of(shape(1000))) + new Config("i->", List.of(shape(100))), + new Config("i->i", List.of(shape(100))) ); private final int id; @@ -146,7 +149,7 @@ private static StringBuilder createDmlFile(Config config, boolean outputScalar) for (int i = 0; i < config.shapes.size(); i++) { int[] dims = config.shapes.get(i); - double factor = 0.0001; + double factor = config.factors != null ? config.factors.get(i) : 0.0001; sb.append("A"); sb.append(i); @@ -174,13 +177,13 @@ private static StringBuilder createDmlFile(Config config, boolean outputScalar) sb.append(config.einsumStr); sb.append("\", "); - for (int i = 0; i < config.shapes.size()-1; i++) { + for (int i = 0; i < config.shapes.size() - 1; i++) { sb.append("A"); sb.append(i); sb.append(", "); } sb.append("A"); - sb.append(config.shapes.size()-1); + sb.append(config.shapes.size() - 1); sb.append(")"); sb.append("\n\n"); @@ -199,8 +202,8 @@ private static StringBuilder createRFile(Config config, boolean outputScalar) { for (int i = 0; i < config.shapes.size(); i++) { int[] dims = config.shapes.get(i); - - double factor = 0.0001; + + double factor = config.factors != null ? config.factors.get(i) : 0.0001; sb.append("A"); sb.append(i); @@ -268,12 +271,19 @@ public void cleanUp() { } private static class Config { + public List factors; String einsumStr; List shapes; Config(String einsum, List shapes) { this.einsumStr = einsum; this.shapes = shapes; + this.factors = null; + } + Config(String einsum, List shapes, List factors) { + this.einsumStr = einsum; + this.shapes = shapes; + this.factors = factors; } } From 3c5f86a3c75b90b50768ca60311b1c5e6c1e2472 Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Sun, 10 Aug 2025 21:56:59 +0200 Subject: [PATCH 27/28] fix testrunner --- .../java/org/apache/sysds/test/functions/einsum/EinsumTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java index f33c1bc5a1d..04ea3b35a0e 100644 --- a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java +++ b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java @@ -332,7 +332,7 @@ private void testCodegenIntegration( String testname) runRScript(true); if(outputScalar){ - HashMap dmlfile = readDMLScalarFromExpectedDir("S"); + HashMap dmlfile = readDMLScalarFromOutputDir("S"); HashMap rfile = readRScalarFromExpectedDir("S"); TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); }else { From f6b93e25487ab7707bfc0eedbc85a41bb7b4180c Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Sun, 10 Aug 2025 22:16:06 +0200 Subject: [PATCH 28/28] small change to save lines --- .../instructions/cp/EinsumCPInstruction.java | 59 +++++++------------ 1 file changed, 21 insertions(+), 38 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java index 007c89f7821..87dcf3c6048 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -492,125 +492,107 @@ private MatrixBlock ComputeEOpNode(EOpNode eOpNode, ArrayList input AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); + MatrixBlock res; switch (bin.operand){ case AB_AB -> { - var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); - return res; + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); } case A_A -> { EnsureMatrixBlockColumnVector(left); EnsureMatrixBlockColumnVector(right); - var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); - return res; + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); } case a_a -> { EnsureMatrixBlockColumnVector(left); EnsureMatrixBlockColumnVector(right); - var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads); res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - return res; } //////////// case Ba_Ba -> { - var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads); res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - return res; } case aB_aB -> { - var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), _numThreads); res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); EnsureMatrixBlockColumnVector(res); - return res; } case ab_ab -> { - var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads); res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - return res; } case ab_ba -> { ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads); res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - return res; } case Ba_aB -> { ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads); res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - return res; } ///////// case AB_BA -> { ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); - return res; + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); } case Ba_aC -> { - var res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), _numThreads); - return res; + res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), _numThreads); } case aB_Ca -> { - var res = LibMatrixMult.matrixMult(right,left, new MatrixBlock(), _numThreads); - return res; + res = LibMatrixMult.matrixMult(right,left, new MatrixBlock(), _numThreads); } case Ba_Ca -> { ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - var res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), _numThreads); - return res; + res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), _numThreads); } case aB_aC -> { ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); left = left.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - var res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), _numThreads); - return res; + res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), _numThreads); } case A_scalar, AB_scalar -> { - var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left},new ScalarObject[]{new DoubleObject(right.get(0,0))}, new MatrixBlock()); - return res; + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left},new ScalarObject[]{new DoubleObject(right.get(0,0))}, new MatrixBlock()); } case BA_A -> { EnsureMatrixBlockRowVector(right); - var res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); - return res; + res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); } case Ba_a -> { EnsureMatrixBlockRowVector(right); - var res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); + res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads); res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - return res; } case AB_A -> { EnsureMatrixBlockColumnVector(right); - var res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); - return res; + res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); } case aB_a -> { EnsureMatrixBlockColumnVector(right); - var res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); + res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), _numThreads); res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); EnsureMatrixBlockColumnVector(res); - return res; } case A_B -> { EnsureMatrixBlockColumnVector(left); EnsureMatrixBlockRowVector(right); - var res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); - return res; + res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); } case scalar_scalar -> { return new MatrixBlock(left.get(0,0)*right.get(0,0)); @@ -620,6 +602,7 @@ private MatrixBlock ComputeEOpNode(EOpNode eOpNode, ArrayList input } } + return res; } private static MatrixBlock ComputeEOpNodeCodegen(EOpNode eOpNode, ArrayList inputs){