Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/main/java/org/apache/sysds/hops/BinaryOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ op, getDataType(), getValueType(), et,
setLineNumbers(softmax);
setLops(softmax);
}
else if ( et == ExecType.CP || et == ExecType.GPU || et == ExecType.FED )
else if ( et == ExecType.CP || et == ExecType.GPU || et == ExecType.FED || et == ExecType.OOC )
{
Lop binary = null;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@
import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContext;
import org.apache.sysds.runtime.instructions.gpu.context.GPUObject;
import org.apache.sysds.runtime.instructions.ooc.ResettableStream;
import org.apache.sysds.runtime.instructions.ooc.OOCStream;
import org.apache.sysds.runtime.instructions.ooc.OOCStreamable;
import org.apache.sysds.runtime.instructions.ooc.SubscribableTaskQueue;
import org.apache.sysds.runtime.instructions.spark.data.BroadcastObject;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
Expand Down Expand Up @@ -223,7 +225,7 @@ public enum CacheStatus {
private BroadcastObject<T> _bcHandle = null; //Broadcast handle
protected HashMap<GPUContext, GPUObject> _gpuObjects = null; //Per GPUContext object allocated on GPU
//TODO generalize for frames
private LocalTaskQueue<IndexedMatrixValue> _streamHandle = null;
private OOCStreamable<IndexedMatrixValue> _streamHandle = null;

private LineageItem _lineage = null;

Expand Down Expand Up @@ -469,34 +471,25 @@ public boolean hasBroadcastHandle() {
return _bcHandle != null && _bcHandle.hasBackReference();
}

public LocalTaskQueue<IndexedMatrixValue> getStreamHandle() {
public OOCStream<IndexedMatrixValue> getStreamHandle() {
if( !hasStreamHandle() ) {
_streamHandle = new LocalTaskQueue<>();
final SubscribableTaskQueue<IndexedMatrixValue> _mStream = new SubscribableTaskQueue<>();
_streamHandle = _mStream;
DataCharacteristics dc = getDataCharacteristics();
MatrixBlock src = (MatrixBlock)acquireReadAndRelease();
LongStream.range(0, dc.getNumBlocks())
.mapToObj(i -> UtilFunctions.createIndexedMatrixBlock(src, dc, i))
.forEach( blk -> {
try{
_streamHandle.enqueueTask(blk);
_mStream.enqueue(blk);
}
catch(Exception ex) {
throw new DMLRuntimeException(ex);
throw ex instanceof DMLRuntimeException ? (DMLRuntimeException) ex : new DMLRuntimeException(ex);
}});
_streamHandle.closeInput();
}
else if(_streamHandle != null && _streamHandle.isProcessed()
&& _streamHandle instanceof ResettableStream)
{
try {
((ResettableStream)_streamHandle).reset();
}
catch(Exception ex) {
throw new DMLRuntimeException(ex);
}
_mStream.closeInput();
}

return _streamHandle;
return _streamHandle.getReadStream();
}

/**
Expand Down Expand Up @@ -539,7 +532,7 @@ public synchronized void removeGPUObject(GPUContext gCtx) {
_gpuObjects.remove(gCtx);
}

public synchronized void setStreamHandle(LocalTaskQueue<IndexedMatrixValue> q) {
public synchronized void setStreamHandle(OOCStreamable<IndexedMatrixValue> q) {
_streamHandle = q;
}

Expand Down Expand Up @@ -633,7 +626,7 @@ && getRDDHandle() == null) ) {
_requiresLocalWrite = false;
}
else if( hasStreamHandle() ) {
_data = readBlobFromStream( getStreamHandle() );
_data = readBlobFromStream( getStreamHandle().toLocalTaskQueue() );
}
else if( getRDDHandle()==null || getRDDHandle().allowsShortCircuitRead() ) {
if( DMLScript.STATISTICS )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ protected long writeStreamToHDFS(String fname, String ofmt, int rep, FileFormatP
MetaDataFormat iimd = (MetaDataFormat) _metaData;
FileFormat fmt = (ofmt != null ? FileFormat.safeValueOf(ofmt) : iimd.getFileFormat());
MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(fmt, rep, fprop);
return writer.writeMatrixFromStream(fname, getStreamHandle(),
return writer.writeMatrixFromStream(fname, getStreamHandle().toLocalTaskQueue(),
getNumRows(), getNumColumns(), ConfigurationManager.getBlocksize());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ public class LocalTaskQueue<T>
public static final int MAX_SIZE = 100000; //main memory constraint
public static final Object NO_MORE_TASKS = null; //object to signal NO_MORE_TASKS

private LinkedList<T> _data = null;
private boolean _closedInput = false;
protected LinkedList<T> _data = null;
protected boolean _closedInput = false;
private DMLRuntimeException _failure = null;
private static final Log LOG = LogFactory.getLog(LocalTaskQueue.class.getName());

Expand All @@ -60,21 +60,19 @@ public LocalTaskQueue()
* @param t task
* @throws InterruptedException if InterruptedException occurs
*/
public synchronized void enqueueTask( T t )
public synchronized void enqueueTask( T t )
throws InterruptedException
{
while( _data.size() + 1 > MAX_SIZE && _failure == null )
{
while(_data.size() + 1 > MAX_SIZE && _failure == null) {
LOG.warn("MAX_SIZE of task queue reached.");
wait(); //max constraint reached, wait for read
}

if ( _failure != null )
if(_failure != null)
throw _failure;

_data.addLast( t );

notify(); //notify waiting readers

_data.addLast(t);
notify();
}

/**
Expand All @@ -97,22 +95,22 @@ public synchronized T dequeueTask()

if ( _failure != null )
throw _failure;

T t = _data.removeFirst();

notify(); // notify waiting writers

return t;
}

/**
* Synchronized (logical) insert of a NO_MORE_TASKS symbol at the end of the FIFO queue in order to
* mark that no more tasks will be inserted into the queue.
*/
public synchronized void closeInput()
{
_closedInput = true;
notifyAll(); //notify all waiting readers
notifyAll();
}

public synchronized boolean isProcessed() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public void processInstruction( ExecutionContext ec ) {
//setup operators and input queue
AggregateUnaryOperator aggun = (AggregateUnaryOperator) getOperator();
MatrixObject min = ec.getMatrixObject(input1);
LocalTaskQueue<IndexedMatrixValue> q = min.getStreamHandle();
OOCStream<IndexedMatrixValue> q = min.getStreamHandle();
int blen = ConfigurationManager.getBlocksize();

if (aggun.isRowAggregate() || aggun.isColAggregate()) {
Expand All @@ -86,13 +86,13 @@ public void processInstruction( ExecutionContext ec ) {
OOCMatrixBlockTracker aggTracker = new OOCMatrixBlockTracker(emitThreshold);
HashMap<Long, MatrixBlock> corrs = new HashMap<>(); // correction blocks

LocalTaskQueue<IndexedMatrixValue> qOut = new LocalTaskQueue<>();
OOCStream<IndexedMatrixValue> qOut = createWritableStream();
ec.getMatrixObject(output).setStreamHandle(qOut);

submitOOCTask(() -> {
IndexedMatrixValue tmp = null;
try {
while((tmp = q.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
while((tmp = q.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) {
long idx = aggun.isRowAggregate() ?
tmp.getIndexes().getRowIndex() : tmp.getIndexes().getColumnIndex();
MatrixBlock ret = aggTracker.get(idx);
Expand Down Expand Up @@ -139,7 +139,7 @@ public void processInstruction( ExecutionContext ec ) {
new MatrixIndexes(1, tmp.getIndexes().getColumnIndex());
IndexedMatrixValue tmpOut = new IndexedMatrixValue(midx, ret);

qOut.enqueueTask(tmpOut);
qOut.enqueue(tmpOut);
// drop intermediate states
aggTracker.remove(idx);
corrs.remove(idx);
Expand All @@ -159,7 +159,7 @@ public void processInstruction( ExecutionContext ec ) {
MatrixBlock ret = new MatrixBlock(1,1+extra,false);
MatrixBlock corr = new MatrixBlock(1,1+extra,false);
try {
while((tmp = q.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
while((tmp = q.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) {
//block aggregation
MatrixBlock ltmp = (MatrixBlock) ((MatrixBlock) tmp.getValue())
.aggregateUnaryOperations(aggun, new MatrixBlock(), blen, tmp.getIndexes());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;

Expand All @@ -54,33 +55,46 @@ public static BinaryOOCInstruction parseInstruction(String str) {

@Override
public void processInstruction( ExecutionContext ec ) {
//TODO support all types, currently only binary matrix-scalar

if (input1.isMatrix() && input2.isMatrix())
processMatrixMatrixInstruction(ec);
else
processScalarMatrixInstruction(ec);
}

protected void processMatrixMatrixInstruction(ExecutionContext ec) {
MatrixObject m1 = ec.getMatrixObject(input1);
MatrixObject m2 = ec.getMatrixObject(input2);

OOCStream<IndexedMatrixValue> qIn1 = m1.getStreamHandle();
OOCStream<IndexedMatrixValue> qIn2 = m2.getStreamHandle();
OOCStream<IndexedMatrixValue> qOut = new SubscribableTaskQueue<>();
ec.getMatrixObject(output).setStreamHandle(qOut);

joinOOC(qIn1, qIn2, qOut, (tmp1, tmp2) -> {
IndexedMatrixValue tmpOut = new IndexedMatrixValue();
tmpOut.set(tmp1.getIndexes(),
tmp1.getValue().binaryOperations((BinaryOperator)_optr, tmp2.getValue(), tmpOut.getValue()));
return tmpOut;
}, IndexedMatrixValue::getIndexes);
}

protected void processScalarMatrixInstruction(ExecutionContext ec) {
//get operator and scalar
CPOperand scalar = ( input1.getDataType() == DataType.MATRIX ) ? input2 : input1;
CPOperand scalar = input1.isMatrix() ? input2 : input1;
ScalarObject constant = ec.getScalarInput(scalar);
ScalarOperator sc_op = ((ScalarOperator)_optr).setConstant(constant.getDoubleValue());

//create thread and process binary operation
MatrixObject min = ec.getMatrixObject(input1);
LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle();
LocalTaskQueue<IndexedMatrixValue> qOut = new LocalTaskQueue<>();
MatrixObject min = ec.getMatrixObject(input1.isMatrix() ? input1 : input2);
OOCStream<IndexedMatrixValue> qIn = min.getStreamHandle();
OOCStream<IndexedMatrixValue> qOut = createWritableStream();
ec.getMatrixObject(output).setStreamHandle(qOut);

submitOOCTask(() -> {
IndexedMatrixValue tmp = null;
try {
while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
IndexedMatrixValue tmpOut = new IndexedMatrixValue();
tmpOut.set(tmp.getIndexes(),
tmp.getValue().scalarOperations(sc_op, new MatrixBlock()));
qOut.enqueueTask(tmpOut);
}
qOut.closeInput();
}
catch(Exception ex) {
throw new DMLRuntimeException(ex);
}
}, qIn, qOut);

mapOOC(qIn, qOut, tmp -> {
IndexedMatrixValue tmpOut = new IndexedMatrixValue();
tmpOut.set(tmp.getIndexes(),
tmp.getValue().scalarOperations(sc_op, new MatrixBlock()));
return tmpOut;
});
}
}
Loading
Loading