Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/main/java/org/apache/sysds/api/DMLScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,13 @@
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDHandler;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContextPool;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.lineage.LineageCacheConfig;
import org.apache.sysds.runtime.lineage.LineageCacheConfig.LineageCachePolicy;
import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
import org.apache.sysds.runtime.privacy.CheckedConstraintsLog;
import org.apache.sysds.runtime.util.LocalFileUtils;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.utils.Explain;
import org.apache.sysds.utils.NativeHelper;
Expand Down Expand Up @@ -519,7 +519,7 @@ public static void cleanupHadoopExecution( DMLConfig config )
FederatedData.clearFederatedWorkers();

//0) shutdown prefetch/broadcast thread pool if necessary
SparkUtils.shutdownPool();
CommonThreadPool.shutdownAsyncRDDPool();

//1) cleanup scratch space (everything for current uuid)
//(required otherwise export to hdfs would skip assumed unnecessary writes if same name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.util.CommonThreadPool;

public class BroadcastCPInstruction extends UnaryCPInstruction {
private BroadcastCPInstruction(Operator op, CPOperand in, CPOperand out, String opcode, String istr) {
Expand All @@ -44,8 +44,8 @@ public static BroadcastCPInstruction parseInstruction (String str) {
public void processInstruction(ExecutionContext ec) {
ec.setVariable(output.getName(), ec.getMatrixObject(input1));

if (SparkUtils.triggerRDDPool == null)
SparkUtils.triggerRDDPool = Executors.newCachedThreadPool();
SparkUtils.triggerRDDPool.submit(new TriggerBroadcastTask(ec, ec.getMatrixObject(output)));
if (CommonThreadPool.triggerRDDPool == null)
CommonThreadPool.triggerRDDPool = Executors.newCachedThreadPool();
CommonThreadPool.triggerRDDPool.submit(new TriggerBroadcastTask(ec, ec.getMatrixObject(output)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.util.CommonThreadPool;

public class PrefetchCPInstruction extends UnaryCPInstruction {
private PrefetchCPInstruction(Operator op, CPOperand in, CPOperand out, String opcode, String istr) {
Expand All @@ -49,8 +49,8 @@ public void processInstruction(ExecutionContext ec) {
// If the next instruction which takes this output as an input comes before
// the prefetch thread triggers, that instruction will start the operations.
// In that case this Prefetch instruction will act like a NOOP.
if (SparkUtils.triggerRDDPool == null)
SparkUtils.triggerRDDPool = Executors.newCachedThreadPool();
SparkUtils.triggerRDDPool.submit(new TriggerRDDOperationsTask(ec.getMatrixObject(output)));
if (CommonThreadPool.triggerRDDPool == null)
CommonThreadPool.triggerRDDPool = Executors.newCachedThreadPool();
CommonThreadPool.triggerRDDPool.submit(new TriggerRDDOperationsTask(ec.getMatrixObject(output)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,11 @@

import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;
import java.util.stream.LongStream;

public class SparkUtils
{
public static ExecutorService triggerRDDPool = null;

//internal configuration
public static final StorageLevel DEFAULT_TMP = Checkpoint.DEFAULT_STORAGE_LEVEL;

Expand Down Expand Up @@ -296,14 +293,6 @@ public static void postprocessUltraSparseOutput(MatrixObject mo, DataCharacteris
mo.acquireReadAndRelease();
}

public static void shutdownPool() {
if (triggerRDDPool != null) {
//shutdown prefetch/broadcast thread pool
triggerRDDPool.shutdown();
triggerRDDPool = null;
}
}

private static class CheckSparsityFunction implements VoidFunction<Tuple2<MatrixIndexes,MatrixBlock>>
{
private static final long serialVersionUID = 4150132775681848807L;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public class CommonThreadPool implements ExecutorService
private static final int size = InfrastructureAnalyzer.getLocalParallelism();
private static final ExecutorService shared = ForkJoinPool.commonPool();
private final ExecutorService _pool;
public static ExecutorService triggerRDDPool = null;

public CommonThreadPool(ExecutorService pool) {
_pool = pool;
Expand Down Expand Up @@ -78,6 +79,14 @@ public static void shutdownShared() {
shared.shutdownNow();
}

public static void shutdownAsyncRDDPool() {
if (triggerRDDPool != null) {
//shutdown prefetch/broadcast thread pool
triggerRDDPool.shutdown();
triggerRDDPool = null;
}
}

@Override
public void shutdown() {
if( _pool != shared )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ public void runTest(String testname) {
HashMap<MatrixValue.CellIndex, Double> R_pf = readDMLScalarFromOutputDir("R");

//compare matrices
TestUtils.compareMatrices(R, R_pf, 1e-6, "Origin", "withPrefetch");
Boolean matchVal = TestUtils.compareMatrices(R, R_pf, 1e-6, "Origin", "withPrefetch");
if (!matchVal)
System.out.println("Value w/o Prefetch "+R+" w/ Prefetch "+R_pf);
//assert Prefetch instructions and number of success.
long expected_numPF = !testname.equalsIgnoreCase(TEST_NAME+"3") ? 1 : 0;
long expected_successPF = !testname.equalsIgnoreCase(TEST_NAME+"3") ? 1 : 0;
Expand Down