Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/main/java/org/apache/sysds/api/DMLOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ public class DMLOptions {
public boolean federatedCompilation = false; // Compile federated instructions based on input federation state and privacy constraints.
public boolean noFedRuntimeConversion = false; // If activated, no runtime conversion of CP instructions to FED instructions will be performed.
public int seed = -1; // The general seed for the execution, if -1 random (system time).
public boolean sparseIntermediate = false; // whether SparseRowIntermediates should be used for rowwise operations

public final static DMLOptions defaultOptions = new DMLOptions(null);

Expand Down Expand Up @@ -117,7 +118,8 @@ public String toString() {
", w=" + fedWorker +
", federatedCompilation=" + federatedCompilation +
", noFedRuntimeConversion=" + noFedRuntimeConversion +
", seed=" + seed +
", seed=" + seed +
", sparseIntermediate=" + sparseIntermediate +
'}';
}

Expand Down Expand Up @@ -350,6 +352,10 @@ else if (lineageType.equalsIgnoreCase("debugger"))
dmlOptions.seed = Integer.parseInt(line.getOptionValue("seed"));
}

if(line.hasOption("sparseIntermediate")){
dmlOptions.sparseIntermediate = true;
}

return dmlOptions;
}

Expand Down Expand Up @@ -431,7 +437,10 @@ private static Options createCLIOptions() {
Option commandlineSeed = OptionBuilder
.withDescription("A general seed for the execution through the commandline")
.hasArg().create("seed");

Option sparseRowIntermediates = OptionBuilder
.withDescription("If activated, sparseRowVector intermediates will be used to calculate rowwise operations.")
.create("sparseIntermediate");

options.addOption(configOpt);
options.addOption(cleanOpt);
options.addOption(statsOpt);
Expand All @@ -451,6 +460,7 @@ private static Options createCLIOptions() {
options.addOption(federatedCompilation);
options.addOption(noFedRuntimeConversion);
options.addOption(commandlineSeed);
options.addOption(sparseRowIntermediates);

// Either a clean(-clean), a file(-f), a script(-s) or help(-help) needs to be specified
OptionGroup fileOrScriptOpt = new OptionGroup()
Expand Down
4 changes: 4 additions & 0 deletions src/main/java/org/apache/sysds/api/DMLScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ public class DMLScript
// Global seed
public static int SEED = -1;

// Sparse row flag
public static boolean SPARSE_INTERMEDIATE = false;

public static String MONITORING_ADDRESS = null;

// flag that indicates whether or not to suppress any prints to stdout
Expand Down Expand Up @@ -275,6 +278,7 @@ public static boolean executeScript( String[] args )
LINEAGE_ESTIMATE = dmlOptions.lineage_estimate;
LINEAGE_DEBUGGER = dmlOptions.lineage_debugger;
SEED = dmlOptions.seed;
SPARSE_INTERMEDIATE = dmlOptions.sparseIntermediate;


String fnameOptConfig = dmlOptions.configFile;
Expand Down
20 changes: 18 additions & 2 deletions src/main/java/org/apache/sysds/hops/codegen/cplan/CNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.sysds.hops.codegen.cplan;

import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.hops.codegen.SpoofCompiler.GeneratorAPI;
import org.apache.sysds.hops.codegen.template.TemplateUtils;
Expand Down Expand Up @@ -77,6 +78,14 @@ public String createVarname() {
_genVar = "TMP"+_seqVar.getNextID();
return _genVar;
}

public String createVarname(boolean sparse) {
if(!sparse) {
return createVarname();
} else {
return _genVar = "S" + createVarname();
}
}

public String getVarname() {
return _genVar;
Expand All @@ -98,6 +107,8 @@ public String getVectorLength(GeneratorAPI api) {
return "len";
if(getVarname().startsWith("b"))
return getVarname() + ".clen";
else if(getVarname().startsWith("STMP"))
return "len";
else if(_dataType == DataType.MATRIX)
return getVarname() + ".length";
}
Expand Down Expand Up @@ -222,8 +233,13 @@ public boolean equals(Object that) {

protected String replaceUnaryPlaceholders(String tmp, String varj, boolean vectIn, GeneratorAPI api) {
//replace sparse and dense inputs
tmp = tmp.replace("%IN1v%", varj+"vals");
tmp = tmp.replace("%IN1i%", varj+"ix");
if(DMLScript.SPARSE_INTERMEDIATE) {
tmp = tmp.replace("%IN1v%", varj.startsWith("STMP") ? varj+".values()" : varj+"vals");
tmp = tmp.replace("%IN1i%", varj.startsWith("STMP") ? varj+".indexes()" :varj+"ix");
} else {
tmp = tmp.replace("%IN1v%", varj+"vals");
tmp = tmp.replace("%IN1i%", varj+"ix");
}
tmp = tmp.replace("%IN1%",
(vectIn && TemplateUtils.isMatrix(_inputs.get(0))) ?
((api == GeneratorAPI.JAVA) ? varj + ".values(rix)" : varj + ".vals(0)" ) :
Expand Down
190 changes: 140 additions & 50 deletions src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.Arrays;

import org.apache.commons.lang3.StringUtils;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Opcodes;
import org.apache.sysds.hops.codegen.template.TemplateUtils;
import org.apache.sysds.common.Types.DataType;
Expand Down Expand Up @@ -157,69 +158,158 @@ public String codegen(boolean sparse, GeneratorAPI api) {
//generate children
sb.append(_inputs.get(0).codegen(sparse, api));
sb.append(_inputs.get(1).codegen(sparse, api));

//generate binary operation (use sparse template, if data input)
boolean lsparseLhs = sparse && _inputs.get(0) instanceof CNodeData
&& _inputs.get(0).getVarname().startsWith("a");
boolean lsparseRhs = sparse && _inputs.get(1) instanceof CNodeData
&& _inputs.get(1).getVarname().startsWith("a");
boolean scalarInput = _inputs.get(0).getDataType().isScalar();
boolean scalarVector = (_inputs.get(0).getDataType().isScalar()
&& _inputs.get(1).getDataType().isMatrix());
boolean vectorVector = _inputs.get(0).getDataType().isMatrix()
&& _inputs.get(1).getDataType().isMatrix();
String var = createVarname();
String tmp = getLanguageTemplateClass(this, api)
.getTemplate(_type, lsparseLhs, lsparseRhs, scalarVector, scalarInput, vectorVector);

tmp = tmp.replace("%TMP%", var);

//replace input references and start indexes
for( int j=0; j<2; j++ ) {
String varj = _inputs.get(j).getVarname(api);

//replace sparse and dense inputs
tmp = tmp.replace("%IN"+(j+1)+"v%", varj+"vals");
tmp = tmp.replace("%IN"+(j+1)+"i%", varj+"ix");
tmp = tmp.replace("%IN"+(j+1)+"%",
varj.startsWith("a") ? (api == GeneratorAPI.JAVA ? varj :
if(DMLScript.SPARSE_INTERMEDIATE) {
//generate binary operation (use sparse template, if data input)
boolean lsparseLhs = sparse ? _inputs.get(0) instanceof CNodeData
&& _inputs.get(0).getVarname().startsWith("a") ||
_inputs.get(0).getVarname().startsWith("STMP") : false;
boolean lsparseRhs = sparse ? _inputs.get(1) instanceof CNodeData
&& _inputs.get(1).getVarname().startsWith("a") ||
_inputs.get(1).getVarname().startsWith("STMP") : false;
boolean scalarInput = _inputs.get(0).getDataType().isScalar();
boolean scalarVector = (_inputs.get(0).getDataType().isScalar()
&& _inputs.get(1).getDataType().isMatrix());
boolean vectorVector = _inputs.get(0).getDataType().isMatrix()
&& _inputs.get(1).getDataType().isMatrix();
String var = createVarname(sparse && getOutputType(scalarVector, lsparseLhs, lsparseRhs));
String tmp = getLanguageTemplateClass(this, api)
.getTemplate(_type, lsparseLhs, lsparseRhs, scalarVector, scalarInput, vectorVector);

tmp = tmp.replace("%TMP%", var);

//replace input references and start indexes
for( int j=0; j<2; j++ ) {
String varj = _inputs.get(j).getVarname(api);
//replace sparse and dense inputs
tmp = tmp.replace("%IN"+(j+1)+"v%", varj.startsWith("STMP") ? varj+".values()" : varj+"vals");
tmp = tmp.replace("%IN"+(j+1)+"i%", varj.startsWith("STMP") ? varj+".indexes()" : varj+"ix");
tmp = tmp.replace("%IN"+(j+1)+"%",
varj.startsWith("a") ? (api == GeneratorAPI.JAVA ? varj :
(_inputs.get(j).getDataType() == DataType.MATRIX ? varj + ".vals(0)" : varj)) :
varj.startsWith("b") ? (api == GeneratorAPI.JAVA ? varj + ".values(rix)" :
(_type == BinType.VECT_MATRIXMULT ? varj : varj + ".vals(0)")) :
varj.startsWith("b") ? (api == GeneratorAPI.JAVA ? varj + ".values(rix)" :
(_type == BinType.VECT_MATRIXMULT ? varj : varj + ".vals(0)")) :
_inputs.get(j).getDataType() == DataType.MATRIX ? (api == GeneratorAPI.JAVA ? varj : varj + ".vals(0)") : varj);

//replace start position of main input
tmp = tmp.replace("%POS"+(j+1)+"%", (_inputs.get(j) instanceof CNodeData
&& _inputs.get(j).getDataType().isMatrix()) ? (!varj.startsWith("b")) ? varj+"i" :

tmp = tmp.replace("%SLEN"+(j+1)+"%", varj.startsWith("STMP") ? varj+".size()" : varj.startsWith("a") ? "alen" : "blen");

//replace start position of main input
tmp = tmp.replace("%POS"+(j+1)+"%", (_inputs.get(j) instanceof CNodeData
&& _inputs.get(j).getDataType().isMatrix()) ? (!varj.startsWith("b")) ? varj+"i" :
((TemplateUtils.isMatrix(_inputs.get(j)) || (_type.isElementwise()
&& TemplateUtils.isColVector(_inputs.get(j)))) && _type!=BinType.VECT_MATRIXMULT) ?
varj + ".pos(rix)" : "0" : "0");
}
//replace length information (e.g., after matrix mult)
if( _type == BinType.VECT_OUTERMULT_ADD || (_type == BinType.VECT_CBIND && vectorVector) ) {
for( int j=0; j<2; j++ )
tmp = tmp.replace("%LEN"+(j+1)+"%", _inputs.get(j).getVectorLength(api));
}
else { //general case
CNode mInput = getIntermediateInputVector();
if( mInput != null )
tmp = tmp.replace("%LEN%", mInput.getVectorLength(api));
varj + ".pos(rix)" : "0" : "0");
}
//replace length information (e.g., after matrix mult)
if( _type == BinType.VECT_OUTERMULT_ADD || (_type == BinType.VECT_CBIND && vectorVector)) {
for( int j=0; j<2; j++ )
tmp = tmp.replace("%LEN"+(j+1)+"%", _inputs.get(j).getVectorLength(api));
}
else { //general case
CNode mInput = getIntermediateInputVector();
if( mInput != null )
tmp = tmp.replace("%LEN%", mInput.getVectorLength(api));
}

sb.append(tmp);

//mark as generated
_generated = true;

return sb.toString();
} else {
boolean lsparseLhs =
sparse && _inputs.get(0) instanceof CNodeData && _inputs.get(0).getVarname().startsWith("a");
boolean lsparseRhs =
sparse && _inputs.get(1) instanceof CNodeData && _inputs.get(1).getVarname().startsWith("a");
boolean scalarInput = _inputs.get(0).getDataType().isScalar();
boolean scalarVector = (_inputs.get(0).getDataType().isScalar() && _inputs.get(1).getDataType().isMatrix());
boolean vectorVector = _inputs.get(0).getDataType().isMatrix() && _inputs.get(1).getDataType().isMatrix();
String var = createVarname();
String tmp = getLanguageTemplateClass(this, api).getTemplate(_type, lsparseLhs, lsparseRhs, scalarVector,
scalarInput, vectorVector);

tmp = tmp.replace("%TMP%", var);

//replace input references and start indexes
for(int j = 0; j < 2; j++) {
String varj = _inputs.get(j).getVarname(api);

//replace sparse and dense inputs
tmp = tmp.replace("%IN" + (j + 1) + "v%", varj + "vals");
tmp = tmp.replace("%IN" + (j + 1) + "i%", varj + "ix");
tmp = tmp.replace("%IN" + (j + 1) + "%", varj.startsWith("a") ? (
api == GeneratorAPI.JAVA ? varj : (_inputs.get(j).getDataType() == DataType.MATRIX ? varj +
".vals(0)" : varj)) : varj.startsWith("b") ? (api == GeneratorAPI.JAVA ? varj +
".values(rix)" : (_type == BinType.VECT_MATRIXMULT ? varj : varj + ".vals(0)")) :
_inputs.get(j).getDataType() == DataType.MATRIX ? (
api == GeneratorAPI.JAVA ? varj : varj + ".vals(0)") : varj);

//replace start position of main input
tmp = tmp.replace("%POS" + (j + 1) + "%", (_inputs.get(j) instanceof CNodeData &&
_inputs.get(j).getDataType().isMatrix()) ? (!varj.startsWith("b")) ? varj + "i" : (
(TemplateUtils.isMatrix(_inputs.get(j)) ||
(_type.isElementwise() && TemplateUtils.isColVector(_inputs.get(j)))) &&
_type != BinType.VECT_MATRIXMULT) ? varj + ".pos(rix)" : "0" : "0");
}
//replace length information (e.g., after matrix mult)
if(_type == BinType.VECT_OUTERMULT_ADD || (_type == BinType.VECT_CBIND && vectorVector)) {
for(int j = 0; j < 2; j++)
tmp = tmp.replace("%LEN" + (j + 1) + "%", _inputs.get(j).getVectorLength(api));
}
else { //general case
CNode mInput = getIntermediateInputVector();
if(mInput != null)
tmp = tmp.replace("%LEN%", mInput.getVectorLength(api));
}

sb.append(tmp);

//mark as generated
_generated = true;

return sb.toString();
}

sb.append(tmp);

//mark as generated
_generated = true;

return sb.toString();
}

private CNode getIntermediateInputVector() {
for( int i=0; i<2; i++ )
if( getInput().get(i).getDataType().isMatrix() )
return getInput().get(i);
return null;
}
}

public boolean getOutputType(boolean scalarVector, boolean lsparseLhs, boolean lsparseRhs) {
switch(_type) {
case VECT_POW_SCALAR: return !scalarVector && lsparseLhs;
case VECT_MULT_SCALAR:
case VECT_DIV_SCALAR:
case VECT_XOR_SCALAR:
case VECT_MIN_SCALAR:
case VECT_MAX_SCALAR:
case VECT_EQUAL_SCALAR:
case VECT_NOTEQUAL_SCALAR:
case VECT_LESS_SCALAR:
case VECT_LESSEQUAL_SCALAR:
case VECT_GREATER_SCALAR:
case VECT_GREATEREQUAL_SCALAR:
case VECT_BITWAND_SCALAR: return lsparseLhs || lsparseRhs;
case VECT_MULT:
case VECT_DIV:
case VECT_MINUS:
case VECT_PLUS:
case VECT_XOR:
case VECT_BITWAND:
case VECT_BIASADD:
case VECT_BIASMULT:
case VECT_MIN:
case VECT_MAX:
case VECT_NOTEQUAL:
case VECT_LESS:
case VECT_GREATER: return lsparseLhs && lsparseRhs;
default: return false;
}
}

@Override
public String toString() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ public String getTemplate(boolean sparseGen, long len, ArrayList<CNode> inputs,
sb.append( sparseInput ?
" LibSpoofPrimitives.vectWrite("+varj+"vals, %TMP%, "
+varj+"ix, "+pos+", "+off+", "+input._cols+");\n" :
" LibSpoofPrimitives.vectWrite("+(varj.startsWith("b")?varj+".values(rix)":varj)
varj.startsWith("STMP") ?
" LibSpoofPrimitives.vectWrite("+varj+".values(), %TMP%, "
+varj+".indexes(), "+pos+", "+off+", "+varj+".size());\n" :
" LibSpoofPrimitives.vectWrite("+(varj.startsWith("b")?varj+".values(rix)":varj)
+", %TMP%, "+pos+", "+off+", "+input._cols+");\n");
off += input._cols;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public class CNodeRow extends CNodeTpl
+ "import org.apache.sysds.runtime.codegen.SpoofOperator.SideInput;\n"
+ "import org.apache.sysds.runtime.codegen.SpoofRowwise;\n"
+ "import org.apache.sysds.runtime.codegen.SpoofRowwise.RowType;\n"
+ "import org.apache.sysds.runtime.data.SparseRowVector;\n"
+ "import org.apache.commons.math3.util.FastMath;\n"
+ "\n"
+ "public final class %TMP% extends SpoofRowwise { \n"
Expand Down Expand Up @@ -162,7 +163,8 @@ private String getOutputStatement(String varName) {
case NO_AGG_B1:
case NO_AGG_CONST:
if(api == GeneratorAPI.JAVA)
return TEMPLATE_NOAGG_OUT.replace("%IN%", varName).replace("%LEN%", _output.getVarname()+".length");
return TEMPLATE_NOAGG_OUT.replace("%IN%", varName.startsWith("STMP")?varName+".values(), "+varName+".indexes()":varName).replace("%LEN%",
varName.startsWith("STMP") ? varName+".size()" : _output.getVarname()+".length");
else
return TEMPLATE_NOAGG_CONST_OUT_CUDA.replace("%IN%", varName + ".vals(0)").replaceAll("%LEN%", _output.getVarname()+".length");
case FULL_AGG:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,13 @@ public String codegen(boolean sparse, GeneratorAPI api) {
String varj = _inputs.get(j-1).getVarname();
//replace sparse and dense inputs
tmp = tmp.replace("%IN"+j+"v%",
varj+(varj.startsWith("a")?"vals":"") );
varj+(varj.startsWith("a")?"vals" : varj.startsWith("STMP") ? ".values()" :"") );
tmp = tmp.replace("%IN"+j+"i%",
varj+(varj.startsWith("a")?"ix":"") );
varj+(varj.startsWith("a")?"ix": varj.startsWith("STMP") ? ".indexes()" :"") );
tmp = tmp.replace("%IN"+j+"%", varj );
tmp = tmp.replace("%POS%", varj.startsWith("a") ? varj+"i" : varj.startsWith("STMP") ? "0" : "");
tmp = tmp.replace("%LEN%",
varj.startsWith("a") ? "alen" : varj.startsWith("STMP") ? varj+".size()" : "");
}
sb.append(tmp);

Expand Down
Loading