-
Notifications
You must be signed in to change notification settings - Fork 541
[SYSTEMDS-3730] Multi-threaded rev reorg operation #2290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c78e751
17fe58a
f930455
91852f2
faab068
4002c22
5b586ae
9eb29fc
5fe6be4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -128,7 +128,10 @@ public static MatrixBlock reorg( MatrixBlock in, MatrixBlock out, ReorgOperator | |
| else | ||
| return transpose(in, out); | ||
| case REV: | ||
| return rev(in, out); | ||
| if (op.getNumThreads() > 1) | ||
| return rev(in, out, op.getNumThreads()); | ||
| else | ||
| return rev(in, out); | ||
| case ROLL: | ||
| RollIndex rix = (RollIndex) op.fn; | ||
| return roll(in, out, rix.getShift()); | ||
|
|
@@ -389,10 +392,72 @@ public static MatrixBlock rev( MatrixBlock in, MatrixBlock out ) { | |
| return out; | ||
| } | ||
|
|
||
| public static MatrixBlock rev(MatrixBlock in, MatrixBlock out, int k) { | ||
| if (k <= 1 || in.isEmptyBlock(false) ) { | ||
| return rev(in, out); // fallback to single-threaded | ||
|
|
||
| } | ||
| final int numRows = in.getNumRows(); | ||
| final int numCols = in.getNumColumns(); | ||
| final boolean sparse = in.isInSparseFormat(); | ||
|
|
||
| // Prepare output block | ||
| out.reset(numRows, numCols, sparse); | ||
|
|
||
| // Before starting threads, ensure the output sparse block is allocated! | ||
| if (sparse) { | ||
| out.allocateSparseRowsBlock(false); | ||
| } | ||
|
|
||
| // Set up thread pool | ||
| ExecutorService pool = CommonThreadPool.get(k); | ||
| try { | ||
| int blklen = (int) Math.ceil((double) numRows / k); | ||
| List<Future<?>> tasks = new ArrayList<>(); | ||
|
|
||
| for (int i = 0; i < k; i++) { | ||
| final int startRow = i * blklen; | ||
| final int endRow = Math.min((i + 1) * blklen, numRows); | ||
|
|
||
| tasks.add(pool.submit(() -> { | ||
| if (!sparse) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. create a static method for this kernel which is called from both the single-threaded implementation as well as the multi-threaded implementation |
||
| // Dense case | ||
| double[] inVals = in.getDenseBlockValues(); | ||
| double[] outVals = out.getDenseBlockValues(); | ||
| for (int r = startRow; r < endRow; r++) { | ||
| int revRow = numRows - r - 1; | ||
| System.arraycopy(inVals, revRow * numCols, outVals, r * numCols, numCols); | ||
| } | ||
| } else { | ||
| // Sparse case | ||
| SparseBlock inBlk = in.getSparseBlock(); | ||
| SparseBlock outBlk = out.getSparseBlock(); | ||
| for (int r = startRow; r < endRow; r++) { | ||
| int revRow = numRows - r - 1; | ||
| if (!inBlk.isEmpty(revRow)) { | ||
| outBlk.set(r, inBlk.get(revRow), true); | ||
| } | ||
| } | ||
| } | ||
| })); | ||
| } | ||
|
|
||
| // Wait for all threads | ||
| for (Future<?> task : tasks) { | ||
| task.get(); | ||
| } | ||
| } catch (Exception ex) { | ||
| throw new DMLRuntimeException(ex); | ||
| } finally { | ||
| pool.shutdown(); | ||
| } | ||
| return out; | ||
| } | ||
|
|
||
| public static void rev( IndexedMatrixValue in, long rlen, int blen, ArrayList<IndexedMatrixValue> out ) { | ||
| //input block reverse | ||
| MatrixIndexes inix = in.getIndexes(); | ||
| MatrixBlock inblk = (MatrixBlock) in.getValue(); | ||
| MatrixBlock inblk = (MatrixBlock) in.getValue(); | ||
| MatrixBlock tmpblk = rev(inblk, new MatrixBlock(inblk.getNumRows(), inblk.getNumColumns(), inblk.isInSparseFormat())); | ||
|
|
||
| //split and expand block if necessary (at most 2 blocks) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ | |
| import java.util.HashMap; | ||
|
|
||
| import org.apache.sysds.common.Opcodes; | ||
| import org.apache.sysds.utils.stats.InfrastructureAnalyzer; | ||
| import org.junit.Assert; | ||
| import org.junit.Test; | ||
| import org.apache.sysds.api.DMLScript; | ||
|
|
@@ -44,10 +45,17 @@ public class FullReverseTest extends AutomatedTestBase | |
| private static final String TEST_CLASS_DIR = TEST_DIR + FullReverseTest.class.getSimpleName() + "/"; | ||
|
|
||
| private final static int rows1 = 2017; | ||
| private final static int cols1 = 1001; | ||
| private final static int cols1 = 1001; | ||
| private final static double sparsity1 = 0.7; | ||
| private final static double sparsity2 = 0.1; | ||
|
|
||
| // Multi-threading test parameters | ||
| private final static int rows_mt = 5018; // Larger for multi-threading benefits | ||
| private final static int cols_mt = 1001; // Larger for multi-threading benefits | ||
| private final static int[] threadCounts = {1, 2, 4, 8}; | ||
| // Set global parallelism for SystemDS to enable multi-threading | ||
| private final static int oldPar = InfrastructureAnalyzer.getLocalParallelism(); | ||
|
|
||
| @Override | ||
| public void setUp() { | ||
| TestUtils.clearAssertionInformation(); | ||
|
|
@@ -64,7 +72,22 @@ public void testReverseVectorDenseCP() { | |
| public void testReverseVectorSparseCP() { | ||
| runReverseTest(TEST_NAME1, false, true, ExecType.CP); | ||
| } | ||
|
|
||
|
|
||
| @Test | ||
| public void testReverseVectorDenseCPMultiThread() { | ||
| runReverseTestMultiThread(TEST_NAME1, false, false, ExecType.CP); | ||
| } | ||
|
|
||
| @Test | ||
| public void testReverseVectorSparseCPMultiThread() { | ||
| runReverseTestMultiThread(TEST_NAME1, false, true, ExecType.CP); | ||
| } | ||
|
|
||
| @Test | ||
| public void testReverseVectorDenseSPMultiThread() { | ||
| runReverseTestMultiThread(TEST_NAME1, false, false, ExecType.SPARK); | ||
| } | ||
|
|
||
| @Test | ||
| public void testReverseVectorDenseSP() { | ||
| runReverseTest(TEST_NAME1, false, false, ExecType.SPARK); | ||
|
|
@@ -165,6 +188,78 @@ else if ( instType == ExecType.SPARK ) | |
| DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; | ||
| } | ||
| } | ||
|
|
||
|
|
||
| private void runReverseTestMultiThread(String testname, boolean matrix, boolean sparse, ExecType instType) | ||
| { | ||
| // Compare single-thread vs multi-thread results | ||
| // HashMap<CellIndex, Double> stResult = runReverseWithThreads(testname, matrix, sparse, instType, 1); | ||
| HashMap<CellIndex, Double> mtResult = runReverseWithThreads(testname, matrix, sparse, instType, 8); | ||
|
|
||
| // Compare results to ensure consistency | ||
| // TestUtils.compareMatrices(stResult, mtResult, 0, "ST-Result", "MT-Result"); | ||
| } | ||
|
|
||
| private HashMap<CellIndex, Double> runReverseWithThreads(String testname, boolean matrix, boolean sparse, ExecType instType, int numThreads) | ||
| { | ||
| //rtplatform for MR | ||
| ExecMode platformOld = rtplatform; | ||
| switch( instType ){ | ||
| case SPARK: rtplatform = ExecMode.SPARK; break; | ||
| default: rtplatform = ExecMode.HYBRID; break; | ||
| } | ||
| boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; | ||
| if( rtplatform == ExecMode.SPARK ) | ||
| DMLScript.USE_LOCAL_SPARK_CONFIG = true; | ||
|
|
||
| String TEST_NAME = testname; | ||
|
|
||
| System.out.println("I am trying to run multi-thread"); | ||
|
|
||
| try | ||
| { | ||
| System.setProperty("sysds.parallel.threads", String.valueOf(numThreads)); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove (I don't think we use this property internally) |
||
|
|
||
| // int cols = matrix ? cols_mt : 1; | ||
| double sparsity = sparse ? sparsity2 : sparsity1; | ||
| getAndLoadTestConfiguration(TEST_NAME); | ||
|
|
||
| /* This is for running the junit test the new way, i.e., construct the arguments directly */ | ||
| String HOME = SCRIPT_DIR + TEST_DIR; | ||
| fullDMLScriptName = HOME + TEST_NAME + ".dml"; | ||
|
|
||
| // Add thread count to program arguments | ||
| programArgs = new String[]{"-stats","-explain","-args", input("A"), output("B") }; | ||
|
|
||
| fullRScriptName = HOME + TEST_NAME + ".R"; | ||
| rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + expectedDir(); | ||
|
|
||
| //generate actual dataset | ||
| double[][] A = getRandomMatrix(rows_mt, cols_mt, -1, 1, sparsity, 7); | ||
| writeInputMatrixWithMTD("A", A, true); | ||
|
|
||
| // Run with specified thread count (this is the key part) | ||
| runTest(true, false, null, -1); | ||
|
|
||
| //read and return results | ||
| HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("B"); | ||
|
|
||
| //check generated opcode | ||
| if( instType == ExecType.CP ) | ||
| Assert.assertTrue("Missing opcode: rev", Statistics.getCPHeavyHitterOpCodes().contains(Opcodes.REV.toString())); | ||
| else if ( instType == ExecType.SPARK ) | ||
| Assert.assertTrue("Missing opcode: "+Instruction.SP_INST_PREFIX+Opcodes.REV.toString(), Statistics.getCPHeavyHitterOpCodes().contains(Instruction.SP_INST_PREFIX+Opcodes.REV)); | ||
|
|
||
| return dmlfile; | ||
| } | ||
| catch(Exception ex) { | ||
| throw new RuntimeException(ex); | ||
| } | ||
| finally { | ||
| //reset flags | ||
| rtplatform = platformOld; | ||
| DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; | ||
| System.setProperty("sysds.parallel.threads", String.valueOf(oldPar)); | ||
| } | ||
| } | ||
|
|
||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would recommend to create smaller tasks (e.g., numRows/k/4) which yields better load balancing.