diff --git a/src/main/java/org/apache/sysds/api/DMLOptions.java b/src/main/java/org/apache/sysds/api/DMLOptions.java index 763ac7b9388..b5b0908432b 100644 --- a/src/main/java/org/apache/sysds/api/DMLOptions.java +++ b/src/main/java/org/apache/sysds/api/DMLOptions.java @@ -85,6 +85,7 @@ public class DMLOptions { public boolean federatedCompilation = false; // Compile federated instructions based on input federation state and privacy constraints. public boolean noFedRuntimeConversion = false; // If activated, no runtime conversion of CP instructions to FED instructions will be performed. public int seed = -1; // The general seed for the execution, if -1 random (system time). + public boolean sparseIntermediate = false; // whether SparseRowIntermediates should be used for rowwise operations public final static DMLOptions defaultOptions = new DMLOptions(null); @@ -117,7 +118,8 @@ public String toString() { ", w=" + fedWorker + ", federatedCompilation=" + federatedCompilation + ", noFedRuntimeConversion=" + noFedRuntimeConversion + - ", seed=" + seed + + ", seed=" + seed + + ", sparseIntermediate=" + sparseIntermediate + '}'; } @@ -350,6 +352,10 @@ else if (lineageType.equalsIgnoreCase("debugger")) dmlOptions.seed = Integer.parseInt(line.getOptionValue("seed")); } + if(line.hasOption("sparseIntermediate")){ + dmlOptions.sparseIntermediate = true; + } + return dmlOptions; } @@ -431,7 +437,10 @@ private static Options createCLIOptions() { Option commandlineSeed = OptionBuilder .withDescription("A general seed for the execution through the commandline") .hasArg().create("seed"); - + Option sparseRowIntermediates = OptionBuilder + .withDescription("If activated, sparseRowVector intermediates will be used to calculate rowwise operations.") + .create("sparseIntermediate"); + options.addOption(configOpt); options.addOption(cleanOpt); options.addOption(statsOpt); @@ -451,6 +460,7 @@ private static Options createCLIOptions() { options.addOption(federatedCompilation); options.addOption(noFedRuntimeConversion); options.addOption(commandlineSeed); + options.addOption(sparseRowIntermediates); // Either a clean(-clean), a file(-f), a script(-s) or help(-help) needs to be specified OptionGroup fileOrScriptOpt = new OptionGroup() diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java b/src/main/java/org/apache/sysds/api/DMLScript.java index d6853891e24..44dc35ee75e 100644 --- a/src/main/java/org/apache/sysds/api/DMLScript.java +++ b/src/main/java/org/apache/sysds/api/DMLScript.java @@ -153,6 +153,9 @@ public class DMLScript // Global seed public static int SEED = -1; + // Sparse row flag + public static boolean SPARSE_INTERMEDIATE = false; + public static String MONITORING_ADDRESS = null; // flag that indicates whether or not to suppress any prints to stdout @@ -275,6 +278,7 @@ public static boolean executeScript( String[] args ) LINEAGE_ESTIMATE = dmlOptions.lineage_estimate; LINEAGE_DEBUGGER = dmlOptions.lineage_debugger; SEED = dmlOptions.seed; + SPARSE_INTERMEDIATE = dmlOptions.sparseIntermediate; String fnameOptConfig = dmlOptions.configFile; diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNode.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNode.java index 36cc8f49799..36ebd238aca 100644 --- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNode.java +++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNode.java @@ -19,6 +19,7 @@ package org.apache.sysds.hops.codegen.cplan; +import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.hops.codegen.SpoofCompiler.GeneratorAPI; import org.apache.sysds.hops.codegen.template.TemplateUtils; @@ -77,6 +78,14 @@ public String createVarname() { _genVar = "TMP"+_seqVar.getNextID(); return _genVar; } + + public String createVarname(boolean sparse) { + if(!sparse) { + return createVarname(); + } else { + return _genVar = "S" + createVarname(); + } + } public String getVarname() { return _genVar; @@ -98,6 +107,8 @@ public String getVectorLength(GeneratorAPI api) { return "len"; if(getVarname().startsWith("b")) return getVarname() + ".clen"; + else if(getVarname().startsWith("STMP")) + return "len"; else if(_dataType == DataType.MATRIX) return getVarname() + ".length"; } @@ -222,8 +233,13 @@ public boolean equals(Object that) { protected String replaceUnaryPlaceholders(String tmp, String varj, boolean vectIn, GeneratorAPI api) { //replace sparse and dense inputs - tmp = tmp.replace("%IN1v%", varj+"vals"); - tmp = tmp.replace("%IN1i%", varj+"ix"); + if(DMLScript.SPARSE_INTERMEDIATE) { + tmp = tmp.replace("%IN1v%", varj.startsWith("STMP") ? varj+".values()" : varj+"vals"); + tmp = tmp.replace("%IN1i%", varj.startsWith("STMP") ? varj+".indexes()" :varj+"ix"); + } else { + tmp = tmp.replace("%IN1v%", varj+"vals"); + tmp = tmp.replace("%IN1i%", varj+"ix"); + } tmp = tmp.replace("%IN1%", (vectIn && TemplateUtils.isMatrix(_inputs.get(0))) ? ((api == GeneratorAPI.JAVA) ? varj + ".values(rix)" : varj + ".vals(0)" ) : diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java index b29d586c38a..b71aa8945fa 100644 --- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java +++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java @@ -22,6 +22,7 @@ import java.util.Arrays; import org.apache.commons.lang3.StringUtils; +import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Opcodes; import org.apache.sysds.hops.codegen.template.TemplateUtils; import org.apache.sysds.common.Types.DataType; @@ -157,61 +158,118 @@ public String codegen(boolean sparse, GeneratorAPI api) { //generate children sb.append(_inputs.get(0).codegen(sparse, api)); sb.append(_inputs.get(1).codegen(sparse, api)); - - //generate binary operation (use sparse template, if data input) - boolean lsparseLhs = sparse && _inputs.get(0) instanceof CNodeData - && _inputs.get(0).getVarname().startsWith("a"); - boolean lsparseRhs = sparse && _inputs.get(1) instanceof CNodeData - && _inputs.get(1).getVarname().startsWith("a"); - boolean scalarInput = _inputs.get(0).getDataType().isScalar(); - boolean scalarVector = (_inputs.get(0).getDataType().isScalar() - && _inputs.get(1).getDataType().isMatrix()); - boolean vectorVector = _inputs.get(0).getDataType().isMatrix() - && _inputs.get(1).getDataType().isMatrix(); - String var = createVarname(); - String tmp = getLanguageTemplateClass(this, api) - .getTemplate(_type, lsparseLhs, lsparseRhs, scalarVector, scalarInput, vectorVector); - tmp = tmp.replace("%TMP%", var); - - //replace input references and start indexes - for( int j=0; j<2; j++ ) { - String varj = _inputs.get(j).getVarname(api); - - //replace sparse and dense inputs - tmp = tmp.replace("%IN"+(j+1)+"v%", varj+"vals"); - tmp = tmp.replace("%IN"+(j+1)+"i%", varj+"ix"); - tmp = tmp.replace("%IN"+(j+1)+"%", - varj.startsWith("a") ? (api == GeneratorAPI.JAVA ? varj : + if(DMLScript.SPARSE_INTERMEDIATE) { + //generate binary operation (use sparse template, if data input) + boolean lsparseLhs = sparse ? _inputs.get(0) instanceof CNodeData + && _inputs.get(0).getVarname().startsWith("a") || + _inputs.get(0).getVarname().startsWith("STMP") : false; + boolean lsparseRhs = sparse ? _inputs.get(1) instanceof CNodeData + && _inputs.get(1).getVarname().startsWith("a") || + _inputs.get(1).getVarname().startsWith("STMP") : false; + boolean scalarInput = _inputs.get(0).getDataType().isScalar(); + boolean scalarVector = (_inputs.get(0).getDataType().isScalar() + && _inputs.get(1).getDataType().isMatrix()); + boolean vectorVector = _inputs.get(0).getDataType().isMatrix() + && _inputs.get(1).getDataType().isMatrix(); + String var = createVarname(sparse && getOutputType(scalarVector, lsparseLhs, lsparseRhs)); + String tmp = getLanguageTemplateClass(this, api) + .getTemplate(_type, lsparseLhs, lsparseRhs, scalarVector, scalarInput, vectorVector); + + tmp = tmp.replace("%TMP%", var); + + //replace input references and start indexes + for( int j=0; j<2; j++ ) { + String varj = _inputs.get(j).getVarname(api); + //replace sparse and dense inputs + tmp = tmp.replace("%IN"+(j+1)+"v%", varj.startsWith("STMP") ? varj+".values()" : varj+"vals"); + tmp = tmp.replace("%IN"+(j+1)+"i%", varj.startsWith("STMP") ? varj+".indexes()" : varj+"ix"); + tmp = tmp.replace("%IN"+(j+1)+"%", + varj.startsWith("a") ? (api == GeneratorAPI.JAVA ? varj : (_inputs.get(j).getDataType() == DataType.MATRIX ? varj + ".vals(0)" : varj)) : - varj.startsWith("b") ? (api == GeneratorAPI.JAVA ? varj + ".values(rix)" : - (_type == BinType.VECT_MATRIXMULT ? varj : varj + ".vals(0)")) : + varj.startsWith("b") ? (api == GeneratorAPI.JAVA ? varj + ".values(rix)" : + (_type == BinType.VECT_MATRIXMULT ? varj : varj + ".vals(0)")) : _inputs.get(j).getDataType() == DataType.MATRIX ? (api == GeneratorAPI.JAVA ? varj : varj + ".vals(0)") : varj); - - //replace start position of main input - tmp = tmp.replace("%POS"+(j+1)+"%", (_inputs.get(j) instanceof CNodeData - && _inputs.get(j).getDataType().isMatrix()) ? (!varj.startsWith("b")) ? varj+"i" : + + tmp = tmp.replace("%SLEN"+(j+1)+"%", varj.startsWith("STMP") ? varj+".size()" : varj.startsWith("a") ? "alen" : "blen"); + + //replace start position of main input + tmp = tmp.replace("%POS"+(j+1)+"%", (_inputs.get(j) instanceof CNodeData + && _inputs.get(j).getDataType().isMatrix()) ? (!varj.startsWith("b")) ? varj+"i" : ((TemplateUtils.isMatrix(_inputs.get(j)) || (_type.isElementwise() && TemplateUtils.isColVector(_inputs.get(j)))) && _type!=BinType.VECT_MATRIXMULT) ? - varj + ".pos(rix)" : "0" : "0"); - } - //replace length information (e.g., after matrix mult) - if( _type == BinType.VECT_OUTERMULT_ADD || (_type == BinType.VECT_CBIND && vectorVector) ) { - for( int j=0; j<2; j++ ) - tmp = tmp.replace("%LEN"+(j+1)+"%", _inputs.get(j).getVectorLength(api)); - } - else { //general case - CNode mInput = getIntermediateInputVector(); - if( mInput != null ) - tmp = tmp.replace("%LEN%", mInput.getVectorLength(api)); + varj + ".pos(rix)" : "0" : "0"); + } + //replace length information (e.g., after matrix mult) + if( _type == BinType.VECT_OUTERMULT_ADD || (_type == BinType.VECT_CBIND && vectorVector)) { + for( int j=0; j<2; j++ ) + tmp = tmp.replace("%LEN"+(j+1)+"%", _inputs.get(j).getVectorLength(api)); + } + else { //general case + CNode mInput = getIntermediateInputVector(); + if( mInput != null ) + tmp = tmp.replace("%LEN%", mInput.getVectorLength(api)); + } + + sb.append(tmp); + + //mark as generated + _generated = true; + + return sb.toString(); + } else { + boolean lsparseLhs = + sparse && _inputs.get(0) instanceof CNodeData && _inputs.get(0).getVarname().startsWith("a"); + boolean lsparseRhs = + sparse && _inputs.get(1) instanceof CNodeData && _inputs.get(1).getVarname().startsWith("a"); + boolean scalarInput = _inputs.get(0).getDataType().isScalar(); + boolean scalarVector = (_inputs.get(0).getDataType().isScalar() && _inputs.get(1).getDataType().isMatrix()); + boolean vectorVector = _inputs.get(0).getDataType().isMatrix() && _inputs.get(1).getDataType().isMatrix(); + String var = createVarname(); + String tmp = getLanguageTemplateClass(this, api).getTemplate(_type, lsparseLhs, lsparseRhs, scalarVector, + scalarInput, vectorVector); + + tmp = tmp.replace("%TMP%", var); + + //replace input references and start indexes + for(int j = 0; j < 2; j++) { + String varj = _inputs.get(j).getVarname(api); + + //replace sparse and dense inputs + tmp = tmp.replace("%IN" + (j + 1) + "v%", varj + "vals"); + tmp = tmp.replace("%IN" + (j + 1) + "i%", varj + "ix"); + tmp = tmp.replace("%IN" + (j + 1) + "%", varj.startsWith("a") ? ( + api == GeneratorAPI.JAVA ? varj : (_inputs.get(j).getDataType() == DataType.MATRIX ? varj + + ".vals(0)" : varj)) : varj.startsWith("b") ? (api == GeneratorAPI.JAVA ? varj + + ".values(rix)" : (_type == BinType.VECT_MATRIXMULT ? varj : varj + ".vals(0)")) : + _inputs.get(j).getDataType() == DataType.MATRIX ? ( + api == GeneratorAPI.JAVA ? varj : varj + ".vals(0)") : varj); + + //replace start position of main input + tmp = tmp.replace("%POS" + (j + 1) + "%", (_inputs.get(j) instanceof CNodeData && + _inputs.get(j).getDataType().isMatrix()) ? (!varj.startsWith("b")) ? varj + "i" : ( + (TemplateUtils.isMatrix(_inputs.get(j)) || + (_type.isElementwise() && TemplateUtils.isColVector(_inputs.get(j)))) && + _type != BinType.VECT_MATRIXMULT) ? varj + ".pos(rix)" : "0" : "0"); + } + //replace length information (e.g., after matrix mult) + if(_type == BinType.VECT_OUTERMULT_ADD || (_type == BinType.VECT_CBIND && vectorVector)) { + for(int j = 0; j < 2; j++) + tmp = tmp.replace("%LEN" + (j + 1) + "%", _inputs.get(j).getVectorLength(api)); + } + else { //general case + CNode mInput = getIntermediateInputVector(); + if(mInput != null) + tmp = tmp.replace("%LEN%", mInput.getVectorLength(api)); + } + + sb.append(tmp); + + //mark as generated + _generated = true; + + return sb.toString(); } - - sb.append(tmp); - - //mark as generated - _generated = true; - - return sb.toString(); } private CNode getIntermediateInputVector() { @@ -219,7 +277,39 @@ private CNode getIntermediateInputVector() { if( getInput().get(i).getDataType().isMatrix() ) return getInput().get(i); return null; - } + } + + public boolean getOutputType(boolean scalarVector, boolean lsparseLhs, boolean lsparseRhs) { + switch(_type) { + case VECT_POW_SCALAR: return !scalarVector && lsparseLhs; + case VECT_MULT_SCALAR: + case VECT_DIV_SCALAR: + case VECT_XOR_SCALAR: + case VECT_MIN_SCALAR: + case VECT_MAX_SCALAR: + case VECT_EQUAL_SCALAR: + case VECT_NOTEQUAL_SCALAR: + case VECT_LESS_SCALAR: + case VECT_LESSEQUAL_SCALAR: + case VECT_GREATER_SCALAR: + case VECT_GREATEREQUAL_SCALAR: + case VECT_BITWAND_SCALAR: return lsparseLhs || lsparseRhs; + case VECT_MULT: + case VECT_DIV: + case VECT_MINUS: + case VECT_PLUS: + case VECT_XOR: + case VECT_BITWAND: + case VECT_BIASADD: + case VECT_BIASMULT: + case VECT_MIN: + case VECT_MAX: + case VECT_NOTEQUAL: + case VECT_LESS: + case VECT_GREATER: return lsparseLhs && lsparseRhs; + default: return false; + } + } @Override public String toString() { diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeNary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeNary.java index dcf18ec6569..35c351546da 100644 --- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeNary.java +++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeNary.java @@ -60,7 +60,10 @@ public String getTemplate(boolean sparseGen, long len, ArrayList inputs, sb.append( sparseInput ? " LibSpoofPrimitives.vectWrite("+varj+"vals, %TMP%, " +varj+"ix, "+pos+", "+off+", "+input._cols+");\n" : - " LibSpoofPrimitives.vectWrite("+(varj.startsWith("b")?varj+".values(rix)":varj) + varj.startsWith("STMP") ? + " LibSpoofPrimitives.vectWrite("+varj+".values(), %TMP%, " + +varj+".indexes(), "+pos+", "+off+", "+varj+".size());\n" : + " LibSpoofPrimitives.vectWrite("+(varj.startsWith("b")?varj+".values(rix)":varj) +", %TMP%, "+pos+", "+off+", "+input._cols+");\n"); off += input._cols; } diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeRow.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeRow.java index 77dec97cbe1..c0d06b4bcbc 100644 --- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeRow.java +++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeRow.java @@ -37,6 +37,7 @@ public class CNodeRow extends CNodeTpl + "import org.apache.sysds.runtime.codegen.SpoofOperator.SideInput;\n" + "import org.apache.sysds.runtime.codegen.SpoofRowwise;\n" + "import org.apache.sysds.runtime.codegen.SpoofRowwise.RowType;\n" + + "import org.apache.sysds.runtime.data.SparseRowVector;\n" + "import org.apache.commons.math3.util.FastMath;\n" + "\n" + "public final class %TMP% extends SpoofRowwise { \n" @@ -162,7 +163,8 @@ private String getOutputStatement(String varName) { case NO_AGG_B1: case NO_AGG_CONST: if(api == GeneratorAPI.JAVA) - return TEMPLATE_NOAGG_OUT.replace("%IN%", varName).replace("%LEN%", _output.getVarname()+".length"); + return TEMPLATE_NOAGG_OUT.replace("%IN%", varName.startsWith("STMP")?varName+".values(), "+varName+".indexes()":varName).replace("%LEN%", + varName.startsWith("STMP") ? varName+".size()" : _output.getVarname()+".length"); else return TEMPLATE_NOAGG_CONST_OUT_CUDA.replace("%IN%", varName + ".vals(0)").replaceAll("%LEN%", _output.getVarname()+".length"); case FULL_AGG: diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeTernary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeTernary.java index 5e811092836..c6ff9802b16 100644 --- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeTernary.java +++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeTernary.java @@ -82,10 +82,13 @@ public String codegen(boolean sparse, GeneratorAPI api) { String varj = _inputs.get(j-1).getVarname(); //replace sparse and dense inputs tmp = tmp.replace("%IN"+j+"v%", - varj+(varj.startsWith("a")?"vals":"") ); + varj+(varj.startsWith("a")?"vals" : varj.startsWith("STMP") ? ".values()" :"") ); tmp = tmp.replace("%IN"+j+"i%", - varj+(varj.startsWith("a")?"ix":"") ); + varj+(varj.startsWith("a")?"ix": varj.startsWith("STMP") ? ".indexes()" :"") ); tmp = tmp.replace("%IN"+j+"%", varj ); + tmp = tmp.replace("%POS%", varj.startsWith("a") ? varj+"i" : varj.startsWith("STMP") ? "0" : ""); + tmp = tmp.replace("%LEN%", + varj.startsWith("a") ? "alen" : varj.startsWith("STMP") ? varj+".size()" : ""); } sb.append(tmp); diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeUnary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeUnary.java index fe67995b6b5..9088d7c6e93 100644 --- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeUnary.java +++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeUnary.java @@ -23,6 +23,7 @@ import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.StringUtils; +import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Opcodes; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.runtime.util.UtilFunctions; @@ -111,10 +112,12 @@ public String codegen(boolean sparse, GeneratorAPI api) { sb.append(_inputs.get(0).codegen(sparse, api)); //generate unary operation - boolean lsparse = sparse && (_inputs.get(0) instanceof CNodeData - && _inputs.get(0).getVarname().startsWith("a") - && !_inputs.get(0).isLiteral()); - String var = createVarname(); + boolean lsparse = sparse && + ((_inputs.get(0) instanceof CNodeData + && _inputs.get(0).getVarname().startsWith("a") + && !_inputs.get(0).isLiteral()) + || (_inputs.get(0).getVarname().startsWith("STMP"))); + String var = createVarname(DMLScript.SPARSE_INTERMEDIATE && lsparse && getOutputType()); String tmp = getLanguageTemplateClass(this, api).getTemplate(_type, lsparse); tmp = tmp.replaceAll("%TMP%", var); @@ -130,6 +133,24 @@ public String codegen(boolean sparse, GeneratorAPI api) { return sb.toString(); } + + public boolean getOutputType() { + switch(_type) { + case VECT_SQRT: + case VECT_ABS: + case VECT_ROUND: + case VECT_CEIL: + case VECT_FLOOR: + case VECT_SIN: + case VECT_TAN: + case VECT_ASIN: + case VECT_ATAN: + case VECT_SINH: + case VECT_TANH: + case VECT_SIGN: return true; + default: return false; + } + } @Override public String toString() { diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java index 40496249e52..8bce035604c 100644 --- a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java +++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java @@ -19,6 +19,7 @@ package org.apache.sysds.hops.codegen.cplan.java; +import org.apache.sysds.api.DMLScript; import org.apache.sysds.hops.codegen.cplan.CNodeBinary.BinType; import org.apache.sysds.hops.codegen.cplan.CodeTemplate; @@ -68,13 +69,22 @@ public String getTemplate(BinType type, boolean sparseLhs, boolean sparseRhs, } //vector-scalar operations + case VECT_POW_SCALAR: { + String vectName = type.getVectorPrimitiveName(); + if( scalarVector ) + return sparseRhs ? " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2v%, %IN2i%, %POS2%, alen, %LEN%);\n" : + " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS2%, %LEN%);\n"; + else if(DMLScript.SPARSE_INTERMEDIATE) { + return sparseLhs ? " SparseRowVector %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%LEN%, %IN1v%, %IN2%, %IN1i%, %POS1%, %SLEN1%);\n" : + " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS1%, %LEN%);\n"; + } else { + return sparseLhs ? " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%);\n" : + " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS1%, %LEN%);\n"; + } + } case VECT_MULT_SCALAR: case VECT_DIV_SCALAR: - case VECT_MINUS_SCALAR: - case VECT_PLUS_SCALAR: - case VECT_POW_SCALAR: case VECT_XOR_SCALAR: - case VECT_BITWAND_SCALAR: case VECT_MIN_SCALAR: case VECT_MAX_SCALAR: case VECT_EQUAL_SCALAR: @@ -82,7 +92,22 @@ public String getTemplate(BinType type, boolean sparseLhs, boolean sparseRhs, case VECT_LESS_SCALAR: case VECT_LESSEQUAL_SCALAR: case VECT_GREATER_SCALAR: - case VECT_GREATEREQUAL_SCALAR: { + case VECT_GREATEREQUAL_SCALAR: + case VECT_BITWAND_SCALAR: { + String vectName = type.getVectorPrimitiveName(); + if(scalarVector) { + if(sparseRhs) + return DMLScript.SPARSE_INTERMEDIATE ? " SparseRowVector %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%LEN%, %IN1%, %IN2v%, %IN2i%, %POS2%, %SLEN1%);\n" : + " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2v%, %IN2i%, %POS2%, alen, %LEN%);\n"; + } else { + if(sparseLhs) + return DMLScript.SPARSE_INTERMEDIATE ? " SparseRowVector %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%LEN%, %IN1v%, %IN2%, %IN1i%, %POS1%, %SLEN1%);\n" : + " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%);\n"; + } + return " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS1%, %LEN%);\n"; + } + case VECT_MINUS_SCALAR: + case VECT_PLUS_SCALAR: { String vectName = type.getVectorPrimitiveName(); if( scalarVector ) return sparseRhs ? " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2v%, %IN2i%, %POS2%, alen, %LEN%);\n" : @@ -115,20 +140,34 @@ else if( !vectorVector ) case VECT_BIASMULT: case VECT_MIN: case VECT_MAX: - case VECT_EQUAL: case VECT_NOTEQUAL: case VECT_LESS: + case VECT_GREATER:{ + String vectName = type.getVectorPrimitiveName(); + if(DMLScript.SPARSE_INTERMEDIATE && sparseLhs && sparseRhs) { + return " SparseRowVector %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%LEN%, %IN1v%, %IN2v%, %IN1i%, %IN2i%, %POS1%, %POS2%, %SLEN1%, %SLEN2%);\n"; + } else { + return sparseLhs ? + " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, %LEN%);\n" : + sparseRhs ? + " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2v%, %POS1%, %IN2i%, %POS2%, alen, %LEN%);\n" : + " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n"; + } + } + case VECT_EQUAL: case VECT_LESSEQUAL: - case VECT_GREATER: case VECT_GREATEREQUAL: { String vectName = type.getVectorPrimitiveName(); - return sparseLhs ? + if(DMLScript.SPARSE_INTERMEDIATE && sparseLhs && sparseRhs) { + return " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%LEN%, %IN1v%, %IN2v%, %IN1i%, %IN2i%, %POS1%, %POS2%, %SLEN1%, %SLEN2%);\n"; + } else { + return sparseLhs ? " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, %LEN%);\n" : sparseRhs ? - " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2v%, %POS1%, %IN2i%, %POS2%, alen, %LEN%);\n" : - " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n"; + " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2v%, %POS1%, %IN2i%, %POS2%, alen, %LEN%);\n" : + " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n"; + } } - //scalar-scalar operations case MULT: return " double %TMP% = %IN1% * %IN2%;\n"; diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Ternary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Ternary.java index a86d51cca86..64d282bbb8d 100644 --- a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Ternary.java +++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Ternary.java @@ -51,7 +51,7 @@ public String getTemplate(CNodeTernary.TernaryType type, boolean sparse) { case LOOKUP_RC1: return sparse ? - " double %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, %IN3%-1);\n" : + " double %TMP% = getValue(%IN1v%, %IN1i%, %POS%, %LEN%, %IN3%-1);\n" : " double %TMP% = getValue(%IN1%, %IN2%, rix, %IN3%-1);\n"; case LOOKUP_RVECT1: diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Unary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Unary.java index d8a1085df58..8e5d28fb4a0 100644 --- a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Unary.java +++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Unary.java @@ -20,6 +20,7 @@ package org.apache.sysds.hops.codegen.cplan.java; import org.apache.commons.lang3.StringUtils; +import org.apache.sysds.api.DMLScript; import org.apache.sysds.hops.codegen.cplan.CNodeUnary.UnaryType; import org.apache.sysds.hops.codegen.cplan.CodeTemplate; @@ -38,25 +39,32 @@ public String getTemplate(UnaryType type, boolean sparse) { return sparse ? " double %TMP% = LibSpoofPrimitives.vect"+vectName+"(%IN1v%, %IN1i%, %POS1%, alen, len);\n": " double %TMP% = LibSpoofPrimitives.vect"+vectName+"(%IN1%, %POS1%, %LEN%);\n"; } - case VECT_EXP: - case VECT_POW2: - case VECT_MULT2: + case VECT_SQRT: - case VECT_LOG: case VECT_ABS: case VECT_ROUND: case VECT_CEIL: case VECT_FLOOR: - case VECT_SIGN: case VECT_SIN: - case VECT_COS: case VECT_TAN: case VECT_ASIN: - case VECT_ACOS: case VECT_ATAN: case VECT_SINH: - case VECT_COSH: case VECT_TANH: + case VECT_SIGN:{ + String vectName = type.getVectorPrimitiveName(); + return sparse ? DMLScript.SPARSE_INTERMEDIATE ? + " SparseRowVector %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(len, %IN1v%, %IN1i%, %POS1%, alen);\n" : + " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN1i%, %POS1%, alen, len);\n" : + " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %POS1%, %LEN%);\n"; + } + case VECT_EXP: + case VECT_POW2: + case VECT_MULT2: + case VECT_LOG: + case VECT_COS: + case VECT_ACOS: + case VECT_COSH: case VECT_CUMSUM: case VECT_CUMMIN: case VECT_CUMMAX: