Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
ad3ec9d
initial working version
HubertKrawczyk May 18, 2025
bc69ee3
add einsum R dependency
HubertKrawczyk May 18, 2025
d5df648
Merge remote-tracking branch 'origin/main' into einsum
HubertKrawczyk May 19, 2025
8db99c9
fix merge
HubertKrawczyk May 19, 2025
c9c9947
quick fix
HubertKrawczyk May 19, 2025
3a09dff
more computations done using row tpl
HubertKrawczyk Jun 8, 2025
6b3c67f
removed old code
HubertKrawczyk Jun 8, 2025
786d217
fix bugs
HubertKrawczyk Jun 8, 2025
980d7a8
cleanup validate code
HubertKrawczyk Jun 17, 2025
309f758
add tests
HubertKrawczyk Jun 17, 2025
1094805
better plan generation for einsum, handle more cases, added test that…
HubertKrawczyk Jun 18, 2025
b46f23a
support for scalar multiplication, better validation
HubertKrawczyk Jun 26, 2025
d9d52ef
add test and removing unnecessary changes
HubertKrawczyk Jun 26, 2025
1c4abc3
remove more unused code, fix bug
HubertKrawczyk Jun 26, 2025
0f3c395
add einsum to test suite in github workflows
HubertKrawczyk Jun 26, 2025
bf4f910
add tests and code for vector elementwise mult. and outer product
HubertKrawczyk Jun 30, 2025
1b765cc
moved einsum information extraction to EinsumContext function
HubertKrawczyk Jun 30, 2025
12564d7
add comment
HubertKrawczyk Jun 30, 2025
83b2dfd
Merge branch 'main' into einsum
HubertKrawczyk Jul 4, 2025
a0b1608
extracted einsumEquationValidation to separate class, it is now calle…
HubertKrawczyk Jul 4, 2025
acaa2ae
missing licence notice, better more restricting condition for binary …
HubertKrawczyk Jul 4, 2025
f694b7d
better, clearer EinsumContext properties names
HubertKrawczyk Jul 4, 2025
f622d4b
aB_a contraction works, by transposing matrix, no longer change to co…
HubertKrawczyk Jul 20, 2025
1e5e0fa
remove changes to CNodeBinary
HubertKrawczyk Jul 25, 2025
63683c9
remove changes to CNodeBinary 2
HubertKrawczyk Jul 25, 2025
4af94a3
Merge branch 'main' into einsum
HubertKrawczyk Jul 25, 2025
ae4bb82
reduce size of einsum test18
HubertKrawczyk Jul 25, 2025
b854305
Merge branch 'main' into einsum
HubertKrawczyk Aug 3, 2025
7eb950c
create einsum test files from configuration list, fixes in code
HubertKrawczyk Aug 3, 2025
ba21c30
refactored plan generation and execution into separate steps; plan is…
HubertKrawczyk Aug 9, 2025
3c5f86a
fix testrunner
HubertKrawczyk Aug 10, 2025
f6b93e2
small change to save lines
HubertKrawczyk Aug 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/javaTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ jobs:
"**.functions.unary.scalar.**,**.functions.updateinplace.**,**.functions.vect.**",
"**.functions.reorg.**,**.functions.rewrite.**,**.functions.ternary.**",
"**.functions.transform.**","**.functions.unique.**",
"**.functions.unary.matrix.**,**.functions.linearization.**,**.functions.jmlc.**"
"**.functions.unary.matrix.**,**.functions.linearization.**,**.functions.jmlc.**",
"**.functions.einsum.**",
]
name: ${{ matrix.tests }}
steps:
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/apache/sysds/common/Builtins.java
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ public enum Builtins {
UNDER_SAMPLING("underSampling", true),
UNIQUE("unique", false, true),
UPPER_TRI("upper.tri", false, true),
EINSUM("einsum", false, false),
XDUMMY1("xdummy1", true), //error handling test
XDUMMY2("xdummy2", true); //error handling test

Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/apache/sysds/common/InstructionType.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ public enum InstructionType {
PMMJ,
MMChain,
Union,
EINSUM,

//SP Types
MAPMM,
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/apache/sysds/common/Opcodes.java
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ public enum Opcodes {
RBIND("rbind", InstructionType.BuiltinNary),
EVAL("eval", InstructionType.BuiltinNary),
LIST("list", InstructionType.BuiltinNary),

EINSUM("einsum", InstructionType.BuiltinNary),
//Parametrized builtin functions
AUTODIFF("autoDiff", InstructionType.ParameterizedBuiltin),
CONTAINS("contains", InstructionType.ParameterizedBuiltin),
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/org/apache/sysds/common/Types.java
Original file line number Diff line number Diff line change
Expand Up @@ -767,8 +767,8 @@ public String toString() {

/** Operations that require a variable number of operands*/
public enum OpOpN {
PRINTF, CBIND, RBIND, MIN, MAX, PLUS, MULT, EVAL, LIST;
PRINTF, CBIND, RBIND, MIN, MAX, PLUS, MULT, EVAL, LIST, EINSUM;

public boolean isCellOp() {
return this == MIN || this == MAX || this == PLUS || this == MULT;
}
Expand Down
9 changes: 9 additions & 0 deletions src/main/java/org/apache/sysds/hops/NaryOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.sysds.lops.Lop;
import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.lops.Nary;
import org.apache.sysds.runtime.einsum.EinsumEquationValidator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;

Expand Down Expand Up @@ -235,6 +236,14 @@ public void refreshSizeInformation() {
setDim1(getInput().size());
setDim2(1);
break;
case EINSUM:
String equationString = ((LiteralOp) _input.get(0)).getStringValue();
var dims = EinsumEquationValidator.validateEinsumEquationAndReturnDimensions(equationString, this.getInput().subList(1, this.getInput().size()));

setDim1(dims.getLeft());
setDim2(dims.getMiddle());
setDataType(dims.getRight());
break;
case PRINTF:
case EVAL:
//do nothing:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

public class CNodeCell extends CNodeTpl
{
protected static final String JAVA_TEMPLATE =
public static final String JAVA_TEMPLATE =
"package codegen;\n"
+ "import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;\n"
+ "import org.apache.sysds.runtime.codegen.SpoofCellwise;\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ public CNodeData(CNodeData node, String newName) {
_cols = node.getNumCols();
_dataType = node.getDataType();
}

public CNodeData(String name, long hopID, long rows, long cols, DataType dataType) {
_name = name;
_hopID = hopID;
_rows = rows;
_cols = cols;
_dataType = dataType;
}

@Override
public String getVarname() {
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/apache/sysds/lops/Nary.java
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ private String getOpcode() {
case RBIND:
case EVAL:
case LIST:
case EINSUM:
return operationType.name().toLowerCase();
case MIN:
case MAX:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;

import org.antlr.v4.runtime.ParserRuleContext;
import org.apache.commons.lang3.ArrayUtils;
Expand All @@ -35,6 +36,7 @@
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.parser.LanguageException.LanguageErrorCodes;
import org.apache.sysds.runtime.einsum.EinsumEquationValidator;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.DnnUtils;
import org.apache.sysds.runtime.util.UtilFunctions;
Expand Down Expand Up @@ -751,7 +753,9 @@ else if(((ConstIdentifier) getThirdExpr().getOutput())
else
raiseValidateError("Compress/DeCompress instruction not allowed in dml script");
break;

case EINSUM:
validateEinsum((DataIdentifier) getOutputs()[0]);
break;
default: //always unconditional
raiseValidateError("Unknown Builtin Function opcode: " + _opcode, false);
}
Expand Down Expand Up @@ -2063,7 +2067,9 @@ else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AV
output.setValueType(ValueType.INT64);
output.setNnz(id.getDim2());
break;

case EINSUM:
validateEinsum(output);
break;
default:
if( isMathFunction() ) {
checkMathFunctionParam();
Expand Down Expand Up @@ -2096,6 +2102,49 @@ else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AV
}
}

private void validateEinsum(DataIdentifier output){
if(getSecondExpr() == null)
raiseValidateError("Einsum: at least one input matrix required", false,
LanguageErrorCodes.INVALID_PARAMETERS);

if(!(getFirstExpr() instanceof StringIdentifier))
raiseValidateError("Einsum: first argument has to be equation str", false,
LanguageErrorCodes.INVALID_PARAMETERS);

String equationString = ((StringIdentifier)getFirstExpr()).getValue();

if (equationString.length() == 0) raiseValidateError("Einsum: equation str too short", false, LanguageErrorCodes.INVALID_PARAMETERS);
if (equationString.charAt(0) == '-' || equationString.charAt(0) == ',') raiseValidateError("Einsum: equation str invalid", false, LanguageErrorCodes.INVALID_PARAMETERS);

Expression[] expressions = getAllExpr();
boolean allDimsKnown = true;

LinkedList<Identifier> matrixBlocks = new LinkedList();
for (int i=1;i<expressions.length; i++){
checkMatrixParam(expressions[i]);
if(!(expressions[i]).getOutput().dimsKnown()){
allDimsKnown = false;
break;
}
matrixBlocks.add((expressions[i].getOutput()));
}

if(allDimsKnown){
var dims = EinsumEquationValidator.validateEinsumEquationAndReturnDimensions(equationString, matrixBlocks);

output.setDataType(dims.getRight());
output.setDimensions(dims.getLeft(), dims.getMiddle());
}else{
DataType dataType = EinsumEquationValidator.validateEinsumEquationNoDimensions(equationString, _args.length - 1);

output.setDataType(dataType);
output.setDimensions(-1l, -1l);
}

output.setValueType(ValueType.FP64);
output.setBlocksize(getSecondExpr().getOutput().getBlocksize());
}

private void setBinaryOutputProperties(DataIdentifier output) {
DataType dt1 = getFirstExpr().getOutput().getDataType();
DataType dt2 = getSecondExpr().getOutput().getDataType();
Expand Down
5 changes: 4 additions & 1 deletion src/main/java/org/apache/sysds/parser/DMLTranslator.java
Original file line number Diff line number Diff line change
Expand Up @@ -2447,7 +2447,10 @@ private Hop processBuiltinFunctionExpression(BuiltinFunctionExpression source, D
new NaryOp(target.getName(), target.getDataType(), target.getValueType(),
OpOpN.valueOf(source.getOpCode().name()), processAllExpressions(source.getAllExpr(), hops));
break;

case EINSUM:
currBuiltinOp = new NaryOp(target.getName(), target.getDataType(), target.getValueType(),
OpOpN.valueOf(source.getOpCode().name()), processAllExpressions(source.getAllExpr(), hops));
break;
case PPRED:
String sop = ((StringIdentifier)source.getThirdExpr()).getValue();
sop = sop.replace("\"", "");
Expand Down
177 changes: 177 additions & 0 deletions src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.runtime.einsum;

import org.apache.sysds.runtime.matrix.data.MatrixBlock;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;


public class EinsumContext {
public enum ContractDimensions {
CONTRACT_LEFT,
CONTRACT_RIGHT,
CONTRACT_BOTH,
}
public Integer outRows;
public Integer outCols;
public Character outChar1;
public Character outChar2;
public HashMap<Character, Integer> charToDimensionSize;
public String equationString;
public boolean[] diagonalInputs;
public HashSet<Character> summingChars;
public HashSet<Character> contractDimsSet;
public ContractDimensions[] contractDims;
public ArrayList<String> newEquationStringInputsSplit;
public HashMap<Character, ArrayList<Integer>> characterAppearanceIndexes; // for each character, this tells in which inputs it appears

private EinsumContext(){};
public static EinsumContext getEinsumContext(String eqStr, ArrayList<MatrixBlock> inputs){
EinsumContext res = new EinsumContext();

res.equationString = eqStr;
res.charToDimensionSize = new HashMap<Character, Integer>();
HashSet<Character> summingChars = new HashSet<>();
ContractDimensions[] contractDims = new ContractDimensions[inputs.size()];
boolean[] diagonalInputs = new boolean[inputs.size()]; // all false by default
HashSet<Character> contractDimsSet = new HashSet();
HashMap<Character, ArrayList<Integer>> partsCharactersToIndices = new HashMap<>();
ArrayList<String> newEquationStringSplit = new ArrayList();

Iterator<MatrixBlock> it = inputs.iterator();
MatrixBlock curArr = it.next();
int arrSizeIterator = 0;
int arrayIterator = 0;
int i;
// first iteration through string: collect information on character-size and what characters are summing characters
for (i = 0; true; i++) {
char c = eqStr.charAt(i);
if(c == '-'){
i+=2;
break;
}
if(c == ','){
arrayIterator++;
curArr = it.next();
arrSizeIterator = 0;
}
else{
if (res.charToDimensionSize.containsKey(c)) { // sanity check if dims match, this is already checked at validation
if(arrSizeIterator == 0 && res.charToDimensionSize.get(c) != curArr.getNumRows())
throw new RuntimeException("Einsum: character "+c+" has multiple conflicting sizes");
else if(arrSizeIterator == 1 && res.charToDimensionSize.get(c) != curArr.getNumColumns())
throw new RuntimeException("Einsum: character "+c+" has multiple conflicting sizes");
summingChars.add(c);
} else {
if(arrSizeIterator == 0)
res.charToDimensionSize.put(c, curArr.getNumRows());
else if(arrSizeIterator == 1)
res.charToDimensionSize.put(c, curArr.getNumColumns());
}

arrSizeIterator++;
}
}

int numOfRemainingChars = eqStr.length() - i;

if (numOfRemainingChars > 2)
throw new RuntimeException("Einsum: dim > 2 not supported");

arrSizeIterator = 0;

Character outChar1 = numOfRemainingChars > 0 ? eqStr.charAt(i) : null;
Character outChar2 = numOfRemainingChars > 1 ? eqStr.charAt(i+1) : null;
res.outRows=(numOfRemainingChars > 0 ? res.charToDimensionSize.get(outChar1) : 1);
res.outCols=(numOfRemainingChars > 1 ? res.charToDimensionSize.get(outChar2) : 1);

arrayIterator=0;
// second iteration through string: collect remaining information
for (i = 0; true; i++) {
char c = eqStr.charAt(i);
if (c == '-') {
break;
}
if (c == ',') {
arrayIterator++;
arrSizeIterator = 0;
continue;
}
String s = "";

if(summingChars.contains(c)) {
s+=c;
if(!partsCharactersToIndices.containsKey(c))
partsCharactersToIndices.put(c, new ArrayList<>());
partsCharactersToIndices.get(c).add(arrayIterator);
}
else if((outChar1 != null && c == outChar1) || (outChar2 != null && c == outChar2)) {
s+=c;
}
else {
contractDimsSet.add(c);
contractDims[arrayIterator] = ContractDimensions.CONTRACT_LEFT;
}

if(i + 1 < eqStr.length()) { // process next character together
char c2 = eqStr.charAt(i + 1);
i++;
if (c2 == '-') { newEquationStringSplit.add(s); break;}
if (c2 == ',') { arrayIterator++; newEquationStringSplit.add(s); continue; }

if (c2 == c){
diagonalInputs[arrayIterator] = true;
if (contractDims[arrayIterator] == ContractDimensions.CONTRACT_LEFT) contractDims[arrayIterator] = ContractDimensions.CONTRACT_BOTH;
}
else{
if(summingChars.contains(c2)) {
s+=c2;
if(!partsCharactersToIndices.containsKey(c2))
partsCharactersToIndices.put(c2, new ArrayList<>());
partsCharactersToIndices.get(c2).add(arrayIterator);
}
else if((outChar1 != null && c2 == outChar1) || (outChar2 != null && c2 == outChar2)) {
s+=c2;
}
else {
contractDimsSet.add(c2);
contractDims[arrayIterator] = contractDims[arrayIterator] == ContractDimensions.CONTRACT_LEFT ? ContractDimensions.CONTRACT_BOTH : ContractDimensions.CONTRACT_RIGHT;
}
}
}
newEquationStringSplit.add(s);
arrSizeIterator++;
}

res.contractDims = contractDims;
res.contractDimsSet = contractDimsSet;
res.diagonalInputs = diagonalInputs;
res.summingChars = summingChars;
res.outChar1 = outChar1;
res.outChar2 = outChar2;
res.newEquationStringInputsSplit = newEquationStringSplit;
res.characterAppearanceIndexes = partsCharactersToIndices;
return res;
}
}
Loading
Loading