diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java index 8400ec54e6f..9d08d32521d 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java @@ -41,17 +41,21 @@ import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysds.runtime.controlprogram.caching.TensorObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysds.runtime.data.TensorBlock; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.io.FileFormatProperties; import org.apache.sysds.runtime.io.FileFormatPropertiesCSV; import org.apache.sysds.runtime.io.FileFormatPropertiesHDF5; import org.apache.sysds.runtime.io.FileFormatPropertiesLIBSVM; import org.apache.sysds.runtime.io.ListReader; import org.apache.sysds.runtime.io.ListWriter; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; import org.apache.sysds.runtime.io.WriterHDF5; import org.apache.sysds.runtime.io.WriterMatrixMarket; import org.apache.sysds.runtime.io.WriterTextCSV; @@ -59,6 +63,7 @@ import org.apache.sysds.runtime.lineage.LineageItemUtils; import org.apache.sysds.runtime.lineage.LineageTraceable; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; import org.apache.sysds.runtime.meta.DataCharacteristics; import org.apache.sysds.runtime.meta.MatrixCharacteristics; import org.apache.sysds.runtime.meta.MetaData; @@ -1060,19 +1065,46 @@ private void processWriteInstruction(ExecutionContext ec) { HDFSTool.writeScalarToHDFS(ec.getScalarInput(getInput1()), fname); } else if( getInput1().getDataType() == DataType.MATRIX ) { - if( fmt == FileFormat.MM ) - writeMMFile(ec, fname); - else if( fmt == FileFormat.CSV ) - writeCSVFile(ec, fname); - else if(fmt == FileFormat.LIBSVM) - writeLIBSVMFile(ec, fname); - else if(fmt == FileFormat.HDF5) - writeHDF5File(ec, fname); + + MatrixObject mo = ec.getMatrixObject(getInput1().getName()); + LocalTaskQueue stream = mo.getStreamHandle(); + + if (stream != null) { + + try { + IndexedMatrixValue tmp = null; + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(fmt); + + while((tmp = stream.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + MatrixBlock mb = (MatrixBlock)tmp.getValue(); + MatrixIndexes mi = tmp.getIndexes(); + + // Construct a unique filename for each part-file inside the output directory + String partFilePath = fname + "/part-" + mi.getRowIndex() + "-" + mi.getColumnIndex(); + + writer.writeMatrixToHDFS(mb, partFilePath, mb.getNumRows(), mb.getNumColumns(), (int) mb.getLength() , mb.getNonZeros()); + } + HDFSTool.writeMetaDataFile(fname + "/.mtd", mo.getValueType(), + mo.getMetaData().getDataCharacteristics(), FileFormat.HDF5, _formatProperties); + } + catch(Exception ex) { + throw new DMLRuntimeException("Failed to write OOC stream to " + fname, ex); + } + } else { - // Default behavior (text, binary) - MatrixObject mo = ec.getMatrixObject(getInput1().getName()); - int blen = Integer.parseInt(getInput4().getName()); - mo.exportData(fname, fmtStr, new FileFormatProperties(blen)); + if( fmt == FileFormat.MM ) + writeMMFile(ec, fname); + else if( fmt == FileFormat.CSV ) + writeCSVFile(ec, fname); + else if(fmt == FileFormat.LIBSVM) + writeLIBSVMFile(ec, fname); + else if(fmt == FileFormat.HDF5) + writeHDF5File(ec, fname); + else { + // Default behavior (text, binary) + int blen = Integer.parseInt(getInput4().getName()); + mo.exportData(fname, fmtStr, new FileFormatProperties(blen)); + } } } else if( getInput1().getDataType() == DataType.FRAME ) { diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/UnaryTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/UnaryTest.java index a81689af375..85f9166dc52 100644 --- a/src/test/java/org/apache/sysds/test/functions/ooc/UnaryTest.java +++ b/src/test/java/org/apache/sysds/test/functions/ooc/UnaryTest.java @@ -56,7 +56,7 @@ public void setUp() { } /** - * Test the sum of scalar multiplication, "sum(X*7)", with OOC backend. + * Test the unary operation, "ceil(X)", with OOC backend. */ @Test public void testUnary() { @@ -77,7 +77,7 @@ public void testUnaryOperation(boolean rewrite) programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", input(INPUT_NAME), output(OUTPUT_NAME)}; - int rows = 1000, cols = 1000; + int rows = 5000, cols = 5000; MatrixBlock mb = MatrixBlock.randOperations(rows, cols, 1.0, -1, 1, "uniform", 7); MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(FileFormat.BINARY); writer.writeMatrixToHDFS(mb, input(INPUT_NAME), rows, cols, 1000, rows*cols); @@ -87,21 +87,25 @@ public void testUnaryOperation(boolean rewrite) runTest(true, false, null, -1); HashMap dmlfile = readDMLMatrixFromOutputDir(OUTPUT_NAME); - Double result = dmlfile.get(new MatrixValue.CellIndex(1, 1)); + double expected = 0.0; for(int i = 0; i < rows; i++) { for(int j = 0; j < cols; j++) { - expected += Math.ceil(mb.get(i, j)); + Double dmlResult = dmlfile.get(new MatrixValue.CellIndex(i+1 , j+1 )); // Note: MM format is 1-based index + double actualValue = (dmlResult == null) ? 0.0 : dmlResult; + expected = Math.abs(Math.ceil(mb.get(i, j))); + System.out.println("("+i+","+j+"): " + actualValue + "actual: " + expected); + Assert.assertEquals(expected, actualValue, 1e-10); } } - Assert.assertEquals(expected, result, 1e-10); - String prefix = Instruction.OOC_INST_PREFIX; Assert.assertTrue("OOC wasn't used for RBLK", heavyHittersContainsString(prefix + Opcodes.RBLK)); Assert.assertTrue("OOC wasn't used for CEIL", heavyHittersContainsString(prefix + Opcodes.CEIL)); + Assert.assertTrue("Stream Aware WRITE wasn't used", + heavyHittersContainsString(String.valueOf(Opcodes.WRITE))); } catch(Exception ex) { Assert.fail(ex.getMessage()); diff --git a/src/test/scripts/functions/ooc/Unary.dml b/src/test/scripts/functions/ooc/Unary.dml index 6d34e8fd763..404fc9b0100 100644 --- a/src/test/scripts/functions/ooc/Unary.dml +++ b/src/test/scripts/functions/ooc/Unary.dml @@ -21,9 +21,7 @@ # Read input matrix and operator from command line args X = read($1); -#print(toString(X)) Y = ceil(X); -#print(toString(Y)) -res = as.matrix(sum(Y)); + # Write the final matrix result -write(res, $2); +write(Y, $2, format="binary");