Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions src/main/java/org/apache/sysds/hops/ReorgOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ public boolean isGPUEnabled() {
@Override
public boolean isMultiThreadedOpType() {
return _op == ReOrgOp.TRANS
|| _op == ReOrgOp.SORT;
|| _op == ReOrgOp.SORT
|| _op == ReOrgOp.REV;
}

@Override
Expand Down Expand Up @@ -148,11 +149,22 @@ else if( getDim1()==1 && getDim2()==1 )
}
break;
}
case DIAG:
case DIAG: {
Transform transform1 = new Transform(
getInput().get(0).constructLops(),
_op, getDataType(), getValueType(), et);
setOutputDimensions(transform1);
setLineNumbers(transform1);
setLops(transform1);
break;
}
case REV: {
long numel = getDim1() * getDim2();
int k = (numel < 3000_000) ?
1 : OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
Transform transform1 = new Transform(
getInput().get(0).constructLops(),
_op, getDataType(), getValueType(), et);
_op, getDataType(), getValueType(), et, k);
setOutputDimensions(transform1);
setLineNumbers(transform1);
setLops(transform1);
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/apache/sysds/lops/Transform.java
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ private String getInstructions(String input1, int numInputs, String output) {
sb.append( this.prepOutputOperand(output));

if( (getExecType()==ExecType.CP || getExecType()==ExecType.FED)
&& (_operation == ReOrgOp.TRANS || _operation == ReOrgOp.SORT) ) {
&& (_operation == ReOrgOp.TRANS || _operation == ReOrgOp.REV || _operation == ReOrgOp.SORT) ) {
sb.append( OPERAND_DELIMITOR );
sb.append( _numThreads );
if ( getExecType()==ExecType.FED ) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,13 @@ public static ReorgCPInstruction parseInstruction ( String str ) {
return new ReorgCPInstruction(new ReorgOperator(SwapIndex.getSwapIndexFnObject(), k), in, out, opcode, str);
}
else if ( opcode.equalsIgnoreCase(Opcodes.REV.toString()) ) {
parseUnaryInstruction(str, in, out); //max 2 operands
return new ReorgCPInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str);
InstructionUtils.checkNumFields(str, 2, 3);
in.split(parts[1]);
out.split(parts[2]);
// Safely parse the number of threads 'k' if it exists
int k = (parts.length > 3) ? Integer.parseInt(parts[3]) : 1;
// Create the instruction, passing 'k' to the operator
return new ReorgCPInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject(), k), in, out, opcode, str);
}
else if (opcode.equalsIgnoreCase(Opcodes.ROLL.toString())) {
InstructionUtils.checkNumFields(str, 3);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,10 @@ public static MatrixBlock reorg( MatrixBlock in, MatrixBlock out, ReorgOperator
else
return transpose(in, out);
case REV:
return rev(in, out);
if (op.getNumThreads() > 1)
return rev(in, out, op.getNumThreads());
else
return rev(in, out);
case ROLL:
RollIndex rix = (RollIndex) op.fn;
return roll(in, out, rix.getShift());
Expand Down Expand Up @@ -389,10 +392,72 @@ public static MatrixBlock rev( MatrixBlock in, MatrixBlock out ) {
return out;
}

public static MatrixBlock rev(MatrixBlock in, MatrixBlock out, int k) {
if (k <= 1 || in.isEmptyBlock(false) ) {
return rev(in, out); // fallback to single-threaded

}
final int numRows = in.getNumRows();
final int numCols = in.getNumColumns();
final boolean sparse = in.isInSparseFormat();

// Prepare output block
out.reset(numRows, numCols, sparse);

// Before starting threads, ensure the output sparse block is allocated!
if (sparse) {
out.allocateSparseRowsBlock(false);
}

// Set up thread pool
ExecutorService pool = CommonThreadPool.get(k);
try {
int blklen = (int) Math.ceil((double) numRows / k);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend to create smaller tasks (e.g., numRows/k/4) which yields better load balancing.

List<Future<?>> tasks = new ArrayList<>();

for (int i = 0; i < k; i++) {
final int startRow = i * blklen;
final int endRow = Math.min((i + 1) * blklen, numRows);

tasks.add(pool.submit(() -> {
if (!sparse) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create a static method for this kernel which is called from both the single-threaded implementation as well as the multi-threaded implementation

// Dense case
double[] inVals = in.getDenseBlockValues();
double[] outVals = out.getDenseBlockValues();
for (int r = startRow; r < endRow; r++) {
int revRow = numRows - r - 1;
System.arraycopy(inVals, revRow * numCols, outVals, r * numCols, numCols);
}
} else {
// Sparse case
SparseBlock inBlk = in.getSparseBlock();
SparseBlock outBlk = out.getSparseBlock();
for (int r = startRow; r < endRow; r++) {
int revRow = numRows - r - 1;
if (!inBlk.isEmpty(revRow)) {
outBlk.set(r, inBlk.get(revRow), true);
}
}
}
}));
}

// Wait for all threads
for (Future<?> task : tasks) {
task.get();
}
} catch (Exception ex) {
throw new DMLRuntimeException(ex);
} finally {
pool.shutdown();
}
return out;
}

public static void rev( IndexedMatrixValue in, long rlen, int blen, ArrayList<IndexedMatrixValue> out ) {
//input block reverse
MatrixIndexes inix = in.getIndexes();
MatrixBlock inblk = (MatrixBlock) in.getValue();
MatrixBlock inblk = (MatrixBlock) in.getValue();
MatrixBlock tmpblk = rev(inblk, new MatrixBlock(inblk.getNumRows(), inblk.getNumColumns(), inblk.isInSparseFormat()));

//split and expand block if necessary (at most 2 blocks)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.HashMap;

import org.apache.sysds.common.Opcodes;
import org.apache.sysds.utils.stats.InfrastructureAnalyzer;
import org.junit.Assert;
import org.junit.Test;
import org.apache.sysds.api.DMLScript;
Expand All @@ -44,10 +45,17 @@ public class FullReverseTest extends AutomatedTestBase
private static final String TEST_CLASS_DIR = TEST_DIR + FullReverseTest.class.getSimpleName() + "/";

private final static int rows1 = 2017;
private final static int cols1 = 1001;
private final static int cols1 = 1001;
private final static double sparsity1 = 0.7;
private final static double sparsity2 = 0.1;

// Multi-threading test parameters
private final static int rows_mt = 5018; // Larger for multi-threading benefits
private final static int cols_mt = 1001; // Larger for multi-threading benefits
private final static int[] threadCounts = {1, 2, 4, 8};
// Set global parallelism for SystemDS to enable multi-threading
private final static int oldPar = InfrastructureAnalyzer.getLocalParallelism();

@Override
public void setUp() {
TestUtils.clearAssertionInformation();
Expand All @@ -64,7 +72,22 @@ public void testReverseVectorDenseCP() {
public void testReverseVectorSparseCP() {
runReverseTest(TEST_NAME1, false, true, ExecType.CP);
}


@Test
public void testReverseVectorDenseCPMultiThread() {
runReverseTestMultiThread(TEST_NAME1, false, false, ExecType.CP);
}

@Test
public void testReverseVectorSparseCPMultiThread() {
runReverseTestMultiThread(TEST_NAME1, false, true, ExecType.CP);
}

@Test
public void testReverseVectorDenseSPMultiThread() {
runReverseTestMultiThread(TEST_NAME1, false, false, ExecType.SPARK);
}

@Test
public void testReverseVectorDenseSP() {
runReverseTest(TEST_NAME1, false, false, ExecType.SPARK);
Expand Down Expand Up @@ -165,6 +188,78 @@ else if ( instType == ExecType.SPARK )
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
}


private void runReverseTestMultiThread(String testname, boolean matrix, boolean sparse, ExecType instType)
{
// Compare single-thread vs multi-thread results
// HashMap<CellIndex, Double> stResult = runReverseWithThreads(testname, matrix, sparse, instType, 1);
HashMap<CellIndex, Double> mtResult = runReverseWithThreads(testname, matrix, sparse, instType, 8);

// Compare results to ensure consistency
// TestUtils.compareMatrices(stResult, mtResult, 0, "ST-Result", "MT-Result");
}

private HashMap<CellIndex, Double> runReverseWithThreads(String testname, boolean matrix, boolean sparse, ExecType instType, int numThreads)
{
//rtplatform for MR
ExecMode platformOld = rtplatform;
switch( instType ){
case SPARK: rtplatform = ExecMode.SPARK; break;
default: rtplatform = ExecMode.HYBRID; break;
}
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
if( rtplatform == ExecMode.SPARK )
DMLScript.USE_LOCAL_SPARK_CONFIG = true;

String TEST_NAME = testname;

System.out.println("I am trying to run multi-thread");

try
{
System.setProperty("sysds.parallel.threads", String.valueOf(numThreads));

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove (I don't think we use this property internally)


// int cols = matrix ? cols_mt : 1;
double sparsity = sparse ? sparsity2 : sparsity1;
getAndLoadTestConfiguration(TEST_NAME);

/* This is for running the junit test the new way, i.e., construct the arguments directly */
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";

// Add thread count to program arguments
programArgs = new String[]{"-stats","-explain","-args", input("A"), output("B") };

fullRScriptName = HOME + TEST_NAME + ".R";
rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + expectedDir();

//generate actual dataset
double[][] A = getRandomMatrix(rows_mt, cols_mt, -1, 1, sparsity, 7);
writeInputMatrixWithMTD("A", A, true);

// Run with specified thread count (this is the key part)
runTest(true, false, null, -1);

//read and return results
HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("B");

//check generated opcode
if( instType == ExecType.CP )
Assert.assertTrue("Missing opcode: rev", Statistics.getCPHeavyHitterOpCodes().contains(Opcodes.REV.toString()));
else if ( instType == ExecType.SPARK )
Assert.assertTrue("Missing opcode: "+Instruction.SP_INST_PREFIX+Opcodes.REV.toString(), Statistics.getCPHeavyHitterOpCodes().contains(Instruction.SP_INST_PREFIX+Opcodes.REV));

return dmlfile;
}
catch(Exception ex) {
throw new RuntimeException(ex);
}
finally {
//reset flags
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
System.setProperty("sysds.parallel.threads", String.valueOf(oldPar));
}
}

}
Loading