diff --git a/src/main/java/org/apache/sysds/lops/CSVReBlock.java b/src/main/java/org/apache/sysds/lops/CSVReBlock.java index b554b787713..5820a732ee9 100644 --- a/src/main/java/org/apache/sysds/lops/CSVReBlock.java +++ b/src/main/java/org/apache/sysds/lops/CSVReBlock.java @@ -44,8 +44,8 @@ public CSVReBlock(Lop input, int blen, DataType dt, ValueType vt, ExecType et) _blocksize = blen; - if(et == ExecType.SPARK) { - lps.setProperties( inputs, ExecType.SPARK); + if(et == ExecType.SPARK || et == ExecType.OOC) { + lps.setProperties( inputs, et ); } else { throw new LopsException("Incorrect execution type for CSVReblock:" + et); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java index 4e9a92ecb78..03c806b09ae 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java @@ -25,6 +25,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.instructions.ooc.AggregateUnaryOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.BinaryOOCInstruction; +import org.apache.sysds.runtime.instructions.ooc.CSVReblockOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.CentralMomentOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.CtableOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.OOCInstruction; @@ -56,6 +57,8 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str switch(ooctype) { case Reblock: return ReblockOOCInstruction.parseInstruction(str); + case CSVReblock: + return CSVReblockOOCInstruction.parseInstruction(str); case AggregateUnary: return AggregateUnaryOOCInstruction.parseInstruction(str); case Unary: diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CSVReblockOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CSVReblockOOCInstruction.java new file mode 100644 index 00000000000..a4f8c497050 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CSVReblockOOCInstruction.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.io.FileFormatProperties; +import org.apache.sysds.runtime.io.FileFormatPropertiesCSV; +import org.apache.sysds.runtime.io.ReaderTextCSVParallel; +import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.meta.DataCharacteristics; + +public class CSVReblockOOCInstruction extends ComputationOOCInstruction { + private final int blen; + + private CSVReblockOOCInstruction(Operator op, CPOperand in, CPOperand out, int blocklength, String opcode, + String instr) { + super(OOCType.Reblock, op, in, out, opcode, instr); + blen = blocklength; + } + + public static CSVReblockOOCInstruction parseInstruction(String str) { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + String opcode = parts[0]; + if(!opcode.equals(Opcodes.CSVRBLK.toString())) + throw new DMLRuntimeException("Incorrect opcode for CSVReblockOOCInstruction:" + opcode); + + CPOperand in = new CPOperand(parts[1]); + CPOperand out = new CPOperand(parts[2]); + int blen = Integer.parseInt(parts[3]); + return new CSVReblockOOCInstruction(null, in, out, blen, opcode, str); + } + + @Override + public void processInstruction(ExecutionContext ec) { + MatrixObject min = ec.getMatrixObject(input1); + DataCharacteristics mc = ec.getDataCharacteristics(input1.getName()); + DataCharacteristics mcOut = ec.getDataCharacteristics(output.getName()); + mcOut.set(mc.getRows(), mc.getCols(), blen, mc.getNonZeros()); + + OOCStream qOut = createWritableStream(); + addOutStream(qOut); + + FileFormatProperties props = min.getFileFormatProperties(); + final FileFormatPropertiesCSV csvProps = props instanceof FileFormatPropertiesCSV ? (FileFormatPropertiesCSV) props + : new FileFormatPropertiesCSV(); + + final ReaderTextCSVParallel reader = new ReaderTextCSVParallel(csvProps); + final String fileName = min.getFileName(); + final long rows = mc.getRows(); + final long cols = mc.getCols(); + final long nnz = mc.getNonZeros(); + + submitOOCTask(() -> { + try { + reader.readMatrixAsStream(qOut, fileName, rows, cols, blen, nnz); + } + catch(Exception ex) { + throw (ex instanceof DMLRuntimeException) ? (DMLRuntimeException) ex : new DMLRuntimeException(ex); + } + }, qOut); + + MatrixObject mout = ec.getMatrixObject(output); + mout.setStreamHandle(qOut); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/io/ReaderTextCSVParallel.java b/src/main/java/org/apache/sysds/runtime/io/ReaderTextCSVParallel.java index 5b297a5d530..3e9a7a881a6 100644 --- a/src/main/java/org/apache/sysds/runtime/io/ReaderTextCSVParallel.java +++ b/src/main/java/org/apache/sysds/runtime/io/ReaderTextCSVParallel.java @@ -22,9 +22,12 @@ import java.io.IOException; import java.io.InputStream; import java.util.ArrayList; +import java.util.Map; import java.util.concurrent.Callable; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicInteger; import org.apache.commons.lang3.StringUtils; import org.apache.hadoop.fs.FileSystem; @@ -43,8 +46,11 @@ import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseRow; +import org.apache.sysds.runtime.instructions.ooc.OOCStream; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.io.IOUtilFunctions.CountRowsTask; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; import org.apache.sysds.runtime.util.CommonThreadPool; import org.apache.sysds.runtime.util.UtilFunctions; @@ -65,6 +71,7 @@ public class ReaderTextCSVParallel extends MatrixReader { protected int _rLen; protected int _cLen; protected JobConf _job; + protected boolean _streamSparse = false; public ReaderTextCSVParallel(FileFormatPropertiesCSV props) { _numThreads = OptimizerUtils.getParallelTextReadParallelism(); @@ -97,7 +104,7 @@ public MatrixBlock readMatrixFromHDFS(String fname, long rlen, long clen, int bl MatrixBlock ret = computeCSVSizeAndCreateOutputMatrixBlock(splits, path, rlen, clen, blen, estnnz); // Second Read Pass (read, parse strings, append to matrix block) - readCSVMatrixFromHDFS(splits, path, ret); + readCSVMatrixFromHDFS(splits, path, ret, null); // post-processing (representation-specific, change of sparse/dense block representation) // - no sorting required for CSV because it is read in sorted order per row @@ -112,6 +119,53 @@ public MatrixBlock readMatrixFromHDFS(String fname, long rlen, long clen, int bl return ret; } + public MatrixBlock readMatrixAsStream(OOCStream outStream, String fname, long rlen, long clen, + int blen, long estnnz) throws IOException, DMLRuntimeException { + _bLen = blen; + + // prepare file access + _job = new JobConf(ConfigurationManager.getCachedJobConf()); + + Path path = new Path(fname); + FileSystem fs = IOUtilFunctions.getFileSystem(path, _job); + + FileInputFormat.addInputPath(_job, path); + TextInputFormat informat = new TextInputFormat(); + informat.configure(_job); + + InputSplit[] splits = informat.getSplits(_job, _numThreads); + splits = IOUtilFunctions.sortInputSplits(splits); + + // check existence and non-empty file + checkValidInputFile(fs, path); + + // count rows/cols to populate meta data and split offsets + long estnnz2; + ExecutorService pool = CommonThreadPool.get(_numThreads); + try { + estnnz2 = computeCSVSize(splits, path, rlen, clen, estnnz, pool); + } + catch(Exception e) { + throw new IOException("Thread pool Error " + e.getMessage(), e); + } + finally { + pool.shutdown(); + } + + _streamSparse = MatrixBlock.evalSparseFormatInMemory(_rLen, _cLen, estnnz2); + + // stream CSV into blen x blen blocks + try { + BlockBuffer buffer = new BlockBuffer(outStream, _streamSparse); + readCSVMatrixFromHDFS(splits, path, null, buffer); + buffer.flushRemaining(); + } + finally { + outStream.closeInput(); + } + return null; + } + @Override public MatrixBlock readMatrixFromInputStream(InputStream is, long rlen, long clen, int blen, long estnnz) throws IOException, DMLRuntimeException { @@ -119,7 +173,8 @@ public MatrixBlock readMatrixFromInputStream(InputStream is, long rlen, long cle return new ReaderTextCSV(_props).readMatrixFromInputStream(is, rlen, clen, blen, estnnz); } - private void readCSVMatrixFromHDFS(InputSplit[] splits, Path path, MatrixBlock dest) throws IOException { + private void readCSVMatrixFromHDFS(InputSplit[] splits, Path path, MatrixBlock dest, BlockBuffer streamBuffer) + throws IOException { FileInputFormat.addInputPath(_job, path); TextInputFormat informat = new TextInputFormat(); @@ -131,17 +186,19 @@ private void readCSVMatrixFromHDFS(InputSplit[] splits, Path path, MatrixBlock d // create read tasks for all splits ArrayList> tasks = new ArrayList<>(); int splitCount = 0; + final boolean sparseOut = (streamBuffer != null) ? streamBuffer.isSparseBlocks() : + dest.isInSparseFormat(); for(InputSplit split : splits) { - if(dest.isInSparseFormat() && _props.getNAStrings() != null) - tasks.add(new CSVReadSparseNanTask(split, informat, dest, splitCount++)); - else if(dest.isInSparseFormat() && _props.getFillValue() == 0) - tasks.add(new CSVReadSparseNoNanTaskAndFill(split, informat, dest, splitCount++)); - else if(dest.isInSparseFormat()) - tasks.add(new CSVReadSparseNoNanTask(split, informat, dest, splitCount++)); + if(sparseOut && _props.getNAStrings() != null) + tasks.add(new CSVReadSparseNanTask(split, informat, dest, splitCount++, streamBuffer)); + else if(sparseOut && _props.getFillValue() == 0) + tasks.add(new CSVReadSparseNoNanTaskAndFill(split, informat, dest, splitCount++, streamBuffer)); + else if(sparseOut) + tasks.add(new CSVReadSparseNoNanTask(split, informat, dest, splitCount++, streamBuffer)); else if(_props.getNAStrings() != null) - tasks.add(new CSVReadDenseNanTask(split, informat, dest, splitCount++)); + tasks.add(new CSVReadDenseNanTask(split, informat, dest, splitCount++, streamBuffer)); else - tasks.add(new CSVReadDenseNoNanTask(split, informat, dest, splitCount++)); + tasks.add(new CSVReadDenseNoNanTask(split, informat, dest, splitCount++, streamBuffer)); } // check return codes and aggregate nnz @@ -149,7 +206,8 @@ else if(_props.getNAStrings() != null) for(Future rt : pool.invokeAll(tasks)) lnnz += rt.get(); - dest.setNonZeros(lnnz); + if(dest != null) + dest.setNonZeros(lnnz); } catch(Exception e) { throw new IOException("Thread pool issue, while parallel read.", e); @@ -159,6 +217,7 @@ else if(_props.getNAStrings() != null) } } + private MatrixBlock computeCSVSizeAndCreateOutputMatrixBlock(InputSplit[] splits, Path path, long rlen, long clen, int blen, long estnnz) throws IOException, DMLRuntimeException { _rLen = 0; @@ -172,10 +231,29 @@ private MatrixBlock computeCSVSizeAndCreateOutputMatrixBlock(InputSplit[] splits Future ret = (rlen<0 || clen<0 || estnnz<0) ? null : pool.submit(() -> createOutputMatrixBlock(rlen, clen, blen, estnnz, true, true)); + long estnnz2 = computeCSVSize(splits, path, rlen, clen, estnnz, pool); + return (ret!=null) ? UtilFunctions.getSafe(ret) : + createOutputMatrixBlock(_rLen, _cLen, blen, estnnz2, true, true); + } + catch(Exception e) { + throw new IOException("Thread pool Error " + e.getMessage(), e); + } + finally{ + pool.shutdown(); + } + } + + private long computeCSVSize(InputSplit[] splits, + Path path, long rlen, long clen, long estnnz, ExecutorService pool) throws IOException { + _rLen = 0; + _cLen = 0; + + // count rows in parallel per split + try { FileInputFormat.addInputPath(_job, path); TextInputFormat informat = new TextInputFormat(); informat.configure(_job); - + // count number of entities in the first non-header row LongWritable key = new LongWritable(); Text oneLine = new Text(); @@ -196,7 +274,7 @@ private MatrixBlock computeCSVSizeAndCreateOutputMatrixBlock(InputSplit[] splits tasks.add(new CountRowsTask(split, informat, _job, hasHeader)); hasHeader = false; } - + // collect row counts for offset computation // early error notify in case not all tasks successful _offsets = new SplitOffsetInfos(tasks.size()); @@ -208,7 +286,7 @@ private MatrixBlock computeCSVSizeAndCreateOutputMatrixBlock(InputSplit[] splits _rLen = _rLen + lnrow; i++; } - + // robustness for wrong dimensions which are already compiled into the plan if((rlen != -1 && _rLen != rlen) || (clen != -1 && _cLen != clen)) { @@ -229,8 +307,7 @@ private MatrixBlock computeCSVSizeAndCreateOutputMatrixBlock(InputSplit[] splits // allocate target matrix block based on given size; // need to allocate sparse as well since lock-free insert into target long estnnz2 = (estnnz < 0) ? (long) _rLen * _cLen : estnnz; - return (ret!=null) ? UtilFunctions.getSafe(ret) : - createOutputMatrixBlock(_rLen, _cLen, blen, estnnz2, true, true); + return estnnz2; } catch(Exception e) { throw new IOException("Thread pool Error " + e.getMessage(), e); @@ -271,16 +348,19 @@ private abstract class CSVReadTask implements Callable { protected final InputSplit _split; protected final TextInputFormat _informat; protected final MatrixBlock _dest; + protected final BlockBuffer _streamBuffer; protected final boolean _isFirstSplit; protected final int _splitCount; protected int _row = 0; protected int _col = 0; - public CSVReadTask(InputSplit split, TextInputFormat informat, MatrixBlock dest, int splitCount) { + public CSVReadTask(InputSplit split, TextInputFormat informat, MatrixBlock dest, int splitCount, + BlockBuffer buffer) { _split = split; _informat = informat; _dest = dest; + _streamBuffer = buffer; _isFirstSplit = (splitCount == 0); _splitCount = splitCount; } @@ -335,24 +415,184 @@ protected void verifyRows(Text value) throws IOException { + value); } } + + protected void finishRow(int row) { + if(_streamBuffer != null) + _streamBuffer.finishRow(row); + } + } + + private interface RowWriter { + void set(int col, double value); + } + + private static class DenseRowWriter implements RowWriter { + private final double[] _vals; + private final int _pos; + + public DenseRowWriter(DenseBlock block, int row) { + _vals = block.values(row); + _pos = block.pos(row); + } + + @Override + public void set(int col, double value) { + _vals[_pos + col] = value; + } + } + + private static class SparseRowWriter implements RowWriter { + private final SparseRow _row; + + public SparseRowWriter(SparseBlock block, int row) { + block.allocate(row); + _row = block.get(row); + } + + @Override + public void set(int col, double value) { + _row.append(col, value); + } + } + + private class BlockBuffer { + private final OOCStream _stream; + private final boolean _sparseBlocks; + private final int _numBlockCols; + private final ConcurrentHashMap _states = new ConcurrentHashMap<>(); + + public BlockBuffer(OOCStream stream, boolean sparseBlocks) { + _stream = stream; + _sparseBlocks = sparseBlocks; + _numBlockCols = Math.max(1, (int) Math.ceil((double) _cLen / _bLen)); + } + + public boolean isSparseBlocks() { + return _sparseBlocks; + } + + public RowWriter getRowWriter(int row) { + int brow = row / _bLen; + BlockRowState state = _states.computeIfAbsent(brow, BlockRowState::new); + return state.createRowWriter(row % _bLen); + } + + public void finishRow(int row) { + int brow = row / _bLen; + BlockRowState state = _states.get(brow); + if(state != null && state.finishRow()) { + if(_states.remove(brow, state)) + state.flush(brow); + } + } + + public void flushRemaining() { + for(Map.Entry entry : _states.entrySet()) { + if(_states.remove(entry.getKey(), entry.getValue())) + entry.getValue().flush(entry.getKey()); + } + } + + private class StreamRowWriter implements RowWriter { + private final BlockRowState _state; + private final int _rowInBlock; + + public StreamRowWriter(BlockRowState state, int rowInBlock) { + _state = state; + _rowInBlock = rowInBlock; + } + + @Override + public void set(int col, double value) { + if(value == 0) + return; + int bcol = col / _bLen; + MatrixBlock block = _state.getOrCreateBlock(bcol); + int localCol = col % _bLen; + if(_sparseBlocks) { + SparseBlock sb = block.getSparseBlock(); + sb.allocate(_rowInBlock); + sb.get(_rowInBlock).append(localCol, value); + } + else { + DenseBlock db = block.getDenseBlock(); + double[] vals = db.values(_rowInBlock); + int pos = db.pos(_rowInBlock); + vals[pos + localCol] = value; + } + } + } + + private class BlockRowState { + private final MatrixBlock[] _blocks; + private final int _rowsInBlock; + private final AtomicInteger _rowsCompleted = new AtomicInteger(); + + public BlockRowState(int brow) { + _blocks = new MatrixBlock[_numBlockCols]; + _rowsInBlock = Math.min(_bLen, _rLen - brow * _bLen); + } + + public RowWriter createRowWriter(int rowInBlock) { + return new StreamRowWriter(this, rowInBlock); + } + + public boolean finishRow() { + return _rowsCompleted.incrementAndGet() == _rowsInBlock; + } + + public void flush(int brow) { + for(int bci = 0; bci < _blocks.length; bci++) { + MatrixBlock block = _blocks[bci]; + if(block == null) + continue; + block.recomputeNonZeros(); + if(block.getNonZeros() == 0) + continue; + block.examSparsity(); + MatrixIndexes idx = new MatrixIndexes(brow + 1, bci + 1); + _stream.enqueue(new IndexedMatrixValue(idx, block)); + } + } + + private MatrixBlock getOrCreateBlock(int bcol) { + MatrixBlock block = _blocks[bcol]; + if(block == null) { + synchronized(this) { + block = _blocks[bcol]; + if(block == null) { + int cols = Math.min(_bLen, _cLen - bcol * _bLen); + block = new MatrixBlock(_rowsInBlock, cols, _sparseBlocks); + if(_sparseBlocks) + block.allocateSparseRowsBlock(); + else + block.allocateDenseBlock(); + _blocks[bcol] = block; + } + } + } + return block; + } + } } private class CSVReadDenseNoNanTask extends CSVReadTask { - public CSVReadDenseNoNanTask(InputSplit split, TextInputFormat informat, MatrixBlock dest, int splitCount) { - super(split, informat, dest, splitCount); + public CSVReadDenseNoNanTask(InputSplit split, TextInputFormat informat, MatrixBlock dest, int splitCount, + BlockBuffer buffer) { + super(split, informat, dest, splitCount, buffer); } protected long parse(RecordReader reader, LongWritable key, Text value) throws IOException { - DenseBlock a = _dest.getDenseBlock(); + DenseBlock a = (_streamBuffer == null) ? _dest.getDenseBlock() : null; double cellValue = 0; long nnz = 0; boolean noFillEmpty = false; while(reader.next(key, value)) { // foreach line final String cellStr = value.toString().trim(); - double[] avals = a.values(_row); - int apos = a.pos(_row); + RowWriter rowWriter = (_streamBuffer != null) ? + _streamBuffer.getRowWriter(_row) : new DenseRowWriter(a, _row); final String[] parts = _cLen == 1 ? null : IOUtilFunctions.split(cellStr, _props.getDelim()); @@ -365,14 +605,14 @@ protected long parse(RecordReader reader, LongWritable key, else { cellValue = Double.parseDouble(part); } - if(cellValue != 0) { - avals[apos+j] = cellValue; + rowWriter.set(j, cellValue); + if(cellValue != 0) nnz++; - } } // sanity checks (number of columns, fill values) IOUtilFunctions.checkAndRaiseErrorCSVEmptyField(cellStr, _props.isFill(), noFillEmpty); IOUtilFunctions.checkAndRaiseErrorCSVNumColumns(_split, cellStr, parts, _cLen); + finishRow(_row); _row++; } @@ -383,20 +623,21 @@ protected long parse(RecordReader reader, LongWritable key, private class CSVReadDenseNanTask extends CSVReadTask { - public CSVReadDenseNanTask(InputSplit split, TextInputFormat informat, MatrixBlock dest, int splitCount) { - super(split, informat, dest, splitCount); + public CSVReadDenseNanTask(InputSplit split, TextInputFormat informat, MatrixBlock dest, int splitCount, + BlockBuffer buffer) { + super(split, informat, dest, splitCount, buffer); } protected long parse(RecordReader reader, LongWritable key, Text value) throws IOException { - DenseBlock a = _dest.getDenseBlock(); + DenseBlock a = (_streamBuffer == null) ? _dest.getDenseBlock() : null; double cellValue = 0; boolean noFillEmpty = false; long nnz = 0; while(reader.next(key, value)) { // foreach line String cellStr = value.toString().trim(); String[] parts = IOUtilFunctions.split(cellStr, _props.getDelim()); - double[] avals = a.values(_row); - int apos = a.pos(_row); + RowWriter rowWriter = (_streamBuffer != null) ? + _streamBuffer.getRowWriter(_row) : new DenseRowWriter(a, _row); for(int j = 0; j < _cLen; j++) { // foreach cell String part = parts[j].trim(); if(part.isEmpty()) { @@ -406,14 +647,14 @@ protected long parse(RecordReader reader, LongWritable key, else cellValue = UtilFunctions.parseToDouble(part, _props.getNAStrings()); - if(cellValue != 0) { - avals[apos+j] = cellValue; + rowWriter.set(j, cellValue); + if(cellValue != 0) nnz++; - } } // sanity checks (number of columns, fill values) IOUtilFunctions.checkAndRaiseErrorCSVEmptyField(cellStr, _props.isFill(), noFillEmpty); IOUtilFunctions.checkAndRaiseErrorCSVNumColumns(_split, cellStr, parts, _cLen); + finishRow(_row); _row++; } return nnz; @@ -422,23 +663,23 @@ protected long parse(RecordReader reader, LongWritable key, private class CSVReadSparseNanTask extends CSVReadTask { - public CSVReadSparseNanTask(InputSplit split, TextInputFormat informat, MatrixBlock dest, int splitCount) { - super(split, informat, dest, splitCount); + public CSVReadSparseNanTask(InputSplit split, TextInputFormat informat, MatrixBlock dest, int splitCount, + BlockBuffer buffer) { + super(split, informat, dest, splitCount, buffer); } protected long parse(RecordReader reader, LongWritable key, Text value) throws IOException { boolean noFillEmpty = false; double cellValue = 0; - final SparseBlock sb = _dest.getSparseBlock(); + final SparseBlock sb = (_streamBuffer == null) ? _dest.getSparseBlock() : null; long nnz = 0; while(reader.next(key, value)) { final String cellStr = value.toString().trim(); final String[] parts = IOUtilFunctions.split(cellStr, _props.getDelim()); _col = 0; - sb.allocate(_row); - SparseRow r = sb.get(_row); - + RowWriter rowWriter = (_streamBuffer != null) ? + _streamBuffer.getRowWriter(_row) : new SparseRowWriter(sb, _row); for(String part : parts) { part = part.trim(); if(part.isEmpty()) { @@ -450,7 +691,7 @@ protected long parse(RecordReader reader, LongWritable key, } if(cellValue != 0) { - r.append(_col, cellValue); + rowWriter.set(_col, cellValue); nnz++; } _col++; @@ -460,6 +701,7 @@ protected long parse(RecordReader reader, LongWritable key, IOUtilFunctions.checkAndRaiseErrorCSVEmptyField(cellStr, _props.isFill(), noFillEmpty); IOUtilFunctions.checkAndRaiseErrorCSVNumColumns(_split, cellStr, parts, _cLen); + finishRow(_row); _row++; } return nnz; @@ -467,12 +709,13 @@ protected long parse(RecordReader reader, LongWritable key, } private class CSVReadSparseNoNanTask extends CSVReadTask { - public CSVReadSparseNoNanTask(InputSplit split, TextInputFormat informat, MatrixBlock dest, int splitCount) { - super(split, informat, dest, splitCount); + public CSVReadSparseNoNanTask(InputSplit split, TextInputFormat informat, MatrixBlock dest, int splitCount, + BlockBuffer buffer) { + super(split, informat, dest, splitCount, buffer); } protected long parse(RecordReader reader, LongWritable key, Text value) throws IOException { - final SparseBlock sb = _dest.getSparseBlock(); + final SparseBlock sb = (_streamBuffer == null) ? _dest.getSparseBlock() : null; long nnz = 0; double cellValue = 0; boolean noFillEmpty = false; @@ -480,8 +723,8 @@ protected long parse(RecordReader reader, LongWritable key, _col = 0; final String cellStr = value.toString().trim(); final String[] parts = IOUtilFunctions.split(cellStr, _props.getDelim()); - sb.allocate(_row); - SparseRow r = sb.get(_row); + RowWriter rowWriter = (_streamBuffer != null) ? + _streamBuffer.getRowWriter(_row) : new SparseRowWriter(sb, _row); for(String part : parts) { part = part.trim(); if(part.isEmpty()) { @@ -493,7 +736,7 @@ protected long parse(RecordReader reader, LongWritable key, } if(cellValue != 0) { - r.append(_col, cellValue); + rowWriter.set(_col, cellValue); nnz++; } _col++; @@ -503,6 +746,7 @@ protected long parse(RecordReader reader, LongWritable key, IOUtilFunctions.checkAndRaiseErrorCSVEmptyField(cellStr, _props.isFill(), noFillEmpty); IOUtilFunctions.checkAndRaiseErrorCSVNumColumns(_split, cellStr, parts, _cLen); + finishRow(_row); _row++; } return nnz; @@ -511,25 +755,25 @@ protected long parse(RecordReader reader, LongWritable key, private class CSVReadSparseNoNanTaskAndFill extends CSVReadTask { public CSVReadSparseNoNanTaskAndFill(InputSplit split, TextInputFormat informat, MatrixBlock dest, - int splitCount) { - super(split, informat, dest, splitCount); + int splitCount, BlockBuffer buffer) { + super(split, informat, dest, splitCount, buffer); } protected long parse(RecordReader reader, LongWritable key, Text value) throws IOException { - final SparseBlock sb = _dest.getSparseBlock(); + final SparseBlock sb = (_streamBuffer == null) ? _dest.getSparseBlock() : null; long nnz = 0; double cellValue = 0; while(reader.next(key, value)) { _col = 0; final String cellStr = value.toString().trim(); final String[] parts = IOUtilFunctions.split(cellStr, _props.getDelim()); - sb.allocate(_row); - SparseRow r = sb.get(_row); + RowWriter rowWriter = (_streamBuffer != null) ? + _streamBuffer.getRowWriter(_row) : new SparseRowWriter(sb, _row); for(String part : parts) { if(!part.isEmpty()) { cellValue = Double.parseDouble(part); if(cellValue != 0) { - r.append(_col, cellValue); + rowWriter.set(_col, cellValue); nnz++; } } @@ -538,6 +782,7 @@ protected long parse(RecordReader reader, LongWritable key, IOUtilFunctions.checkAndRaiseErrorCSVNumColumns(_split, cellStr, parts, _cLen); + finishRow(_row); _row++; } return nnz; diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/CSVReaderTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/CSVReaderTest.java new file mode 100644 index 00000000000..5f5f7fb42f6 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/CSVReaderTest.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +import java.io.IOException; + +public class CSVReaderTest extends AutomatedTestBase { + private final static String TEST_NAME1 = "CSVReader"; + private final static String TEST_DIR = "functions/ooc/"; + private final static String TEST_CLASS_DIR = TEST_DIR + CSVReaderTest.class.getSimpleName() + "/"; + private final static double eps = 1e-8; + private static final String INPUT_NAME = "X"; + private static final String OUTPUT_NAME = "res"; + + private final static int maxVal = 7; + private final static double sparsity1 = 0.65; + private final static double sparsity2 = 0.05; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1); + addTestConfiguration(TEST_NAME1, config); + } + + @Test + public void testCSVReaderDense1() { + runCSVReaderTest(false, 1800, 1100); + } + + @Test + public void testCSVReaderSparse1() { + runCSVReaderTest(true, 1800, 1100); + } + + @Test + public void testCSVReaderDenseWide() { + runCSVReaderTest(false, 50, 12100); + } + + @Test + public void testCSVReaderSparseWide() { + runCSVReaderTest(true, 500, 50000); + } + + @Test + public void testCSVReaderDenseUltraWide() { + runCSVReaderTest(false, 50, 200000); + } + + @Test + public void testCSVReaderDenseLarge() { + runCSVReaderTest(false, 750, 50000); + } + + @Test + public void testCSVReaderSparseLarge() { + runCSVReaderTest(true, 500, 50000); + } + + @Test + public void testCSVReaderDenseLarge2() { + runCSVReaderTest(false, 1200, 25000); + } + + private void runCSVReaderTest(boolean sparse, int rows, int cols) { + Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); + + try { + getAndLoadTestConfiguration(TEST_NAME1); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; + programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", input(INPUT_NAME), output(OUTPUT_NAME)}; + + // 1. Generate the data in-memory as MatrixBlock objects + double[][] A_data = getRandomMatrix(rows, cols, 1, maxVal, sparse ? sparsity2 : sparsity1, 7); + + // 2. Convert the double arrays to MatrixBlock objects + MatrixBlock A_mb = DataConverter.convertToMatrixBlock(A_data); + + // 3. Create a binary matrix writer + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.CSV); + + // 4. Write matrix A to a binary SequenceFile + writer.writeMatrixToHDFS(A_mb, input(INPUT_NAME), rows, cols, 1000, A_mb.getNonZeros()); + HDFSTool.writeMetaDataFile(input(INPUT_NAME + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, cols, 1000, A_mb.getNonZeros()), Types.FileFormat.CSV); + + runTest(true, false, null, -1); + + //compare results + + // rerun without ooc flag + programArgs = new String[] {"-explain", "-stats", "-args", input(INPUT_NAME), + output(OUTPUT_NAME + "_target")}; + runTest(true, false, null, -1); + + // compare matrices + MatrixBlock ret1 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME), Types.FileFormat.BINARY, rows, cols, 1000); + MatrixBlock ret2 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME + "_target"), Types.FileFormat.BINARY, rows, cols, 1000); + TestUtils.compareMatrices(ret1, ret2, eps); + } + catch(IOException e) { + throw new RuntimeException(e); + } + finally { + resetExecMode(platformOld); + } + } +} diff --git a/src/test/scripts/functions/ooc/CSVReader.dml b/src/test/scripts/functions/ooc/CSVReader.dml new file mode 100644 index 00000000000..12e5b02cd0c --- /dev/null +++ b/src/test/scripts/functions/ooc/CSVReader.dml @@ -0,0 +1,23 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +A = read($1); + +write(A, $2, format="binary");