From e55a6b75bf799c06290cdb3f9a452c113cce18ce Mon Sep 17 00:00:00 2001 From: Jannik Lindemann Date: Mon, 27 Oct 2025 18:12:33 +0100 Subject: [PATCH 1/4] Common OOC parallel stream processing utilities and non-blocking queue callback handling Remove unnecessary TODO Merge OOCEvicitionManager support into CachingStream Update CachingStream.java Create subscribable abstractions for OOC streams without affecting LocalTaskQueue Remove unnecessary implement --- .../java/org/apache/sysds/hops/BinaryOp.java | 2 +- .../controlprogram/caching/CacheableData.java | 33 ++-- .../controlprogram/caching/MatrixObject.java | 2 +- .../controlprogram/parfor/LocalTaskQueue.java | 24 ++- .../ooc/AggregateUnaryOOCInstruction.java | 10 +- .../ooc/BinaryOOCInstruction.java | 60 ++++--- .../instructions/ooc/CachingStream.java | 150 +++++++++++++++++ .../ooc/CentralMomentOOCInstruction.java | 86 ++-------- .../ooc/CtableOOCInstruction.java | 10 +- .../ooc/MatrixVectorBinaryOOCInstruction.java | 10 +- .../instructions/ooc/OOCInstruction.java | 151 +++++++++++++++++- .../runtime/instructions/ooc/OOCStream.java | 37 +++++ .../instructions/ooc/OOCStreamable.java | 30 ++++ .../instructions/ooc/PlaybackStream.java | 85 ++++++++++ .../ooc/ReblockOOCInstruction.java | 6 +- .../instructions/ooc/ResettableStream.java | 116 -------------- .../ooc/SubscribableTaskQueue.java | 100 ++++++++++++ .../instructions/ooc/TSMMOOCInstruction.java | 4 +- .../instructions/ooc/TeeOOCInstruction.java | 4 +- .../ooc/TransposeOOCInstruction.java | 29 ++-- .../instructions/ooc/UnaryOOCInstruction.java | 28 +--- .../runtime/matrix/operators/CMOperator.java | 7 + .../apache/sysds/runtime/util/OOCJoin.java | 69 ++++++++ .../functions/ooc/BinaryMatrixMatrixTest.java | 136 ++++++++++++++++ .../functions/ooc/BinaryMatrixScalarTest.java | 122 ++++++++++++++ .../functions/ooc/BinaryMatrixMatrix.dml | 29 ++++ .../functions/ooc/BinaryMatrixScalar.dml | 28 ++++ 27 files changed, 1057 insertions(+), 311 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java create mode 100644 src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java create mode 100644 src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java create mode 100644 src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java delete mode 100644 src/main/java/org/apache/sysds/runtime/instructions/ooc/ResettableStream.java create mode 100644 src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java create mode 100644 src/main/java/org/apache/sysds/runtime/util/OOCJoin.java create mode 100644 src/test/java/org/apache/sysds/test/functions/ooc/BinaryMatrixMatrixTest.java create mode 100644 src/test/java/org/apache/sysds/test/functions/ooc/BinaryMatrixScalarTest.java create mode 100644 src/test/scripts/functions/ooc/BinaryMatrixMatrix.dml create mode 100644 src/test/scripts/functions/ooc/BinaryMatrixScalar.dml diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java b/src/main/java/org/apache/sysds/hops/BinaryOp.java index a3ddb45ea6d..2b803a053c1 100644 --- a/src/main/java/org/apache/sysds/hops/BinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java @@ -478,7 +478,7 @@ op, getDataType(), getValueType(), et, setLineNumbers(softmax); setLops(softmax); } - else if ( et == ExecType.CP || et == ExecType.GPU || et == ExecType.FED ) + else if ( et == ExecType.CP || et == ExecType.GPU || et == ExecType.FED || et == ExecType.OOC ) { Lop binary = null; diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java index 7457e56ba5f..34a8aa18631 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java @@ -49,7 +49,9 @@ import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction; import org.apache.sysds.runtime.instructions.gpu.context.GPUContext; import org.apache.sysds.runtime.instructions.gpu.context.GPUObject; -import org.apache.sysds.runtime.instructions.ooc.ResettableStream; +import org.apache.sysds.runtime.instructions.ooc.OOCStream; +import org.apache.sysds.runtime.instructions.ooc.OOCStreamable; +import org.apache.sysds.runtime.instructions.ooc.SubscribableTaskQueue; import org.apache.sysds.runtime.instructions.spark.data.BroadcastObject; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.instructions.spark.data.RDDObject; @@ -223,7 +225,7 @@ public enum CacheStatus { private BroadcastObject _bcHandle = null; //Broadcast handle protected HashMap _gpuObjects = null; //Per GPUContext object allocated on GPU //TODO generalize for frames - private LocalTaskQueue _streamHandle = null; + private OOCStreamable _streamHandle = null; private LineageItem _lineage = null; @@ -469,34 +471,25 @@ public boolean hasBroadcastHandle() { return _bcHandle != null && _bcHandle.hasBackReference(); } - public LocalTaskQueue getStreamHandle() { + public OOCStream getStreamHandle() { if( !hasStreamHandle() ) { - _streamHandle = new LocalTaskQueue<>(); + final SubscribableTaskQueue _mStream = new SubscribableTaskQueue<>(); + _streamHandle = _mStream; DataCharacteristics dc = getDataCharacteristics(); MatrixBlock src = (MatrixBlock)acquireReadAndRelease(); LongStream.range(0, dc.getNumBlocks()) .mapToObj(i -> UtilFunctions.createIndexedMatrixBlock(src, dc, i)) .forEach( blk -> { try{ - _streamHandle.enqueueTask(blk); + _mStream.enqueue(blk); } catch(Exception ex) { - throw new DMLRuntimeException(ex); + throw ex instanceof DMLRuntimeException ? (DMLRuntimeException) ex : new DMLRuntimeException(ex); }}); - _streamHandle.closeInput(); - } - else if(_streamHandle != null && _streamHandle.isProcessed() - && _streamHandle instanceof ResettableStream) - { - try { - ((ResettableStream)_streamHandle).reset(); - } - catch(Exception ex) { - throw new DMLRuntimeException(ex); - } + _mStream.closeInput(); } - return _streamHandle; + return _streamHandle.getReadStream(); } /** @@ -539,7 +532,7 @@ public synchronized void removeGPUObject(GPUContext gCtx) { _gpuObjects.remove(gCtx); } - public synchronized void setStreamHandle(LocalTaskQueue q) { + public synchronized void setStreamHandle(OOCStreamable q) { _streamHandle = q; } @@ -633,7 +626,7 @@ && getRDDHandle() == null) ) { _requiresLocalWrite = false; } else if( hasStreamHandle() ) { - _data = readBlobFromStream( getStreamHandle() ); + _data = readBlobFromStream( getStreamHandle().toLocalTaskQueue() ); } else if( getRDDHandle()==null || getRDDHandle().allowsShortCircuitRead() ) { if( DMLScript.STATISTICS ) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java index 9f4ca12dd77..496bca87642 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java @@ -611,7 +611,7 @@ protected long writeStreamToHDFS(String fname, String ofmt, int rep, FileFormatP MetaDataFormat iimd = (MetaDataFormat) _metaData; FileFormat fmt = (ofmt != null ? FileFormat.safeValueOf(ofmt) : iimd.getFileFormat()); MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(fmt, rep, fprop); - return writer.writeMatrixFromStream(fname, getStreamHandle(), + return writer.writeMatrixFromStream(fname, getStreamHandle().toLocalTaskQueue(), getNumRows(), getNumColumns(), ConfigurationManager.getBlocksize()); } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java index 350fc8de3b6..783981e0f12 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java @@ -43,8 +43,8 @@ public class LocalTaskQueue public static final int MAX_SIZE = 100000; //main memory constraint public static final Object NO_MORE_TASKS = null; //object to signal NO_MORE_TASKS - private LinkedList _data = null; - private boolean _closedInput = false; + protected LinkedList _data = null; + protected boolean _closedInput = false; private DMLRuntimeException _failure = null; private static final Log LOG = LogFactory.getLog(LocalTaskQueue.class.getName()); @@ -60,21 +60,19 @@ public LocalTaskQueue() * @param t task * @throws InterruptedException if InterruptedException occurs */ - public synchronized void enqueueTask( T t ) + public synchronized void enqueueTask( T t ) throws InterruptedException { - while( _data.size() + 1 > MAX_SIZE && _failure == null ) - { + while(_data.size() + 1 > MAX_SIZE && _failure == null) { LOG.warn("MAX_SIZE of task queue reached."); wait(); //max constraint reached, wait for read } - if ( _failure != null ) + if(_failure != null) throw _failure; - - _data.addLast( t ); - - notify(); //notify waiting readers + + _data.addLast(t); + notify(); } /** @@ -97,14 +95,14 @@ public synchronized T dequeueTask() if ( _failure != null ) throw _failure; - + T t = _data.removeFirst(); notify(); // notify waiting writers return t; } - + /** * Synchronized (logical) insert of a NO_MORE_TASKS symbol at the end of the FIFO queue in order to * mark that no more tasks will be inserted into the queue. @@ -112,7 +110,7 @@ public synchronized T dequeueTask() public synchronized void closeInput() { _closedInput = true; - notifyAll(); //notify all waiting readers + notifyAll(); } public synchronized boolean isProcessed() { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java index c87b3c99cf2..2a53c5400ae 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java @@ -76,7 +76,7 @@ public void processInstruction( ExecutionContext ec ) { //setup operators and input queue AggregateUnaryOperator aggun = (AggregateUnaryOperator) getOperator(); MatrixObject min = ec.getMatrixObject(input1); - LocalTaskQueue q = min.getStreamHandle(); + OOCStream q = min.getStreamHandle(); int blen = ConfigurationManager.getBlocksize(); if (aggun.isRowAggregate() || aggun.isColAggregate()) { @@ -86,13 +86,13 @@ public void processInstruction( ExecutionContext ec ) { OOCMatrixBlockTracker aggTracker = new OOCMatrixBlockTracker(emitThreshold); HashMap corrs = new HashMap<>(); // correction blocks - LocalTaskQueue qOut = new LocalTaskQueue<>(); + OOCStream qOut = createWritableStream(); ec.getMatrixObject(output).setStreamHandle(qOut); submitOOCTask(() -> { IndexedMatrixValue tmp = null; try { - while((tmp = q.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + while((tmp = q.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) { long idx = aggun.isRowAggregate() ? tmp.getIndexes().getRowIndex() : tmp.getIndexes().getColumnIndex(); MatrixBlock ret = aggTracker.get(idx); @@ -139,7 +139,7 @@ public void processInstruction( ExecutionContext ec ) { new MatrixIndexes(1, tmp.getIndexes().getColumnIndex()); IndexedMatrixValue tmpOut = new IndexedMatrixValue(midx, ret); - qOut.enqueueTask(tmpOut); + qOut.enqueue(tmpOut); // drop intermediate states aggTracker.remove(idx); corrs.remove(idx); @@ -159,7 +159,7 @@ public void processInstruction( ExecutionContext ec ) { MatrixBlock ret = new MatrixBlock(1,1+extra,false); MatrixBlock corr = new MatrixBlock(1,1+extra,false); try { - while((tmp = q.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + while((tmp = q.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) { //block aggregation MatrixBlock ltmp = (MatrixBlock) ((MatrixBlock) tmp.getValue()) .aggregateUnaryOperations(aggun, new MatrixBlock(), blen, tmp.getIndexes()); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java index 1dfc99be811..a02d3bc088c 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java @@ -29,6 +29,7 @@ import org.apache.sysds.runtime.instructions.cp.ScalarObject; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.BinaryOperator; import org.apache.sysds.runtime.matrix.operators.Operator; import org.apache.sysds.runtime.matrix.operators.ScalarOperator; @@ -54,33 +55,46 @@ public static BinaryOOCInstruction parseInstruction(String str) { @Override public void processInstruction( ExecutionContext ec ) { - //TODO support all types, currently only binary matrix-scalar - + if (input1.isMatrix() && input2.isMatrix()) + processMatrixMatrixInstruction(ec); + else + processScalarMatrixInstruction(ec); + } + + protected void processMatrixMatrixInstruction(ExecutionContext ec) { + MatrixObject m1 = ec.getMatrixObject(input1); + MatrixObject m2 = ec.getMatrixObject(input2); + + OOCStream qIn1 = m1.getStreamHandle(); + OOCStream qIn2 = m2.getStreamHandle(); + OOCStream qOut = new SubscribableTaskQueue<>(); + ec.getMatrixObject(output).setStreamHandle(qOut); + + joinOOC(qIn1, qIn2, qOut, (tmp1, tmp2) -> { + IndexedMatrixValue tmpOut = new IndexedMatrixValue(); + tmpOut.set(tmp1.getIndexes(), + tmp1.getValue().binaryOperations((BinaryOperator)_optr, tmp2.getValue(), tmpOut.getValue())); + return tmpOut; + }, IndexedMatrixValue::getIndexes); + } + + protected void processScalarMatrixInstruction(ExecutionContext ec) { //get operator and scalar - CPOperand scalar = ( input1.getDataType() == DataType.MATRIX ) ? input2 : input1; + CPOperand scalar = input1.isMatrix() ? input2 : input1; ScalarObject constant = ec.getScalarInput(scalar); ScalarOperator sc_op = ((ScalarOperator)_optr).setConstant(constant.getDoubleValue()); - + //create thread and process binary operation - MatrixObject min = ec.getMatrixObject(input1); - LocalTaskQueue qIn = min.getStreamHandle(); - LocalTaskQueue qOut = new LocalTaskQueue<>(); + MatrixObject min = ec.getMatrixObject(input1.isMatrix() ? input1 : input2); + OOCStream qIn = min.getStreamHandle(); + OOCStream qOut = createWritableStream(); ec.getMatrixObject(output).setStreamHandle(qOut); - - submitOOCTask(() -> { - IndexedMatrixValue tmp = null; - try { - while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { - IndexedMatrixValue tmpOut = new IndexedMatrixValue(); - tmpOut.set(tmp.getIndexes(), - tmp.getValue().scalarOperations(sc_op, new MatrixBlock())); - qOut.enqueueTask(tmpOut); - } - qOut.closeInput(); - } - catch(Exception ex) { - throw new DMLRuntimeException(ex); - } - }, qIn, qOut); + + mapOOC(qIn, qOut, tmp -> { + IndexedMatrixValue tmpOut = new IndexedMatrixValue(); + tmpOut.set(tmp.getIndexes(), + tmp.getValue().scalarOperations(sc_op, new MatrixBlock())); + return tmpOut; + }); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java new file mode 100644 index 00000000000..668e11581db --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; +import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; + +/** + * A wrapper around LocalTaskQueue to consume the source stream and reset to + * consume again for other operators. + * + */ +public class CachingStream implements OOCStreamable { + + private static final IDSequence _streamSeq = new IDSequence(); + + // original live stream + private final OOCStream _source; + + // stream identifier + private final long _streamId; + + // block counter + private int _numBlocks = 0; + + private Runnable[] _subscribers; + + // state flags + private boolean _cacheInProgress = true; // caching in progress, in the first pass. + + public CachingStream(OOCStream source) { + this(source, _streamSeq.getNextID()); + } + + public CachingStream(OOCStream source, long streamId) { + _source = source; + _streamId = streamId; + source.setSubscriber(() -> { + try { + boolean closed = fetchFromStream(); + Runnable[] mSubscribers = _subscribers; + + if(mSubscribers != null) { + for(Runnable mSubscriber : mSubscribers) + mSubscriber.run(); + + if (closed) { + synchronized (this) { + _subscribers = null; + } + } + } + } catch (InterruptedException e) { + throw new DMLRuntimeException(e); + } + }); + } + + private boolean fetchFromStream() throws InterruptedException { + synchronized (this) { + if(!_cacheInProgress) + throw new DMLRuntimeException("Stream is closed"); + } + + IndexedMatrixValue task = _source.dequeue(); + + synchronized (this) { + if(task != LocalTaskQueue.NO_MORE_TASKS) { + OOCEvictionManager.put(_streamId, _numBlocks, task); + _numBlocks++; + notifyAll(); + return false; + } + else { + _cacheInProgress = false; // caching is complete + notifyAll(); + return true; + } + } + } + + public synchronized IndexedMatrixValue get(int idx) throws InterruptedException { + while (true) { + if (idx < _numBlocks) + return OOCEvictionManager.get(_streamId, idx); + else if (!_cacheInProgress) + return (IndexedMatrixValue)LocalTaskQueue.NO_MORE_TASKS; + + wait(); + } + } + + @Override + public OOCStream getReadStream() { + return new PlaybackStream(this); + } + + @Override + public OOCStream getWriteStream() { + return _source.getWriteStream(); + } + + @Override + public boolean isProcessed() { + return false; + } + + @Override + public void setSubscriber(Runnable subscriber) { + int mNumBlocks; + synchronized (this) { + mNumBlocks = _numBlocks; + if (_cacheInProgress) { + int newLen = _subscribers == null ? 1 : _subscribers.length + 1; + Runnable[] newSubscribers = new Runnable[newLen]; + + if(newLen > 1) + System.arraycopy(_subscribers, 0, newSubscribers, 0, newLen - 1); + + newSubscribers[newLen - 1] = subscriber; + _subscribers = newSubscribers; + } + } + + for (int i = 0; i < mNumBlocks; i++) + subscriber.run(); + + if (!_cacheInProgress) + subscriber.run(); // To fetch the NO_MORE_TASK element + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CentralMomentOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CentralMomentOOCInstruction.java index 9c122662c2c..7b3346ab6dd 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CentralMomentOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CentralMomentOOCInstruction.java @@ -30,17 +30,9 @@ import org.apache.sysds.runtime.instructions.cp.ScalarObject; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.matrix.data.MatrixIndexes; -import org.apache.sysds.runtime.matrix.data.MatrixValue; import org.apache.sysds.runtime.matrix.operators.CMOperator; import org.apache.sysds.runtime.meta.DataCharacteristics; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; - public class CentralMomentOOCInstruction extends AggregateUnaryOOCInstruction { private CentralMomentOOCInstruction(CMOperator cm, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, @@ -70,7 +62,7 @@ public void processInstruction(ExecutionContext ec) { */ MatrixObject matObj = ec.getMatrixObject(input1.getName()); - LocalTaskQueue qIn = matObj.getStreamHandle(); + OOCStream qIn = matObj.getStreamHandle(); CPOperand scalarInput = (input3 == null ? input2 : input3); ScalarObject order = ec.getScalarInput(scalarInput); @@ -81,20 +73,10 @@ public void processInstruction(ExecutionContext ec) { CMOperator finalCm_op = cm_op; - List cmObjs = new ArrayList<>(); + OOCStream cmObjs = createWritableStream(); if(input3 == null) { - try { - IndexedMatrixValue tmp; - - while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { - // We only handle MatrixBlock, other types of MatrixValue will fail here - cmObjs.add(((MatrixBlock) tmp.getValue()).cmOperations(cm_op)); - } - } - catch(Exception ex) { - throw new DMLRuntimeException(ex); - } + mapOOC(qIn, cmObjs, tmp -> ((MatrixBlock) tmp.getValue()).cmOperations(new CMOperator(finalCm_op))); // Need to copy CMOperator as its ValueFunction is stateful } else { // Here we use a hash join approach @@ -107,59 +89,23 @@ public void processInstruction(ExecutionContext ec) { if (dc.getBlocksize() != dcW.getBlocksize()) throw new DMLRuntimeException("Different block sizes are not yet supported"); - LocalTaskQueue wIn = wtObj.getStreamHandle(); - - try { - IndexedMatrixValue tmp = qIn.dequeueTask(); - IndexedMatrixValue tmpW = wIn.dequeueTask(); - Map left = new HashMap<>(); - Map right = new HashMap<>(); - - boolean cont = tmp != LocalTaskQueue.NO_MORE_TASKS || tmpW != LocalTaskQueue.NO_MORE_TASKS; - - while(cont) { - cont = false; - - if(tmp != LocalTaskQueue.NO_MORE_TASKS) { - MatrixValue weights = right.remove(tmp.getIndexes()); - - if(weights != null) - cmObjs.add(((MatrixBlock) tmp.getValue()).cmOperations(cm_op, (MatrixBlock) weights)); - else - left.put(tmp.getIndexes(), tmp.getValue()); - - tmp = qIn.dequeueTask(); - cont = tmp != LocalTaskQueue.NO_MORE_TASKS; - } + OOCStream wIn = wtObj.getStreamHandle(); - if(tmpW != LocalTaskQueue.NO_MORE_TASKS) { - MatrixValue q = left.remove(tmpW.getIndexes()); - - if(q != null) - cmObjs.add(((MatrixBlock) q).cmOperations(cm_op, (MatrixBlock) tmpW.getValue())); - else - right.put(tmpW.getIndexes(), tmpW.getValue()); - - tmpW = wIn.dequeueTask(); - cont |= tmpW != LocalTaskQueue.NO_MORE_TASKS; - } - } - - if (!left.isEmpty() || !right.isEmpty()) - throw new DMLRuntimeException("Unmatched blocks: values=" + left.size() + ", weights=" + right.size()); - } - catch(Exception ex) { - throw new DMLRuntimeException(ex); - } + joinOOC(qIn, wIn, cmObjs, + (tmp, weights) -> + ((MatrixBlock) tmp.getValue()).cmOperations(new CMOperator(finalCm_op), (MatrixBlock) weights.getValue()), + IndexedMatrixValue::getIndexes); } - Optional res = cmObjs.stream() - .reduce((arg0, arg1) -> (CM_COV_Object) finalCm_op.fn.execute(arg0, arg1)); - try { - ec.setScalarOutput(output_name, new DoubleObject(res.get().getRequiredResult(finalCm_op))); - } - catch(Exception ex) { + CM_COV_Object agg = cmObjs.dequeue(); + CM_COV_Object next; + + while ((next = cmObjs.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) + agg = (CM_COV_Object) finalCm_op.fn.execute(agg, next); + + ec.setScalarOutput(output_name, new DoubleObject(agg.getRequiredResult(finalCm_op))); + } catch (Exception ex) { throw new DMLRuntimeException(ex); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CtableOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CtableOOCInstruction.java index c4c668ab6b9..01fd348d101 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CtableOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CtableOOCInstruction.java @@ -79,7 +79,7 @@ public static CtableOOCInstruction parseInstruction(String str) { public void processInstruction( ExecutionContext ec ) { MatrixObject in1 = ec.getMatrixObject(input1); // stream - LocalTaskQueue qIn1 = in1.getStreamHandle(); + OOCStream qIn1 = in1.getStreamHandle(); IndexedMatrixValue tmp1 = null; long outputDim1 = ec.getScalarInput(_outDim1).getLongValue(); @@ -90,7 +90,7 @@ public void processInstruction( ExecutionContext ec ) { Ctable.OperationTypes ctableOp = findCtableOperation(); MatrixObject in2 = null, in3 = null; - LocalTaskQueue qIn2 = null, qIn3 = null; + OOCStream qIn2 = null, qIn3 = null; double cst2 = 0, cst3 = 0; // init vars based on ctableOp @@ -121,7 +121,7 @@ public void processInstruction( ExecutionContext ec ) { } try { - while((tmp1 = qIn1.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + while((tmp1 = qIn1.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) { MatrixBlock block1 = (MatrixBlock) tmp1.getValue(); long r = tmp1.getIndexes().getRowIndex(); @@ -172,13 +172,13 @@ public void processInstruction( ExecutionContext ec ) { } private MatrixBlock getOrDequeueBlock(long key, long cols, HashMap blocks, - LocalTaskQueue queue) throws InterruptedException + OOCStream queue) throws InterruptedException { MatrixBlock block = blocks.get(key); if (block == null) { IndexedMatrixValue tmp; // corresponding block still in queue, dequeue until found - while ((tmp = queue.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + while ((tmp = queue.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) { block = (MatrixBlock) tmp.getValue(); long r = tmp.getIndexes().getRowIndex(); long c = tmp.getIndexes().getColumnIndex(); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java index aa215e83e90..38586428e1e 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java @@ -83,15 +83,15 @@ public void processInstruction( ExecutionContext ec ) { long emitThreshold = min.getDataCharacteristics().getNumColBlocks(); OOCMatrixBlockTracker aggTracker = new OOCMatrixBlockTracker(emitThreshold); - LocalTaskQueue qIn = min.getStreamHandle(); - LocalTaskQueue qOut = new LocalTaskQueue<>(); + OOCStream qIn = min.getStreamHandle(); + OOCStream qOut = createWritableStream(); BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString()); ec.getMatrixObject(output).setStreamHandle(qOut); submitOOCTask(() -> { IndexedMatrixValue tmp = null; try { - while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + while((tmp = qIn.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) { MatrixBlock matrixBlock = (MatrixBlock) tmp.getValue(); long rowIndex = tmp.getIndexes().getRowIndex(); long colIndex = tmp.getIndexes().getColumnIndex(); @@ -103,7 +103,7 @@ public void processInstruction( ExecutionContext ec ) { // for single column block, no aggregation neeeded if(emitThreshold == 1) { - qOut.enqueueTask(new IndexedMatrixValue(tmp.getIndexes(), partialResult)); + qOut.enqueue(new IndexedMatrixValue(tmp.getIndexes(), partialResult)); } else { // aggregation @@ -116,7 +116,7 @@ public void processInstruction( ExecutionContext ec ) { if (aggTracker.putAndIncrementCount(rowIndex, currAgg)){ // early block output: emit aggregated block MatrixIndexes idx = new MatrixIndexes(rowIndex, 1L); - qOut.enqueueTask(new IndexedMatrixValue(idx, currAgg)); + qOut.enqueue(new IndexedMatrixValue(idx, currAgg)); aggTracker.remove(rowIndex); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index 0d159492891..4ceabc0802c 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -24,17 +24,29 @@ import org.apache.sysds.api.DMLScript; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.Operator; import org.apache.sysds.runtime.util.CommonThreadPool; +import org.apache.sysds.runtime.util.OOCJoin; +import java.util.ArrayList; import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Function; public abstract class OOCInstruction extends Instruction { protected static final Log LOG = LogFactory.getLog(OOCInstruction.class.getName()); + private static final AtomicInteger nextStreamId = new AtomicInteger(0); public enum OOCType { Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary, MAPMM, MMTSJ, Reorg, CM, Ctable @@ -42,6 +54,7 @@ public enum OOCType { protected final OOCInstruction.OOCType _ooctype; protected final boolean _requiresLabelUpdate; + protected Set> _queues; protected OOCInstruction(OOCInstruction.OOCType type, String opcode, String istr) { this(type, null, opcode, istr); @@ -90,10 +103,134 @@ public void postprocessInstruction(ExecutionContext ec) { ec.maintainLineageDebuggerInfo(this); } - protected void submitOOCTask(Runnable r, LocalTaskQueue... queues) { + protected void addInStream(OOCStream... queue) { + if (_queues == null) + _queues = new HashSet<>(); + _queues.addAll(List.of(queue)); + } + + protected void addOutStream(OOCStream... queue) { + // Currently same behavior as addInQueue + addInStream(queue); + } + + protected OOCStream createWritableStream() { + return new SubscribableTaskQueue<>(); + } + + protected void mapOOC(OOCStream qIn, OOCStream qOut, Function mapper) { + addInStream(qIn); + addOutStream(qOut); + + submitOOCTasks(qIn, tmp -> { + try { + R r = mapper.apply(tmp); + qOut.enqueue(r); + } catch (Exception e) { + throw e instanceof DMLRuntimeException ? (DMLRuntimeException) e : new DMLRuntimeException(e); + } + }, qOut::closeInput); + } + + protected CompletableFuture joinOOC(OOCStream qIn1, OOCStream qIn2, OOCStream qOut, BiFunction mapper, Function on) { + return joinOOC(qIn1, qIn2, qOut, mapper, on, on); + } + + protected CompletableFuture joinOOC(OOCStream qIn1, OOCStream qIn2, OOCStream qOut, BiFunction mapper, Function onLeft, Function onRight) { + addInStream(qIn1, qIn2); + addOutStream(qOut); + + final CompletableFuture future = new CompletableFuture<>(); + + final OOCJoin join = new OOCJoin<>((idx, left, right) -> qOut.enqueue(mapper.apply(left, right))); + + submitOOCTasks(List.of(qIn1, qIn2), (i, tmp) -> { + if (i == 0) + join.addLeft(onLeft.apply(tmp), tmp); + else + join.addRight(onRight.apply(tmp), tmp); + }, () -> { + join.close(); + qOut.closeInput(); + future.complete(null); + }); + + return future; + } + + protected CompletableFuture submitOOCTasks(final List> queues, BiConsumer consumer, Runnable finalizer) { + addInStream(queues.toArray(OOCStream[]::new)); ExecutorService pool = CommonThreadPool.get(); + final AtomicInteger activeTaskCtr = new AtomicInteger(0); + + final Object lock = new Object(); + + List> futures = new ArrayList<>(queues.size()); + final List streamsClosed = new ArrayList<>(queues.size()); + for (int i = 0; i < queues.size(); i++) { + streamsClosed.add(new AtomicBoolean(false)); + } + + int i = 0; + final int streamId = nextStreamId.getAndIncrement(); + //System.out.println("New stream: (id " + streamId + ", size " + queues.size() + ", initiator '" + this.getClass().getSimpleName() + "')"); + + for (OOCStream queue : queues) { + final int k = i; + final CompletableFuture localFuture = new CompletableFuture<>(); + final AtomicBoolean localStreamClosed = streamsClosed.get(k); + futures.add(localFuture); + //System.out.println("Substream (k " + k + ", id " + streamId + ", type '" + queue.getClass().getSimpleName() + "', stream_id " + queue.hashCode() + ")"); + queue.setSubscriber(() -> { + try { + activeTaskCtr.incrementAndGet(); + pool.submit(oocTask(() -> { + try { + T item = queue.dequeue(); + if(item != null) { + //System.out.println("Accept" + ((IndexedMatrixValue)item).getIndexes() + " (k " + k + ", id " + streamId + ")"); + consumer.accept(k, item); + } else { + //System.out.println("Close substream (k " + k + ", id " + streamId + ")"); + localStreamClosed.set(true); + } + activeTaskCtr.decrementAndGet(); + + boolean shutdown; + synchronized(lock) { + shutdown = activeTaskCtr.get() == 0 && streamsClosed.stream().allMatch(AtomicBoolean::get); + } + + if(shutdown) { + //System.out.println("Shutdown (id " + streamId + ")"); + finalizer.run(); + } + } + catch(Exception e) { + throw (e instanceof DMLRuntimeException ? (DMLRuntimeException)e : new DMLRuntimeException(e)); + } + }, localFuture, _queues.toArray(OOCStream[]::new))); + } catch (Exception e) { + throw new DMLRuntimeException(e); + } + }); + + i++; + } + + pool.shutdown(); + return CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)); + } + + protected void submitOOCTasks(OOCStream queue, Consumer consumer, Runnable finalizer) { + submitOOCTasks(List.of(queue), (i, tmp) -> consumer.accept(tmp), finalizer); + } + + protected CompletableFuture submitOOCTask(Runnable r, OOCStream... queues) { + ExecutorService pool = CommonThreadPool.get(); + final CompletableFuture future = new CompletableFuture<>(); try { - pool.submit(oocTask(r, queues)); + pool.submit(oocTask(() -> {r.run();future.complete(null);}, future, queues)); } catch (Exception ex) { throw new DMLRuntimeException(ex); @@ -101,9 +238,11 @@ protected void submitOOCTask(Runnable r, LocalTaskQueue... queues) { finally { pool.shutdown(); } + + return future; } - private Runnable oocTask(Runnable r, LocalTaskQueue... queues) { + private Runnable oocTask(Runnable r, CompletableFuture future, OOCStream... queues) { return () -> { try { r.run(); @@ -111,10 +250,12 @@ private Runnable oocTask(Runnable r, LocalTaskQueue... queues) { catch (Exception ex) { DMLRuntimeException re = ex instanceof DMLRuntimeException ? (DMLRuntimeException) ex : new DMLRuntimeException(ex); - for (LocalTaskQueue q : queues) { + for (OOCStream q : queues) { q.propagateFailure(re); } + future.completeExceptionally(re); + // Rethrow to ensure proper future handling throw re; } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java new file mode 100644 index 00000000000..30603db9d6d --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; + +public interface OOCStream extends OOCStreamable { + void enqueue(T t); + + T dequeue(); + + void closeInput(); + + LocalTaskQueue toLocalTaskQueue(); + + void setSubscriber(Runnable subscriber); + + void propagateFailure(DMLRuntimeException re); +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java new file mode 100644 index 00000000000..bdc4086bdcd --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +public interface OOCStreamable { + OOCStream getReadStream(); + + OOCStream getWriteStream(); + + boolean isProcessed(); + + void setSubscriber(Runnable subscriber); +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java new file mode 100644 index 00000000000..db33687eb5c --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; + +public class PlaybackStream implements OOCStream, OOCStreamable { + private final CachingStream _streamCache; + private int _streamIdx; + + public PlaybackStream(CachingStream streamCache) { + this._streamCache = streamCache; + this._streamIdx = 0; + } + + @Override + public void enqueue(IndexedMatrixValue t) { + throw new DMLRuntimeException("Cannot enqueue to a playback stream"); + } + + @Override + public void closeInput() { + throw new DMLRuntimeException("Cannot close a playback stream"); + } + + @Override + public LocalTaskQueue toLocalTaskQueue() { + final SubscribableTaskQueue q = new SubscribableTaskQueue<>(); + setSubscriber(() -> q.enqueue(dequeue())); + return q; + } + + @Override + public synchronized IndexedMatrixValue dequeue() { + try { + return _streamCache.get(_streamIdx++); + } catch (InterruptedException e) { + throw new DMLRuntimeException(e); + } + } + + @Override + public OOCStream getReadStream() { + return _streamCache.getReadStream(); + } + + @Override + public OOCStream getWriteStream() { + return _streamCache.getWriteStream(); + } + + @Override + public boolean isProcessed() { + return false; + } + + @Override + public void setSubscriber(Runnable subscriber) { + _streamCache.setSubscriber(subscriber); + } + + @Override + public void propagateFailure(DMLRuntimeException re) { + _streamCache.getWriteStream().propagateFailure(re); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java index 3c78879b45d..52a34ffbede 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java @@ -75,7 +75,7 @@ public void processInstruction(ExecutionContext ec) { //TODO support other formats than binary //create queue, spawn thread for asynchronous reading, and return - LocalTaskQueue q = new LocalTaskQueue(); + OOCStream q = createWritableStream(); submitOOCTask(() -> readBinaryBlock(q, min.getFileName()), q); MatrixObject mout = ec.getMatrixObject(output); @@ -83,7 +83,7 @@ public void processInstruction(ExecutionContext ec) { } @SuppressWarnings("resource") - private void readBinaryBlock(LocalTaskQueue q, String fname) { + private void readBinaryBlock(OOCStream q, String fname) { try { //prepare file access JobConf job = new JobConf(ConfigurationManager.getCachedJobConf()); @@ -102,7 +102,7 @@ private void readBinaryBlock(LocalTaskQueue q, String fname) MatrixIndexes key = new MatrixIndexes(); MatrixBlock value = new MatrixBlock(); while( reader.next(key, value) ) - q.enqueueTask(new IndexedMatrixValue(key, new MatrixBlock(value))); + q.enqueue(new IndexedMatrixValue(key, new MatrixBlock(value))); } } q.closeInput(); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ResettableStream.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ResettableStream.java deleted file mode 100644 index 6179811f7a7..00000000000 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ResettableStream.java +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysds.runtime.instructions.ooc; - -import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; -import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; -import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; - - -/** - * A wrapper around LocalTaskQueue to consume the source stream and reset to - * consume again for other operators. - *

- * Uses OOCEvictionManager for out-of-core caching. - * - */ -public class ResettableStream extends LocalTaskQueue { - - // original live stream - private final LocalTaskQueue _source; - - private static final IDSequence _streamSeq = new IDSequence(); - // stream identifier - private final long _streamId; - - // block counter - private int _numBlocks = 0; - - - // state flags - private boolean _cacheInProgress = true; // caching in progress, in the first pass. - private int _replayPosition = 0; // slider position in the stream - - public ResettableStream(LocalTaskQueue source) { - this(source, _streamSeq.getNextID()); - } - public ResettableStream(LocalTaskQueue source, long streamId) { - _source = source; - _streamId = streamId; - } - - /** - * Dequeues a task. If it is the first, it reads from the disk and stores in the cache. - * For subsequent passes it reads from the memory. - * - * @return The next matrix value in the stream, or NO_MORE_TASKS - */ - @Override - public synchronized IndexedMatrixValue dequeueTask() - throws InterruptedException { - if (_cacheInProgress) { - // First pass: Read value from the source and cache it, and return. - IndexedMatrixValue task = _source.dequeueTask(); - if (task != NO_MORE_TASKS) { - - OOCEvictionManager.put(_streamId, _numBlocks, task); - _numBlocks++; - - return task; - } else { - _cacheInProgress = false; // caching is complete - _source.closeInput(); // close source stream - - // Notify all the waiting consumers waiting for cache to fill with this stream - notifyAll(); - return (IndexedMatrixValue) NO_MORE_TASKS; - } - } else { - // Replay pass: read from the buffer - if (_replayPosition < _numBlocks) { - return OOCEvictionManager.get(_streamId, _replayPosition++); - } else { - return (IndexedMatrixValue) NO_MORE_TASKS; - } - } - } - - /** - * Resets the stream to beginning to read the stream from start. - * This can only be called once the stream is fully consumed once. - */ - public synchronized void reset() throws InterruptedException { - while (_cacheInProgress) { - // Attempted to reset a stream that's not been fully cached yet. - wait(); - } - _replayPosition = 0; - } - - @Override - public synchronized void closeInput() { - _source.closeInput(); - } - - @Override - public synchronized boolean isProcessed() { - return false; - } -} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java new file mode 100644 index 00000000000..698ce841fe0 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; + +public class SubscribableTaskQueue extends LocalTaskQueue implements OOCStream { + private Runnable _subscriber; + + @Override + public void enqueue(T t) { + try { + super.enqueueTask(t); + } + catch (InterruptedException e) { + throw new DMLRuntimeException(e); + } + + if(_subscriber != null) + _subscriber.run(); + } + + @Override + public T dequeue() { + try { + return super.dequeueTask(); + } + catch (InterruptedException e) { + throw new DMLRuntimeException(e); + } + } + + @Override + public void closeInput() { + super.closeInput(); + + if(_subscriber != null) { + _subscriber.run(); + _subscriber = null; + } + } + + @Override + public LocalTaskQueue toLocalTaskQueue() { + return this; + } + + @Override + public OOCStream getReadStream() { + return this; + } + + @Override + public OOCStream getWriteStream() { + return this; + } + + @Override + public void setSubscriber(Runnable subscriber) { + int queueSize; + + synchronized (this) { + if(_subscriber != null) + throw new DMLRuntimeException("Cannot set multiple subscribers"); + + _subscriber = subscriber; + queueSize = _data.size(); + queueSize += _closedInput ? 1 : 0; // To trigger the NO_MORE_TASK element + } + + for (int i = 0; i < queueSize; i++) + subscriber.run(); + } + + @Override + public void propagateFailure(DMLRuntimeException re) { + super.propagateFailure(re); + + if(_subscriber != null) + _subscriber.run(); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TSMMOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TSMMOOCInstruction.java index b3f302c2045..9040c369a24 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TSMMOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TSMMOOCInstruction.java @@ -66,7 +66,7 @@ public void processInstruction( ExecutionContext ec ) { int nCols = (int) min.getDataCharacteristics().getCols(); int bLen = min.getDataCharacteristics().getBlocksize(); - LocalTaskQueue qIn = min.getStreamHandle(); + OOCStream qIn = min.getStreamHandle(); BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString()); //validation check TODO extend compiler to not create OOC otherwise @@ -81,7 +81,7 @@ public void processInstruction( ExecutionContext ec ) { try { IndexedMatrixValue tmp = null; // aggregate partial tsmm outputs into result as inputs stream in - while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + while((tmp = qIn.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) { MatrixBlock partialResult = ((MatrixBlock) tmp.getValue()) .transposeSelfMatrixMultOperations(new MatrixBlock(), _type); resultBlock.binaryOperationsInPlace(plus, partialResult); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java index baf3ecea242..aa3be6c1f41 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java @@ -44,11 +44,11 @@ public static TeeOOCInstruction parseInstruction(String str) { public void processInstruction( ExecutionContext ec ) { //get input stream MatrixObject min = ec.getMatrixObject(input1); - LocalTaskQueue qIn = min.getStreamHandle(); + OOCStream qIn = min.getStreamHandle(); //get output and create new resettable stream MatrixObject mo = ec.getMatrixObject(output); - mo.setStreamHandle(new ResettableStream(qIn)); + mo.setStreamHandle(new CachingStream(qIn)); mo.setMetaData(min.getMetaData()); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java index 05e31830a56..6558145ec21 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java @@ -19,10 +19,8 @@ package org.apache.sysds.runtime.instructions.ooc; -import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; import org.apache.sysds.runtime.functionobjects.SwapIndex; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.instructions.cp.CPOperand; @@ -53,26 +51,17 @@ public void processInstruction( ExecutionContext ec ) { // Create thread and process the transpose operation MatrixObject min = ec.getMatrixObject(input1); - LocalTaskQueue qIn = min.getStreamHandle(); - LocalTaskQueue qOut = new LocalTaskQueue<>(); + OOCStream qIn = min.getStreamHandle(); + OOCStream qOut = createWritableStream(); ec.getMatrixObject(output).setStreamHandle(qOut); - submitOOCTask(() -> { - IndexedMatrixValue tmp = null; - try { - while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { - MatrixBlock inBlock = (MatrixBlock)tmp.getValue(); - long oldRowIdx = tmp.getIndexes().getRowIndex(); - long oldColIdx = tmp.getIndexes().getColumnIndex(); + mapOOC(qIn, qOut, tmp -> { + MatrixBlock inBlock = (MatrixBlock) tmp.getValue(); + long oldRowIdx = tmp.getIndexes().getRowIndex(); + long oldColIdx = tmp.getIndexes().getColumnIndex(); - MatrixBlock outBlock = inBlock.reorgOperations((ReorgOperator) _optr, new MatrixBlock(), -1, -1, -1); - qOut.enqueueTask(new IndexedMatrixValue(new MatrixIndexes(oldColIdx, oldRowIdx), outBlock)); - } - qOut.closeInput(); - } - catch(Exception ex) { - throw new DMLRuntimeException(ex); - } - }, qIn, qOut); + MatrixBlock outBlock = inBlock.reorgOperations((ReorgOperator) _optr, new MatrixBlock(), -1, -1, -1); + return new IndexedMatrixValue(new MatrixIndexes(oldColIdx, oldRowIdx), outBlock); + }); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java index 173486844a6..08f00f86d2f 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java @@ -19,10 +19,8 @@ package org.apache.sysds.runtime.instructions.ooc; -import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.instructions.cp.CPOperand; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; @@ -53,25 +51,15 @@ public void processInstruction( ExecutionContext ec ) { UnaryOperator uop = (UnaryOperator) _uop; // Create thread and process the unary operation MatrixObject min = ec.getMatrixObject(input1); - LocalTaskQueue qIn = min.getStreamHandle(); - LocalTaskQueue qOut = new LocalTaskQueue<>(); + OOCStream qIn = min.getStreamHandle(); + OOCStream qOut = createWritableStream(); ec.getMatrixObject(output).setStreamHandle(qOut); - - submitOOCTask(() -> { - IndexedMatrixValue tmp = null; - try { - while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { - IndexedMatrixValue tmpOut = new IndexedMatrixValue(); - tmpOut.set(tmp.getIndexes(), - tmp.getValue().unaryOperations(uop, new MatrixBlock())); - qOut.enqueueTask(tmpOut); - } - qOut.closeInput(); - } - catch(Exception ex) { - throw new DMLRuntimeException(ex); - } - }, qIn, qOut); + mapOOC(qIn, qOut, tmp -> { + IndexedMatrixValue tmpOut = new IndexedMatrixValue(); + tmpOut.set(tmp.getIndexes(), + tmp.getValue().unaryOperations(uop, new MatrixBlock())); + return tmpOut; + }); } } diff --git a/src/main/java/org/apache/sysds/runtime/matrix/operators/CMOperator.java b/src/main/java/org/apache/sysds/runtime/matrix/operators/CMOperator.java index f928f0440b8..489b277f748 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/operators/CMOperator.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/CMOperator.java @@ -55,6 +55,13 @@ public CMOperator(ValueFunction op, AggregateOperationTypes agg, int numThreads) _numThreads = numThreads; } + public CMOperator(CMOperator that) { + // Deep copy the stateful ValueFunction + fn = that.fn instanceof CM ? CM.getCMFnObject((CM)that.fn) : that.fn; + aggOpType = that.aggOpType; + _numThreads = that._numThreads; + } + public AggregateOperationTypes getAggOpType() { return aggOpType; } diff --git a/src/main/java/org/apache/sysds/runtime/util/OOCJoin.java b/src/main/java/org/apache/sysds/runtime/util/OOCJoin.java new file mode 100644 index 00000000000..81265b8a2d2 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/util/OOCJoin.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.util; + +import org.apache.logging.log4j.util.TriConsumer; +import org.apache.sysds.runtime.DMLRuntimeException; + +import java.util.HashMap; +import java.util.Map; + +public class OOCJoin { + private Map left; + private Map right; + private TriConsumer emitter; + + public OOCJoin(TriConsumer emitter) { + this.left = new HashMap<>(); + this.right = new HashMap<>(); + this.emitter = emitter; + } + + public void addLeft(T idx, O item) { + add(true, idx, item); + } + + public void addRight(T idx, O item) { + add(false, idx, item); + } + + public void close() { + synchronized (this) { + if (!left.isEmpty() || !right.isEmpty()) + throw new DMLRuntimeException("There are still unprocessed items in the OOC join"); + } + } + + public void add(boolean isLeft, T idx, O val) { + Map lookup = isLeft ? right : left; + Map store = isLeft ? left : right; + O val2; + + synchronized (this) { + val2 = lookup.remove(idx); + + if (val2 == null) + store.put(idx, val); + } + + if (val2 != null) + emitter.accept(idx, isLeft ? val : val2, isLeft ? val2 : val); + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/BinaryMatrixMatrixTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/BinaryMatrixMatrixTest.java new file mode 100644 index 00000000000..dfa9413bfb0 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/BinaryMatrixMatrixTest.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.io.IOException; + +public class BinaryMatrixMatrixTest extends AutomatedTestBase { + private final static String TEST_NAME1 = "BinaryMatrixMatrix"; + private final static String TEST_DIR = "functions/ooc/"; + private final static String TEST_CLASS_DIR = TEST_DIR + BinaryMatrixMatrixTest.class.getSimpleName() + "/"; + private final static double eps = 1e-8; + private static final String INPUT_NAME_1 = "X"; + private static final String INPUT_NAME_2 = "Y"; + private static final String OUTPUT_NAME = "res"; + + private final static int rows = 1500; + private final static int cols = 1200; + private final static int maxVal = 7; + private final static double sparsity1 = 1; + private final static double sparsity2 = 0.05; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1); + addTestConfiguration(TEST_NAME1, config); + } + + @Test + public void testBinaryMatrixMatrixDenseDense() { + runBinaryMatrixMatrixTest(false, false); + } + + @Test + public void testBinaryMatrixMatrixDenseSparse() { + runBinaryMatrixMatrixTest(false, true); + } + + @Test + public void testBinaryMatrixMatrixSparseDense() { + runBinaryMatrixMatrixTest(true, false); + } + + @Test + public void testBinaryMatrixMatrixSparseSparse() { + runBinaryMatrixMatrixTest(true, true); + } + + private void runBinaryMatrixMatrixTest(boolean sparse1, boolean sparse2) { + Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); + + try { + getAndLoadTestConfiguration(TEST_NAME1); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; + programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", input(INPUT_NAME_1), input(INPUT_NAME_2), output(OUTPUT_NAME)}; + + // 1. Generate the data in-memory as MatrixBlock objects + double[][] X_data = getRandomMatrix(rows, 1, 1, maxVal, sparse1 ? sparsity2 : sparsity1, 7); + double[][] Y_data = getRandomMatrix(rows, 1, 0, 1, sparse2 ? sparsity2 : sparsity1, 8); + + // 2. Convert the double arrays to MatrixBlock objects + MatrixBlock X_mb = DataConverter.convertToMatrixBlock(X_data); + MatrixBlock Y_mb = DataConverter.convertToMatrixBlock(Y_data); + + // 3. Create a binary matrix writer + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY); + + // 4. Write matrix A to a binary SequenceFile + writer.writeMatrixToHDFS(X_mb, input(INPUT_NAME_1), rows, cols, 1000, X_mb.getNonZeros()); + writer.writeMatrixToHDFS(Y_mb, input(INPUT_NAME_2), rows, cols, 1000, Y_mb.getNonZeros()); + HDFSTool.writeMetaDataFile(input(INPUT_NAME_1 + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, cols, 1000, X_mb.getNonZeros()), Types.FileFormat.BINARY); + HDFSTool.writeMetaDataFile(input(INPUT_NAME_2 + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, cols, 1000, Y_mb.getNonZeros()), Types.FileFormat.BINARY); + + runTest(true, false, null, -1); + + //check tsmm OOC + Assert.assertTrue("OOC wasn't used for multiplication", + heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.MULT)); + + //compare results + + // rerun without ooc flag + programArgs = new String[] {"-explain", "-stats", "-args", input(INPUT_NAME_1), input(INPUT_NAME_2), output(OUTPUT_NAME + "_target")}; + runTest(true, false, null, -1); + + // compare matrices + MatrixBlock ret1 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME), + Types.FileFormat.BINARY, rows, cols, 1000); + MatrixBlock ret2 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME + "_target"), + Types.FileFormat.BINARY, rows, cols, 1000); + TestUtils.compareMatrices(ret1, ret2, eps); + } + catch(IOException e) { + throw new RuntimeException(e); + } + finally { + resetExecMode(platformOld); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/BinaryMatrixScalarTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/BinaryMatrixScalarTest.java new file mode 100644 index 00000000000..e84d36e41b0 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/BinaryMatrixScalarTest.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.io.IOException; + +public class BinaryMatrixScalarTest extends AutomatedTestBase { + private final static String TEST_NAME1 = "BinaryMatrixScalar"; + private final static String TEST_DIR = "functions/ooc/"; + private final static String TEST_CLASS_DIR = TEST_DIR + BinaryMatrixScalarTest.class.getSimpleName() + "/"; + private final static double eps = 1e-8; + private static final String INPUT_NAME_1 = "X"; + private static final String OUTPUT_NAME = "res"; + + private final static int rows = 1500; + private final static int cols = 1200; + private final static int maxVal = 7; + private final static double sparsity1 = 1; + private final static double sparsity2 = 0.05; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1); + addTestConfiguration(TEST_NAME1, config); + } + + @Test + public void testBinaryMatrixScalarDense() { + runBinaryMatrixScalarTest(false); + } + + @Test + public void testBinaryMatrixScalarSparse() { + runBinaryMatrixScalarTest(true); + } + + private void runBinaryMatrixScalarTest(boolean sparse) { + Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); + + try { + getAndLoadTestConfiguration(TEST_NAME1); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; + programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", input(INPUT_NAME_1), output(OUTPUT_NAME)}; + + // 1. Generate the data in-memory as MatrixBlock objects + double[][] X_data = getRandomMatrix(rows, 1, 1, maxVal, sparse ? sparsity2 : sparsity1, 7); + + // 2. Convert the double arrays to MatrixBlock objects + MatrixBlock X_mb = DataConverter.convertToMatrixBlock(X_data); + + // 3. Create a binary matrix writer + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY); + + // 4. Write matrix A to a binary SequenceFile + writer.writeMatrixToHDFS(X_mb, input(INPUT_NAME_1), rows, cols, 1000, X_mb.getNonZeros()); + HDFSTool.writeMetaDataFile(input(INPUT_NAME_1 + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, cols, 1000, X_mb.getNonZeros()), Types.FileFormat.BINARY); + + runTest(true, false, null, -1); + + //check tsmm OOC + Assert.assertTrue("OOC wasn't used for division", + heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.DIV)); + Assert.assertTrue("OOC wasn't used for addition", + heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.PLUS)); + + //compare results + + // rerun without ooc flag + programArgs = new String[] {"-explain", "-stats", "-args", input(INPUT_NAME_1), output(OUTPUT_NAME + "_target")}; + runTest(true, false, null, -1); + + // compare matrices + MatrixBlock ret1 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME), + Types.FileFormat.BINARY, rows, cols, 1000); + MatrixBlock ret2 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME + "_target"), + Types.FileFormat.BINARY, rows, cols, 1000); + TestUtils.compareMatrices(ret1, ret2, eps); + } + catch(IOException e) { + throw new RuntimeException(e); + } + finally { + resetExecMode(platformOld); + } + } +} diff --git a/src/test/scripts/functions/ooc/BinaryMatrixMatrix.dml b/src/test/scripts/functions/ooc/BinaryMatrixMatrix.dml new file mode 100644 index 00000000000..ad7ed6bb554 --- /dev/null +++ b/src/test/scripts/functions/ooc/BinaryMatrixMatrix.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read the input matrix as a stream +X = read($1); +Y = read($2); + +res = X * Y +res = res * X + +write(res, $3, format="binary"); diff --git a/src/test/scripts/functions/ooc/BinaryMatrixScalar.dml b/src/test/scripts/functions/ooc/BinaryMatrixScalar.dml new file mode 100644 index 00000000000..e5b19fe5a75 --- /dev/null +++ b/src/test/scripts/functions/ooc/BinaryMatrixScalar.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read the input matrix as a stream +X = read($1); + +OOC = 5 / X +res = OOC + 3 + +write(res, $2, format="binary"); From 1ac69bd806055df2b0392803aecf286ea690c4c7 Mon Sep 17 00:00:00 2001 From: Jannik Lindemann Date: Mon, 3 Nov 2025 11:51:39 +0100 Subject: [PATCH 2/4] Remove redundant definition `setSubscriber` --- .../org/apache/sysds/runtime/instructions/ooc/OOCStream.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java index 30603db9d6d..06f347c0ceb 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java @@ -31,7 +31,5 @@ public interface OOCStream extends OOCStreamable { LocalTaskQueue toLocalTaskQueue(); - void setSubscriber(Runnable subscriber); - void propagateFailure(DMLRuntimeException re); } From caa2ef408231883cf10d60f0bd99373125088ebd Mon Sep 17 00:00:00 2001 From: Jannik Lindemann Date: Thu, 6 Nov 2025 13:58:50 +0100 Subject: [PATCH 3/4] Bugfix for joinOOC with OOCEvictionManager --- .../instructions/ooc/CachingStream.java | 17 +++++++++++++++- .../instructions/ooc/OOCInstruction.java | 20 +++++++++++++++---- .../runtime/instructions/ooc/OOCStream.java | 4 ++++ .../instructions/ooc/PlaybackStream.java | 10 ++++++++++ .../ooc/SubscribableTaskQueue.java | 10 ++++++++++ 5 files changed, 56 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java index 668e11581db..d95d69a71be 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java @@ -23,6 +23,10 @@ import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; + +import java.util.HashMap; +import java.util.Map; /** * A wrapper around LocalTaskQueue to consume the source stream and reset to @@ -31,7 +35,7 @@ */ public class CachingStream implements OOCStreamable { - private static final IDSequence _streamSeq = new IDSequence(); + public static final IDSequence _streamSeq = new IDSequence(); // original live stream private final OOCStream _source; @@ -46,6 +50,7 @@ public class CachingStream implements OOCStreamable { // state flags private boolean _cacheInProgress = true; // caching in progress, in the first pass. + private Map _index; public CachingStream(OOCStream source) { this(source, _streamSeq.getNextID()); @@ -86,6 +91,8 @@ private boolean fetchFromStream() throws InterruptedException { synchronized (this) { if(task != LocalTaskQueue.NO_MORE_TASKS) { OOCEvictionManager.put(_streamId, _numBlocks, task); + if (_index != null) + _index.put(task.getIndexes(), _numBlocks); _numBlocks++; notifyAll(); return false; @@ -109,6 +116,14 @@ else if (!_cacheInProgress) } } + public synchronized IndexedMatrixValue findCached(MatrixIndexes idx) { + return OOCEvictionManager.get(_streamId, _index.get(idx)); + } + + public synchronized void activateIndexing() { + _index = new HashMap<>(); + } + @Override public OOCStream getReadStream() { return new PlaybackStream(this); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index 4ceabc0802c..095e9f71df7 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -25,7 +25,9 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; import org.apache.sysds.runtime.matrix.operators.Operator; import org.apache.sysds.runtime.util.CommonThreadPool; import org.apache.sysds.runtime.util.OOCJoin; @@ -142,13 +144,23 @@ protected CompletableFuture joinOOC(OOCStream qIn1, OOCStream final CompletableFuture future = new CompletableFuture<>(); - final OOCJoin join = new OOCJoin<>((idx, left, right) -> qOut.enqueue(mapper.apply(left, right))); + // We need to construct our own stream to properly manage the cached items in the hash join + CachingStream leftCache = qIn1.hasStreamCache() ? qIn1.getStreamCache() : new CachingStream((SubscribableTaskQueue)qIn1); // We have to assume this generic type for now + CachingStream rightCache = qIn2.hasStreamCache() ? qIn2.getStreamCache() : new CachingStream((SubscribableTaskQueue)qIn2); // We have to assume this generic type for now + leftCache.activateIndexing(); + rightCache.activateIndexing(); - submitOOCTasks(List.of(qIn1, qIn2), (i, tmp) -> { + final OOCJoin join = new OOCJoin<>((idx, left, right) -> { + T leftObj = (T) leftCache.findCached(left); + T rightObj = (T) rightCache.findCached(right); + qOut.enqueue(mapper.apply(leftObj, rightObj)); + }); + + submitOOCTasks(List.of(leftCache.getReadStream(), rightCache.getReadStream()), (i, tmp) -> { if (i == 0) - join.addLeft(onLeft.apply(tmp), tmp); + join.addLeft(onLeft.apply((T)tmp), ((IndexedMatrixValue) tmp).getIndexes()); else - join.addRight(onRight.apply(tmp), tmp); + join.addRight(onRight.apply((T)tmp), ((IndexedMatrixValue) tmp).getIndexes()); }, () -> { join.close(); qOut.closeInput(); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java index 06f347c0ceb..1a12cb138b7 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java @@ -32,4 +32,8 @@ public interface OOCStream extends OOCStreamable { LocalTaskQueue toLocalTaskQueue(); void propagateFailure(DMLRuntimeException re); + + boolean hasStreamCache(); + + CachingStream getStreamCache(); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java index db33687eb5c..6edc4ecf270 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java @@ -82,4 +82,14 @@ public void setSubscriber(Runnable subscriber) { public void propagateFailure(DMLRuntimeException re) { _streamCache.getWriteStream().propagateFailure(re); } + + @Override + public boolean hasStreamCache() { + return true; + } + + @Override + public CachingStream getStreamCache() { + return _streamCache; + } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java index 698ce841fe0..060fb102f75 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java @@ -97,4 +97,14 @@ public void propagateFailure(DMLRuntimeException re) { if(_subscriber != null) _subscriber.run(); } + + @Override + public boolean hasStreamCache() { + return false; + } + + @Override + public CachingStream getStreamCache() { + return null; + } } From 1e33cd54cfb76b76c5f84264d7cdf012dec6aef9 Mon Sep 17 00:00:00 2001 From: Jannik Lindemann Date: Thu, 6 Nov 2025 15:43:21 +0100 Subject: [PATCH 4/4] Error handling improvements and better integration with OOCEvictionManager --- .../instructions/ooc/CachingStream.java | 3 +- .../instructions/ooc/OOCInstruction.java | 173 ++++++++++++------ 2 files changed, 122 insertions(+), 54 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java index d95d69a71be..1a540302806 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java @@ -121,7 +121,8 @@ public synchronized IndexedMatrixValue findCached(MatrixIndexes idx) { } public synchronized void activateIndexing() { - _index = new HashMap<>(); + if (_index == null) + _index = new HashMap<>(); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index 095e9f71df7..5aade709ea2 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -19,6 +19,7 @@ package org.apache.sysds.runtime.instructions.ooc; +import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.api.DMLScript; @@ -45,18 +46,21 @@ import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.Function; +import java.util.stream.Stream; public abstract class OOCInstruction extends Instruction { protected static final Log LOG = LogFactory.getLog(OOCInstruction.class.getName()); private static final AtomicInteger nextStreamId = new AtomicInteger(0); public enum OOCType { - Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary, MAPMM, MMTSJ, Reorg, CM, Ctable + Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary, MAPMM, MMTSJ, Reorg, CM, Ctable, MatrixIndexing } protected final OOCInstruction.OOCType _ooctype; protected final boolean _requiresLabelUpdate; - protected Set> _queues; + protected Set> _inQueues; + protected Set> _outQueues; + private boolean _failed; protected OOCInstruction(OOCInstruction.OOCType type, String opcode, String istr) { this(type, null, opcode, istr); @@ -69,6 +73,7 @@ protected OOCInstruction(OOCInstruction.OOCType type, Operator op, String opcode instOpcode = opcode; _requiresLabelUpdate = super.requiresLabelUpdate(); + _failed = false; } @Override @@ -106,25 +111,34 @@ public void postprocessInstruction(ExecutionContext ec) { } protected void addInStream(OOCStream... queue) { - if (_queues == null) - _queues = new HashSet<>(); - _queues.addAll(List.of(queue)); + if (_inQueues == null) + _inQueues = new HashSet<>(); + _inQueues.addAll(List.of(queue)); } protected void addOutStream(OOCStream... queue) { // Currently same behavior as addInQueue - addInStream(queue); + if (_outQueues == null) + _outQueues = new HashSet<>(); + _outQueues.addAll(List.of(queue)); } protected OOCStream createWritableStream() { return new SubscribableTaskQueue<>(); } - protected void mapOOC(OOCStream qIn, OOCStream qOut, Function mapper) { + protected CompletableFuture filterOOC(OOCStream qIn, Consumer processor, Function predicate, Runnable finalizer) { + if (_inQueues == null || _outQueues == null) + throw new NotImplementedException("filterOOC requires manual specification of all input and output streams for error propagation"); + + return submitOOCTasks(qIn, processor, finalizer, predicate); + } + + protected CompletableFuture mapOOC(OOCStream qIn, OOCStream qOut, Function mapper) { addInStream(qIn); addOutStream(qOut); - submitOOCTasks(qIn, tmp -> { + return submitOOCTasks(qIn, tmp -> { try { R r = mapper.apply(tmp); qOut.enqueue(r); @@ -171,71 +185,119 @@ protected CompletableFuture joinOOC(OOCStream qIn1, OOCStream } protected CompletableFuture submitOOCTasks(final List> queues, BiConsumer consumer, Runnable finalizer) { + List> futures = new ArrayList<>(queues.size()); + + for (int i = 0; i < queues.size(); i++) + futures.add(new CompletableFuture<>()); + + return submitOOCTasks(queues, consumer, finalizer, futures, null); + } + + protected CompletableFuture submitOOCTasks(final List> queues, BiConsumer consumer, Runnable finalizer, List> futures, BiFunction predicate) { addInStream(queues.toArray(OOCStream[]::new)); ExecutorService pool = CommonThreadPool.get(); - final AtomicInteger activeTaskCtr = new AtomicInteger(0); - final Object lock = new Object(); - - List> futures = new ArrayList<>(queues.size()); + final List activeTaskCtrs = new ArrayList<>(queues.size()); final List streamsClosed = new ArrayList<>(queues.size()); + for (int i = 0; i < queues.size(); i++) { + activeTaskCtrs.add(new AtomicInteger(0)); streamsClosed.add(new AtomicBoolean(false)); } + final AtomicInteger globalTaskCtr = new AtomicInteger(0); + final CompletableFuture globalFuture = CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)); + final Runnable oocFinalizer = oocTask(finalizer, null, Stream.concat(_outQueues.stream(), _inQueues.stream()).toArray(OOCStream[]::new)); + final Object globalLock = new Object(); + int i = 0; final int streamId = nextStreamId.getAndIncrement(); //System.out.println("New stream: (id " + streamId + ", size " + queues.size() + ", initiator '" + this.getClass().getSimpleName() + "')"); for (OOCStream queue : queues) { final int k = i; - final CompletableFuture localFuture = new CompletableFuture<>(); + final AtomicInteger localTaskCtr = activeTaskCtrs.get(k); final AtomicBoolean localStreamClosed = streamsClosed.get(k); - futures.add(localFuture); + final CompletableFuture localFuture = futures.get(k); + //System.out.println("Substream (k " + k + ", id " + streamId + ", type '" + queue.getClass().getSimpleName() + "', stream_id " + queue.hashCode() + ")"); - queue.setSubscriber(() -> { - try { - activeTaskCtr.incrementAndGet(); - pool.submit(oocTask(() -> { - try { - T item = queue.dequeue(); - if(item != null) { - //System.out.println("Accept" + ((IndexedMatrixValue)item).getIndexes() + " (k " + k + ", id " + streamId + ")"); - consumer.accept(k, item); - } else { - //System.out.println("Close substream (k " + k + ", id " + streamId + ")"); - localStreamClosed.set(true); - } - activeTaskCtr.decrementAndGet(); - - boolean shutdown; - synchronized(lock) { - shutdown = activeTaskCtr.get() == 0 && streamsClosed.stream().allMatch(AtomicBoolean::get); - } - - if(shutdown) { - //System.out.println("Shutdown (id " + streamId + ")"); - finalizer.run(); - } - } - catch(Exception e) { - throw (e instanceof DMLRuntimeException ? (DMLRuntimeException)e : new DMLRuntimeException(e)); - } - }, localFuture, _queues.toArray(OOCStream[]::new))); - } catch (Exception e) { - throw new DMLRuntimeException(e); + queue.setSubscriber(oocTask(() -> { + final T item = queue.dequeue(); + + if (predicate != null && item != null && !predicate.apply(k, item)) // Can get closed due to cancellation + return; + + synchronized (globalLock) { + if (localFuture.isDone()) + return; + + globalTaskCtr.incrementAndGet(); } - }); + + localTaskCtr.incrementAndGet(); + + pool.submit(oocTask(() -> { + if(item != null) { + //System.out.println("Accept" + ((IndexedMatrixValue)item).getIndexes() + " (k " + k + ", id " + streamId + ")"); + consumer.accept(k, item); + } + else { + //System.out.println("Close substream (k " + k + ", id " + streamId + ")"); + localStreamClosed.set(true); + } + + boolean runFinalizer = false; + + synchronized (globalLock) { + int localTasks = localTaskCtr.decrementAndGet(); + boolean finalizeStream = localTasks == 0 && localStreamClosed.get(); + + int globalTasks = globalTaskCtr.get() - 1; + + if (finalizeStream || (globalFuture.isDone() && localTasks == 0)) { + localFuture.complete(null); + + if (globalFuture.isDone() && globalTasks == 0) + runFinalizer = true; + } + + globalTaskCtr.decrementAndGet(); + } + + if (runFinalizer) + oocFinalizer.run(); + }, localFuture, Stream.concat(_outQueues.stream(), _inQueues.stream()).toArray(OOCStream[]::new))); + }, null, Stream.concat(_outQueues.stream(), _inQueues.stream()).toArray(OOCStream[]::new))); i++; } pool.shutdown(); - return CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)); + + globalFuture.whenComplete((res, e) -> { + if (globalFuture.isCancelled() || globalFuture.isCompletedExceptionally()) + futures.forEach(f -> f.cancel(true)); + + boolean runFinalizer; + + synchronized (globalLock) { + runFinalizer = globalTaskCtr.get() == 0; + } + + if (runFinalizer) + oocFinalizer.run(); + + //System.out.println("Shutdown (id " + streamId + ")"); + }); + return globalFuture; } - protected void submitOOCTasks(OOCStream queue, Consumer consumer, Runnable finalizer) { - submitOOCTasks(List.of(queue), (i, tmp) -> consumer.accept(tmp), finalizer); + protected CompletableFuture submitOOCTasks(OOCStream queue, Consumer consumer, Runnable finalizer) { + return submitOOCTasks(List.of(queue), (i, tmp) -> consumer.accept(tmp), finalizer); + } + + protected CompletableFuture submitOOCTasks(OOCStream queue, Consumer consumer, Runnable finalizer, Function predicate) { + return submitOOCTasks(List.of(queue), (i, tmp) -> consumer.accept(tmp), finalizer, List.of(new CompletableFuture()), (i, tmp) -> predicate.apply(tmp)); } protected CompletableFuture submitOOCTask(Runnable r, OOCStream... queues) { @@ -254,7 +316,7 @@ protected CompletableFuture submitOOCTask(Runnable r, OOCStream... queu return future; } - private Runnable oocTask(Runnable r, CompletableFuture future, OOCStream... queues) { + private Runnable oocTask(Runnable r, CompletableFuture future, OOCStream... queues) { return () -> { try { r.run(); @@ -262,11 +324,16 @@ private Runnable oocTask(Runnable r, CompletableFuture future, OOCStream q : queues) { + if (_failed) // Do avoid infinite cycles + throw re; + + _failed = true; + + for (OOCStream q : queues) q.propagateFailure(re); - } - future.completeExceptionally(re); + if (future != null) + future.completeExceptionally(re); // Rethrow to ensure proper future handling throw re;