diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java index 1fc582924e4..1541f16c96d 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java @@ -607,17 +607,23 @@ public double getAsNaNDouble(int i) { private static double getAsDouble(String s) { try { - return DoubleArray.parseDouble(s); } catch(Exception e) { - String ls = s.toLowerCase(); - if(ls.equals("true") || ls.equals("t")) + // fallback for boolean-like tokens, without allocating a lower-cased copy + final int len = s.length(); + if(len == 1) { + final char c = s.charAt(0); + if(c == 't' || c == 'T') + return 1; + else if(c == 'f' || c == 'F') + return 0; + } + else if(len == 4 && s.compareToIgnoreCase("true") == 0) return 1; - else if(ls.equals("false") || ls.equals("f")) + else if(len == 5 && s.compareToIgnoreCase("false") == 0) return 0; - else - throw new DMLRuntimeException("Unable to change to double: " + s, e); + throw new DMLRuntimeException("Unable to change to double: " + s, e); } } diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java b/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java index 724af1be630..1f731fc3aa5 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java @@ -23,13 +23,22 @@ import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.CommonThreadPool; +import org.apache.sysds.runtime.util.UtilFunctions; /** * Base class for all transform decoders providing both a row and block @@ -43,11 +52,61 @@ public abstract class Decoder implements Externalizable{ protected ValueType[] _schema; protected int[] _colList; protected String[] _colnames = null; + // dummycoded columns that were feature-hashed: domain size K is read from the meta cell, not + // numDistinct. Only used during initMetaData (driver side), so not serialized. + protected transient int[] _dcHashCols = null; + protected Decoder(ValueType[] schema, int[] colList) { _schema = schema; _colList = colList; } + protected boolean isHashCol(int colID) { + return ArrayUtils.contains(_dcHashCols, colID); + } + + /** + * Domain size of a dummycoded source column: the hash domain K from the meta cell for + * feature-hashed columns, otherwise the column's {@code numDistinct} (0 when unset). + * + * @param meta transform meta frame + * @param colID 1-based column id of the dummycoded source column + * @param isHash whether the column was feature-hashed + * @return the domain size, never negative + */ + protected static int getNumDummycodeDistinct(FrameBlock meta, int colID, boolean isHash) { + if(isHash) { + Object o = meta.get(0, colID - 1); + return (o == null) ? 0 : (int) UtilFunctions.parseToLong(o.toString()); + } + ColumnMetadata d = meta.getColumnMetadata()[colID - 1]; + int ndist = d.isDefault() ? 0 : (int) d.getNumDistinct(); + return Math.max(ndist, 0); + } + + /** + * Maps output column ids ({@code _colList}) to source positions in the encoded matrix, shifting past the column + * expansion of any dummycoded columns that precede them. Returns {@code _colList} directly when none apply. + */ + protected int[] buildSrcCols(FrameBlock meta, int[] dcCols) { + if(dcCols == null || dcCols.length == 0) + return _colList; + int[] srcCols = new int[_colList.length]; + int ix1 = 0, ix2 = 0, off = 0; + while(ix1 < _colList.length) { + if(ix2 >= dcCols.length || _colList[ix1] < dcCols[ix2]) { + srcCols[ix1] = _colList[ix1] + off; + ix1++; + } + else { // skip past the dummycode expansion + int dcCol = dcCols[ix2]; + off += getNumDummycodeDistinct(meta, dcCol, isHashCol(dcCol)) - 1; + ix2++; + } + } + return srcCols; + } + public ValueType[] getSchema() { return _schema; } @@ -77,8 +136,35 @@ public String[] getColnames() { * @param k Parallelization degree * @return returns the given output frame block for convenience */ - public FrameBlock decode(MatrixBlock in, FrameBlock out, int k) { - return decode(in, out); + public FrameBlock decode(final MatrixBlock in, final FrameBlock out, final int k) { + if(k <= 1) + return decode(in, out); + final ExecutorService pool = CommonThreadPool.get(k); + out.ensureAllocatedColumns(in.getNumRows()); + try { + final List> tasks = new ArrayList<>(); + int blz = Math.max((in.getNumRows() + k) / k, 1000); + + for(int i = 0; i < in.getNumRows(); i += blz){ + final int start = i; + final int end = Math.min(in.getNumRows(), i + blz); + tasks.add(pool.submit(() -> decode(in, out, start, end))); + } + + for(Future f : tasks) + f.get(); + return out; + } + catch(InterruptedException e) { + Thread.currentThread().interrupt(); + throw new DMLRuntimeException("Parallel decode interrupted", e); + } + catch(ExecutionException e) { + throw new DMLRuntimeException("Parallel decode failed", e); + } + finally { + pool.shutdown(); + } } /** diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java index edee095f612..a286c03dce8 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java @@ -41,8 +41,9 @@ public class DecoderBin extends Decoder { private static final long serialVersionUID = -3784249774608228805L; - // a) column bin boundaries - private int[] _numBins; + // dummycoded source columns and the resulting output->source column mapping + private int[] _dcCols = null; + private int[] _srcCols = null; private double[][] _binMins = null; private double[][] _binMaxs = null; @@ -50,8 +51,10 @@ public DecoderBin() { super(null, null); } - protected DecoderBin(ValueType[] schema, int[] binCols) { + protected DecoderBin(ValueType[] schema, int[] binCols, int[] dcCols, int[] hashCols) { super(schema, binCols); + _dcCols = dcCols; + _dcHashCols = hashCols; } @Override @@ -66,14 +69,19 @@ public void decode(MatrixBlock in, FrameBlock out, int rl, int ru) { for( int i=rl; i< ru; i++ ) { for( int j=0; j<_colList.length; j++ ) { final Array a = out.getColumn(_colList[j] - 1); - final double val = in.get(i, _colList[j] - 1); + final double val = in.get(i, _srcCols[j] - 1); if(!Double.isNaN(val)){ final int key = (int) Math.round(val); - double bmin = _binMins[j][key - 1]; - double bmax = _binMaxs[j][key - 1]; - double oval = bmin + (bmax - bmin) / 2 // bin center - + (val - key) * (bmax - bmin); // bin fractions - a.set(i, oval); + if(key == 0){ + a.set(i, _binMins[j][key]); + } + else{ + double bmin = _binMins[j][key - 1]; + double bmax = _binMaxs[j][key - 1]; + double oval = bmin + (bmax - bmin) / 2 // bin center + + (val - key) * (bmax - bmin); // bin fractions + a.set(i, oval); + } } else a.set(i, val); // NaN @@ -90,7 +98,6 @@ public Decoder subRangeDecoder(int colStart, int colEnd, int dummycodedOffset) { @Override public void initMetaData(FrameBlock meta) { //initialize bin boundaries - _numBins = new int[_colList.length]; _binMins = new double[_colList.length][]; _binMaxs = new double[_colList.length][]; @@ -111,34 +118,52 @@ public void initMetaData(FrameBlock meta) { _binMaxs[j][i] = Double.parseDouble(parts[1]); } } + + _srcCols = buildSrcCols(meta, _dcCols); } @Override public void writeExternal(ObjectOutput out) throws IOException { super.writeExternal(out); + // bin boundaries; the per-column bin count is the length of the boundary arrays for( int i=0; i<_colList.length; i++ ) { - int len = _numBins[i]; + int len = _binMins[i].length; out.writeInt(len); for(int j=0; j> tasks = new ArrayList<>(); - int blz = Math.max(in.getNumRows() / k, 1000); - // Parallelize over row blocks (not over decoders): all decoders must - // run in order within a block, e.g. recode-on-output depends on the - // category indexes produced by the preceding dummycode decoder. - for(int i = 0; i < in.getNumRows(); i += blz){ - final int start = i; - final int end = Math.min(in.getNumRows(), i + blz); - tasks.add(pool.submit(() -> decode(in, out, start, end))); - } - for(Future f : tasks) - f.get(); - return out; - } - catch(Exception e) { - throw new RuntimeException(e); - } - finally { - pool.shutdown(); - } - } - @Override public void decode(MatrixBlock in, FrameBlock out, int rl, int ru){ for( Decoder decoder : _decoders ) diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java index 0c4c6b42690..ee1a33c49fd 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java @@ -27,31 +27,33 @@ import java.util.List; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.frame.data.FrameBlock; -import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.util.UtilFunctions; /** - * Simple atomic decoder for dummycoded columns. This decoder builds internally - * inverted column mappings from the given frame meta data. - * + * Simple atomic decoder for dummycoded columns. This decoder builds internally inverted column mappings from the given + * frame meta data. + * */ -public class DecoderDummycode extends Decoder -{ +public class DecoderDummycode extends Decoder { private static final long serialVersionUID = 4758831042891032129L; - + private int[] _clPos = null; private int[] _cuPos = null; - + protected DecoderDummycode(ValueType[] schema, int[] dcCols) { - //dcCols refers to column IDs in output (non-dc) + this(schema, dcCols, null); + } + + protected DecoderDummycode(ValueType[] schema, int[] dcCols, int[] hashCols) { + // dcCols refers to column IDs in output (non-dc) super(schema, dcCols); + _dcHashCols = hashCols; } @Override public FrameBlock decode(MatrixBlock in, FrameBlock out) { - //TODO perf (exploit sparse representation for better asymptotic behavior) out.ensureAllocatedColumns(in.getNumRows()); decode(in, out, 0, in.getNumRows()); return out; @@ -59,59 +61,97 @@ public FrameBlock decode(MatrixBlock in, FrameBlock out) { @Override public void decode(MatrixBlock in, FrameBlock out, int rl, int ru) { - //TODO perf (exploit sparse representation for better asymptotic behavior) - // out.ensureAllocatedColumns(in.getNumRows()); - for( int i=rl; i= low && aix[h] < high) { + int k = aix[h]; + int col = _colList[j] - 1; + out.getColumn(col).set(i, k - low + 1); + } + // limit the binary search. + apos = h; + } + + } + @Override public Decoder subRangeDecoder(int colStart, int colEnd, int dummycodedOffset) { List dcList = new ArrayList<>(); List clPosList = new ArrayList<>(); List cuPosList = new ArrayList<>(); - + // get the column IDs for the sub range of the dummycode columns and their destination positions, // where they will be decoded to - for( int j=0; j<_colList.length; j++ ) { + for(int j = 0; j < _colList.length; j++) { int colID = _colList[j]; - if (colID >= colStart && colID < colEnd) { + if(colID >= colStart && colID < colEnd) { dcList.add(colID - (colStart - 1)); clPosList.add(_clPos[j] - dummycodedOffset); cuPosList.add(_cuPos[j] - dummycodedOffset); } } - if (dcList.isEmpty()) + if(dcList.isEmpty()) return null; // create sub-range decoder int[] colList = dcList.stream().mapToInt(i -> i).toArray(); - DecoderDummycode subRangeDecoder = new DecoderDummycode( - Arrays.copyOfRange(_schema, colStart - 1, colEnd - 1), colList); + DecoderDummycode subRangeDecoder = new DecoderDummycode(Arrays.copyOfRange(_schema, colStart - 1, colEnd - 1), + colList); subRangeDecoder._clPos = clPosList.stream().mapToInt(i -> i).toArray(); subRangeDecoder._cuPos = cuPosList.stream().mapToInt(i -> i).toArray(); return subRangeDecoder; } - + @Override public void updateIndexRanges(long[] beginDims, long[] endDims) { if(_colList == null) return; - + long lowerColDest = beginDims[1]; long upperColDest = endDims[1]; for(int i = 0; i < _colList.length; i++) { long numDistinct = _cuPos[i] - _clPos[i]; - + if(_cuPos[i] <= beginDims[1] + 1) if(numDistinct > 0) lowerColDest -= numDistinct - 1; - + if(_cuPos[i] <= endDims[1] + 1) if(numDistinct > 0) upperColDest -= numDistinct - 1; @@ -119,16 +159,16 @@ public void updateIndexRanges(long[] beginDims, long[] endDims) { beginDims[1] = lowerColDest; endDims[1] = upperColDest; } - + @Override public void initMetaData(FrameBlock meta) { - _clPos = new int[_colList.length]; //col lower pos - _cuPos = new int[_colList.length]; //col upper pos - for( int j=0, off=0; j<_colList.length; j++ ) { + _clPos = new int[_colList.length]; // col lower pos + _cuPos = new int[_colList.length]; // col upper pos + for(int j = 0, off = 0; j < _colList.length; j++) { int colID = _colList[j]; - ColumnMetadata d = meta.getColumnMetadata()[colID-1]; - int ndist = d.isDefault() ? 0 : (int)d.getNumDistinct(); - ndist = ndist < -1 ? 0: ndist; + // hash columns store the domain size K in the meta cell; others use numDistinct + int ndist = getNumDummycodeDistinct(meta, colID, isHashCol(colID)); + _clPos[j] = off + colID; _cuPos[j] = _clPos[j] + ndist; off += ndist - 1; diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java index 0a400e6da92..8f6c45d63e8 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java @@ -64,41 +64,67 @@ public static Decoder createDecoder(String spec, String[] colnames, ValueType[] try { //parse transform specification JSONObject jSpec = new JSONObject(spec); - List ldecoders = new ArrayList<>(); - //create decoders 'bin', 'recode', 'dummy' and 'pass-through' + //create decoders 'bin', 'recode', 'hash', 'dummy', and 'pass-through' List binIDs = TfMetaUtils.parseBinningColIDs(jSpec, colnames, minCol, maxCol); List rcIDs = Arrays.asList(ArrayUtils.toObject( TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.RECODE.toString(), minCol, maxCol))); + List hcIDs = Arrays.asList(ArrayUtils.toObject( + TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.HASH.toString(), minCol, maxCol))); List dcIDs = Arrays.asList(ArrayUtils.toObject( TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.DUMMYCODE.toString(), minCol, maxCol))); + // only specially treat the columns with both recode and dictionary rcIDs = unionDistinct(rcIDs, dcIDs); + // hashing is a lossy, one-way transform with no inverse recode map, so hash columns + // are never recode-decoded; exclude them from the recode set + rcIDs = except(rcIDs, hcIDs); + + // dummycoded hash columns: domain size K lives in the meta cell, so the decoders + // need to know which dummycoded columns to read it from + List hcdcIDs = new ArrayList<>(dcIDs); + hcdcIDs.retainAll(hcIDs); + int[] hashCols = ArrayUtils.toPrimitive(hcdcIDs.toArray(new Integer[0])); + int len = dcIDs.isEmpty() ? Math.min(meta.getNumColumns(), clen) : meta.getNumColumns(); - List ptIDs = except(except(UtilFunctions.getSeqList(1, len, 1), rcIDs), binIDs); - + + // set the remaining columns to passthrough. + List ptIDs = UtilFunctions.getSeqList(1, len, 1); + // except recoded columns + ptIDs = except(ptIDs, rcIDs); + // binned columns + ptIDs = except(ptIDs, binIDs); + // dummycoded columns (incl. dummycoded hash) are rebuilt by the dummycode decoder; + // hash columns without dummycode stay in passthrough so their bucket code survives + ptIDs = except(ptIDs, dcIDs); + //create default schema if unspecified (with double columns for pass-through) if( schema == null ) { schema = UtilFunctions.nCopies(len, ValueType.STRING); for( Integer col : ptIDs ) schema[col-1] = ValueType.FP64; } + + // collect all the decoders in one list. + List ldecoders = new ArrayList<>(); if( !binIDs.isEmpty() ) { ldecoders.add(new DecoderBin(schema, - ArrayUtils.toPrimitive(binIDs.toArray(new Integer[0])))); + ArrayUtils.toPrimitive(binIDs.toArray(new Integer[0])), + ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0])), hashCols)); } if( !dcIDs.isEmpty() ) { ldecoders.add(new DecoderDummycode(schema, - ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0])))); + ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0])), hashCols)); } if( !rcIDs.isEmpty() ) { + // recode on output (after dummycode rebuilds the categorical columns) when dummycoding is present ldecoders.add(new DecoderRecode(schema, !dcIDs.isEmpty(), ArrayUtils.toPrimitive(rcIDs.toArray(new Integer[0])))); } if( !ptIDs.isEmpty() ) { ldecoders.add(new DecoderPassThrough(schema, ArrayUtils.toPrimitive(ptIDs.toArray(new Integer[0])), - ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0])))); + ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0])), hashCols)); } //create composite decoder of all created decoders @@ -121,6 +147,8 @@ else if( decoder instanceof DecoderRecode ) return DecoderType.Recode.ordinal(); else if( decoder instanceof DecoderPassThrough ) return DecoderType.PassThrough.ordinal(); + else if( decoder instanceof DecoderBin ) + return DecoderType.Bin.ordinal(); throw new DMLRuntimeException("Unsupported decoder type: " + decoder.getClass().getCanonicalName()); } @@ -130,6 +158,7 @@ public static Decoder createInstance(int type) { // create instance switch(dtype) { + case Bin: return new DecoderBin(); case Dummycode: return new DecoderDummycode(null, null); case PassThrough: return new DecoderPassThrough(null, null, null); case Recode: return new DecoderRecode(null, false, null); diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java index 5b6bf7a093e..9b134601419 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java @@ -28,9 +28,7 @@ import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.frame.data.FrameBlock; -import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.util.UtilFunctions; /** * Simple atomic decoder for passing through numeric columns to the output. @@ -45,8 +43,13 @@ public class DecoderPassThrough extends Decoder private int[] _srcCols = null; protected DecoderPassThrough(ValueType[] schema, int[] ptCols, int[] dcCols) { + this(schema, ptCols, dcCols, null); + } + + protected DecoderPassThrough(ValueType[] schema, int[] ptCols, int[] dcCols, int[] hashCols) { super(schema, ptCols); _dcCols = dcCols; + _dcHashCols = hashCols; } public DecoderPassThrough() { super(null, null); } @@ -61,13 +64,12 @@ public FrameBlock decode(MatrixBlock in, FrameBlock out) { @Override public void decode(MatrixBlock in, FrameBlock out, int rl, int ru) { int clen = Math.min(_colList.length, out.getNumColumns()); - for( int i=rl; i 0 ) { - //prepare source column id mapping w/ dummy coding - _srcCols = new int[_colList.length]; - int ix1 = 0, ix2 = 0, off = 0; - while( ix1<_colList.length ) { - if( ix2>=_dcCols.length || _colList[ix1] < _dcCols[ix2] ) { - _srcCols[ix1] = _colList[ix1] + off; - ix1 ++; - } - else { //_colList[ix1] > _dcCols[ix2] - ColumnMetadata d =meta.getColumnMetadata()[_dcCols[ix2]-1]; - off += d.isDefault() ? -1 : d.getNumDistinct() - 1; - ix2 ++; - } - } - } - else { - //prepare direct source column mapping - _srcCols = _colList; - } + _srcCols = buildSrcCols(meta, _dcCols); } @Override @@ -134,8 +117,8 @@ public void writeExternal(ObjectOutput os) for(int i = 0; i < _srcCols.length; i++) os.writeInt(_srcCols[i]); - os.writeInt(_dcCols.length); - for(int i = 0; i < _dcCols.length; i++) + os.writeInt(_dcCols == null ? 0 : _dcCols.length); + for(int i = 0; _dcCols != null && i < _dcCols.length; i++) os.writeInt(_dcCols[i]); } diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java index 33459a1c4f9..11dd2c7faa5 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java @@ -29,6 +29,7 @@ import java.util.Map.Entry; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.Pair; @@ -46,7 +47,6 @@ public class DecoderRecode extends Decoder private static final long serialVersionUID = -3784249774608228805L; private HashMap[] _rcMaps = null; - private Object[][] _rcMapsDirect = null; private boolean _onOut = false; public DecoderRecode() { @@ -59,8 +59,7 @@ protected DecoderRecode(ValueType[] schema, boolean onOut, int[] rcCols) { } public Object getRcMapValue(int i, long key) { - return (_rcMapsDirect != null && key > 0) ? - _rcMapsDirect[i][(int)key-1] : _rcMaps[i].get(key); + return _rcMaps[i].get(key); } @Override @@ -125,31 +124,26 @@ public Decoder subRangeDecoder(int colStart, int colEnd, int dummycodedOffset) { public void initMetaData(FrameBlock meta) { //initialize recode maps according to schema _rcMaps = new HashMap[_colList.length]; - long[] max = new long[_colList.length]; for( int j=0; j<_colList.length; j++ ) { HashMap map = new HashMap<>(); for( int i=0; i v < Integer.MAX_VALUE) ) { - _rcMapsDirect = new Object[_rcMaps.length][]; - for( int i=0; i<_rcMaps.length; i++ ) { - Object[] arr = new Object[(int)max[i]]; - for(Entry e1 : _rcMaps[i].entrySet()) - arr[e1.getKey().intValue()-1] = e1.getValue(); - _rcMapsDirect[i] = arr; - } - } } /** diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java index 400b7f64ffc..cd9a583d60f 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java @@ -146,7 +146,9 @@ public FrameBlock getMetaData(FrameBlock meta) { return meta; meta.ensureAllocatedColumns(1); + // store the hash domain size K in the single meta cell meta.set(0, _colID - 1, String.valueOf(_K)); + return meta; } diff --git a/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java b/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java index df386d4659d..5119cbadd65 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java +++ b/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java @@ -2859,4 +2859,29 @@ public void stringArrayGetDoubleNaN(){ assertTrue(Double.isNaN(s.getAsNaNDouble(i))); } } + + @Test + public void stringArrayGetDoubleBooleanTokens() { + // non-numeric boolean-like tokens fall back to 1/0 (case insensitive, single char or full word) + String[] truthy = new String[] {"true", "True", "TRUE", "t", "T"}; + String[] falsy = new String[] {"false", "False", "FALSE", "f", "F"}; + Array t = ArrayFactory.create(truthy); + for(int i = 0; i < t.size(); i++) + assertEquals(1.0, t.getAsDouble(i), 0.0); + Array f = ArrayFactory.create(falsy); + for(int i = 0; i < f.size(); i++) + assertEquals(0.0, f.getAsDouble(i), 0.0); + } + + @Test(expected = DMLRuntimeException.class) + public void stringArrayGetDoubleInvalidThrows() { + // a token that is neither numeric nor a boolean word/char must throw + ArrayFactory.create(new String[] {"notabool"}).getAsDouble(0); + } + + @Test(expected = DMLRuntimeException.class) + public void stringArrayGetDoubleAmbiguousLengthThrows() { + // length matches neither 1, 4, nor 5 boolean tokens -> reject + ArrayFactory.create(new String[] {"tru"}).getAsDouble(0); + } } diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeRoundTripTest.java b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeRoundTripTest.java new file mode 100644 index 00000000000..b2d31f43b83 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeRoundTripTest.java @@ -0,0 +1,590 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.frame.transform; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.util.concurrent.CountDownLatch; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.transform.decode.Decoder; +import org.apache.sysds.runtime.transform.decode.DecoderFactory; +import org.apache.sysds.runtime.transform.encode.EncoderFactory; +import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder; +import org.apache.sysds.test.TestUtils; +import org.junit.Before; +import org.junit.Test; + +/** + * Exact inverse correctness tests for the transform decoders. Recode and dummycode are lossless category encodings, so a + * decode of the encoded matrix must reconstruct the original categorical frame. These tests assert exact reconstruction + * for the dense path, the sparse path, and the parallel path so that the dummycode sparse binary search and the parallel + * block split are validated against ground truth rather than only against each other. + */ +public class TransformDecodeRoundTripTest { + protected static final Log LOG = LogFactory.getLog(TransformDecodeRoundTripTest.class.getName()); + + @Before + public void setUp() { + // name must contain "main" so the parallel decode path reuses the shared thread pool + Thread.currentThread().setName("main_test_decode"); + } + + private static FrameBlock categoricalFrame() { + final String[] values = new String[] { + "apple", "banana", "apple", "cherry", "banana", "date", "apple", "cherry", "date", "banana", "elderberry", + "apple", "fig", "banana", "cherry", "apple", "date", "fig", "elderberry", "banana"}; + final FrameBlock f = new FrameBlock(new ValueType[] {ValueType.STRING}); + f.ensureAllocatedColumns(values.length); + for(int i = 0; i < values.length; i++) + f.set(i, 0, values[i]); + return f; + } + + @Test + public void recodeReconstructsOriginalDense() { + roundTrip("{ids:true, recode:[1]}", false, 1); + } + + @Test + public void recodeReconstructsOriginalSparse() { + roundTrip("{ids:true, recode:[1]}", true, 1); + } + + @Test + public void recodeReconstructsOriginalParallel() { + roundTrip("{ids:true, recode:[1]}", false, 4); + } + + @Test + public void dummycodeReconstructsOriginalDense() { + roundTrip("{ids:true, recode:[1], dummycode:[1]}", false, 1); + } + + @Test + public void dummycodeReconstructsOriginalSparse() { + // the one-hot encoded matrix is sparse, so this drives the dummycode sparse binary-search decode path + roundTrip("{ids:true, recode:[1], dummycode:[1]}", true, 1); + } + + @Test + public void dummycodeReconstructsOriginalParallel() { + roundTrip("{ids:true, recode:[1], dummycode:[1]}", false, 4); + } + + /** + * Binning a column while a different column is dummycoded shifts the bin column's source position in the encoded + * matrix. The bin decoder must rebuild that source-column mapping from the dummycode domain sizes. This asserts the + * dense, sparse, and parallel decode paths agree for that layout (bin output is lossy, so exact reconstruction is + * not asserted, only cross-mode consistency and dimensions). + */ + @Test + public void binWithDummycodeOnOtherColumnConsistency() { + // bin column (1) precedes the dummycode column (2): the bin decoder takes the direct + // source-column path because no expanded column sits before it + final FrameBlock original = TestUtils.generateRandomFrameBlock(150, + new ValueType[] {ValueType.FP32, ValueType.UINT4, ValueType.UINT8}, 4242); + binConsistency("{ids:true, bin:[{id:1, method:equi-width, numbins:4}], dummycode:[2]}", original); + } + + /** + * Dummycode on an earlier column (1) shifts the bin column (2) to the right in the encoded matrix. The bin decoder + * must walk the dummycode domain sizes to recover the bin column's true source position. This drives the + * non-magic offset branch of the bin source-column mapping. + */ + @Test + public void binAfterDummycodeOnEarlierColumnConsistency() { + final FrameBlock original = TestUtils.generateRandomFrameBlock(150, + new ValueType[] {ValueType.UINT4, ValueType.FP32, ValueType.UINT8}, 4242); + binConsistency("{ids:true, recode:[1], dummycode:[1], bin:[{id:2, method:equi-width, numbins:4}]}", original); + } + + /** + * Same right-shift as above, but the earlier column is feature-hashed before being dummycoded. The hash domain + * size K is stored as a plain integer in the single meta cell, so the bin source-column mapping reads it (instead + * of numDistinct) to compute the offset. + */ + @Test + public void binAfterHashDummycodeOnEarlierColumnConsistency() { + final FrameBlock original = TestUtils.generateRandomFrameBlock(150, + new ValueType[] {ValueType.UINT4, ValueType.FP32, ValueType.UINT8}, 4242); + binConsistency("{ids:true, hash:[1], K:6, dummycode:[1], bin:[{id:2, method:equi-width, numbins:4}]}", + original); + } + + /** + * Encode then decode the dense, parallel and sparse paths and assert they agree. Bin output is lossy, so only + * cross-mode consistency and row count are asserted (not exact reconstruction). + */ + private void binConsistency(String spec, FrameBlock original) { + try { + final String[] colnames = original.getColumnNames(); + + final MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, colnames, original.getNumColumns(), + null); + final MatrixBlock encoded = encoder.encode(original, 1); + final FrameBlock meta = encoder.getMetaData(null); + + final MatrixBlock dense = new MatrixBlock(); + dense.copy(encoded); + if(dense.isInSparseFormat()) + dense.sparseToDense(); + + final MatrixBlock sparse = new MatrixBlock(); + sparse.copy(encoded); + if(!sparse.isInSparseFormat()) + sparse.denseToSparse(); + + final FrameBlock reference = decodeOnce(spec, colnames, meta, dense, 1); + final FrameBlock parallel = decodeOnce(spec, colnames, meta, dense, 4); + final FrameBlock fromSparse = decodeOnce(spec, colnames, meta, sparse, 1); + + org.junit.Assert.assertEquals(original.getNumRows(), reference.getNumRows()); + TestUtils.compareFrames(reference, parallel, false); + TestUtils.compareFrames(reference, fromSparse, false); + } + catch(Exception e) { + e.printStackTrace(); + fail(spec + " : " + e.getMessage()); + } + } + + /** + * The bin encoder always emits codes >= 1, but the decoder defensively handles a 0 code by mapping it to the + * first bin's lower boundary. Inject a 0 into an otherwise validly encoded matrix to exercise that branch. + */ + @Test + public void binDecodeZeroCodeUsesFirstBinBoundary() { + final String spec = "{ids:true, bin:[{id:1, method:equi-width, numbins:4}]}"; + try { + final FrameBlock original = TestUtils.generateRandomFrameBlock(50, new ValueType[] {ValueType.FP32}, 13); + final String[] colnames = original.getColumnNames(); + final MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, colnames, original.getNumColumns(), + null); + final MatrixBlock encoded = encoder.encode(original, 1); + if(encoded.isInSparseFormat()) + encoded.sparseToDense(); + final FrameBlock meta = encoder.getMetaData(null); + + encoded.set(0, 0, 0); // force a 0 bin code + + final Decoder decoder = DecoderFactory.createDecoder(spec, colnames, null, meta, encoded.getNumColumns()); + final FrameBlock decoded = decoder.decode(encoded, new FrameBlock(decoder.getSchema()), 1); + + final double first = Double.parseDouble(decoded.get(0, 0).toString()); + final double second = Double.parseDouble(decoded.get(1, 0).toString()); + // the 0-coded row decodes to the first bin lower bound, which is <= any properly binned center + org.junit.Assert.assertTrue("0-code must map to the lowest bin boundary", first <= second); + } + catch(Exception e) { + e.printStackTrace(); + fail(spec + " : " + e.getMessage()); + } + } + + /** + * Spark broadcasts the decoder to executors via Java serialization without re-running initMetaData, so the + * decoder must round-trip all of its decode state through writeExternal/readExternal. Decode with a freshly + * deserialized decoder and assert it matches the in-memory decode. Covers plain bin and bin-with-dummycode + * (the latter exercises the serialized _srcCols/_dcCols source-column mapping). + */ + @Test + public void binDecoderSurvivesSerialization() { + final FrameBlock original = TestUtils.generateRandomFrameBlock(80, new ValueType[] {ValueType.FP32}, 21); + serializeRoundTrip("{ids:true, bin:[{id:1, method:equi-width, numbins:4}]}", original); + } + + @Test + public void binWithDummycodeDecoderSurvivesSerialization() { + final FrameBlock original = TestUtils.generateRandomFrameBlock(80, + new ValueType[] {ValueType.UINT4, ValueType.FP32}, 21); + serializeRoundTrip("{ids:true, recode:[1], dummycode:[1], bin:[{id:2, method:equi-width, numbins:4}]}", + original); + } + + private void serializeRoundTrip(String spec, FrameBlock original) { + try { + final String[] colnames = original.getColumnNames(); + final MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, colnames, original.getNumColumns(), + null); + final MatrixBlock encoded = encoder.encode(original, 1); + if(encoded.isInSparseFormat()) + encoded.sparseToDense(); + final FrameBlock meta = encoder.getMetaData(null); + + final Decoder decoder = DecoderFactory.createDecoder(spec, colnames, null, meta, encoded.getNumColumns()); + final FrameBlock expected = decoder.decode(encoded, new FrameBlock(decoder.getSchema()), 1); + + final Decoder restored = serializeDeserialize(decoder); + final FrameBlock actual = restored.decode(encoded, new FrameBlock(restored.getSchema()), 1); + + TestUtils.compareFrames(expected, actual, false); + } + catch(Exception e) { + e.printStackTrace(); + fail(spec + " : " + e.getMessage()); + } + } + + private static Decoder serializeDeserialize(Decoder decoder) throws Exception { + final ByteArrayOutputStream bos = new ByteArrayOutputStream(); + try(ObjectOutputStream oos = new ObjectOutputStream(bos)) { + oos.writeObject(decoder); + } + try(ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(bos.toByteArray()))) { + return (Decoder) ois.readObject(); + } + } + + /** + * Feature hashing is non-invertible, so the decode contract for a hash column that is NOT dummycoded is that the + * encoded bucket code passes through unchanged. Regression test: a hash-only column must not be dropped from the + * decoded frame (it previously was, because hash columns were excluded from passthrough). + */ + @Test + public void hashWithoutDummycodeDecodesToBucketCode() { + final String spec = "{ids:true, hash:[1], K:8}"; + try { + final FrameBlock original = categoricalFrame(); + final String[] colnames = original.getColumnNames(); + final MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, colnames, original.getNumColumns(), + null); + final MatrixBlock encoded = encoder.encode(original, 1); + if(encoded.isInSparseFormat()) + encoded.sparseToDense(); + final FrameBlock meta = encoder.getMetaData(null); + + final Decoder decoder = DecoderFactory.createDecoder(spec, colnames, null, meta, encoded.getNumColumns()); + final FrameBlock decoded = decoder.decode(encoded, new FrameBlock(decoder.getSchema()), 1); + + org.junit.Assert.assertEquals(1, decoded.getNumColumns()); + for(int i = 0; i < original.getNumRows(); i++) { + final Object v = decoded.get(i, 0); + org.junit.Assert.assertNotNull("hash column must survive decode at row " + i, v); + org.junit.Assert.assertEquals("hash bucket code must pass through at row " + i, encoded.get(i, 0), + Double.parseDouble(v.toString()), 0.0); + } + } + catch(Exception e) { + e.printStackTrace(); + fail(spec + " : " + e.getMessage()); + } + } + + /** + * A corrupt recode meta entry (no token/code separator) must surface as a {@link DMLRuntimeException} during + * meta-data initialization rather than a raw parsing exception, so callers get an actionable error. Covers the + * defensive try/catch added around the recode-map reconstruction. + */ + @Test + public void recodeInitMetaDataRejectsCorruptEntry() { + final String spec = "{ids:true, recode:[1]}"; + try { + final FrameBlock original = categoricalFrame(); + final String[] colnames = original.getColumnNames(); + final MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, colnames, original.getNumColumns(), + null); + encoder.encode(original, 1); + final FrameBlock meta = encoder.getMetaData(null); + // overwrite the first recode entry with a value lacking the token/code separator + meta.set(0, 0, "corrupt-entry-without-separator"); + + try { + DecoderFactory.createDecoder(spec, colnames, null, meta, original.getNumColumns()); + fail("expected a corrupt recode entry to be rejected"); + } + catch(DMLRuntimeException expected) { + assertTrue("error should identify the recode map reinitialization, got: " + messageChain(expected), + messageChain(expected).contains("recode map")); + } + } + catch(Exception e) { + e.printStackTrace(); + fail(spec + " : " + e.getMessage()); + } + } + + /** + * Federated transform-decode slices a global decoder per worker via {@link Decoder#updateIndexRanges} and + * {@link Decoder#subRangeDecoder}. For a single worker covering the whole matrix, the dummycode expansion must + * collapse the encoded column count down to the decoded column count, and the resulting sub-range decoder must + * reproduce the global decode exactly. Exercises the dummycode index-range and sub-range mapping. + */ + @Test + public void dummycodeSubRangeFullRangeMatchesGlobalDecode() { + final String spec = "{ids:true, recode:[1], dummycode:[1]}"; + try { + final FrameBlock original = TestUtils.generateRandomFrameBlock(60, + new ValueType[] {ValueType.UINT4, ValueType.FP32}, 91); + final String[] colnames = original.getColumnNames(); + final MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, colnames, original.getNumColumns(), + null); + final MatrixBlock encoded = encoder.encode(original, 1); + if(encoded.isInSparseFormat()) + encoded.sparseToDense(); + final FrameBlock meta = encoder.getMetaData(null); + + final Decoder global = DecoderFactory.createDecoder(spec, colnames, null, meta, encoded.getNumColumns()); + final FrameBlock full = global.decode(encoded, new FrameBlock(global.getSchema()), 1); + + // single worker covering the whole matrix: map encoded column range to decoded column range + final long[] beginDims = {0, 0}; + final long[] endDims = {encoded.getNumRows(), encoded.getNumColumns()}; + global.updateIndexRanges(beginDims, endDims); + + org.junit.Assert.assertEquals("begin column must stay at 0", 0, beginDims[1]); + org.junit.Assert.assertEquals("dummycode expansion must collapse to the decoded column count", + full.getNumColumns(), (int) endDims[1]); + + final Decoder sub = global.subRangeDecoder(1, (int) endDims[1] + 1, 0); + final FrameBlock subDecoded = sub.decode(encoded, new FrameBlock(sub.getSchema()), 1); + TestUtils.compareFrames(full, subDecoded, false); + } + catch(Exception e) { + e.printStackTrace(); + fail(spec + " : " + e.getMessage()); + } + } + + /** + * A federated worker holding only the columns after a dummycoded column must shift its index range left by the + * dummycode expansion and receive a sub-range decoder containing just the trailing pass-through columns (the + * dummycode and recode decoders drop out). Mirrors the {@code updateIndexRanges} + {@code subRangeDecoder} call + * sequence in federated transform-decode, covering the index-range shift for a fully-preceding dummycode column and + * the empty sub-range branch. + */ + @Test + public void dummycodeSubRangeExcludingDummycodedColumnKeepsRemaining() { + final String spec = "{ids:true, recode:[1], dummycode:[1]}"; + try { + final FrameBlock original = TestUtils.generateRandomFrameBlock(40, + new ValueType[] {ValueType.UINT4, ValueType.FP32, ValueType.FP32}, 73); + final String[] colnames = original.getColumnNames(); + final MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, colnames, original.getNumColumns(), + null); + final MatrixBlock encoded = encoder.encode(original, 1); + if(encoded.isInSparseFormat()) + encoded.sparseToDense(); + final FrameBlock meta = encoder.getMetaData(null); + + final Decoder global = DecoderFactory.createDecoder(spec, colnames, null, meta, encoded.getNumColumns()); + + // the dummycode column expands to (encodedCols - 2) one-hot columns; a worker owning only the two trailing + // pass-through columns starts after that expanded block in encoded column space + final int dcWidth = encoded.getNumColumns() - 2; + final long[] beginDims = {0, dcWidth}; + final long[] endDims = {encoded.getNumRows(), dcWidth + 2}; + final int colStartBefore = (int) beginDims[1]; + global.updateIndexRanges(beginDims, endDims); + + // after collapsing the preceding dummycode expansion, the worker maps to decoded columns 2..3 + org.junit.Assert.assertEquals(1, beginDims[1]); + org.junit.Assert.assertEquals(3, endDims[1]); + + final Decoder sub = global.subRangeDecoder((int) beginDims[1] + 1, (int) endDims[1] + 1, colStartBefore); + org.junit.Assert.assertNotNull("pass-through columns must still yield a decoder", sub); + org.junit.Assert.assertEquals("only the two trailing pass-through columns remain", 2, + sub.getSchema().length); + } + catch(Exception e) { + e.printStackTrace(); + fail(spec + " : " + e.getMessage()); + } + } + + /** + * Two recode columns with different domain sizes leave trailing empty (null) cells in the shorter column's + * recode-map column. Reconstructing that map must stop at the first null rather than read past it. Recode is + * lossless, so the decode must reconstruct the original frame exactly. + */ + @Test + public void recodeMultiColumnWithTrailingNullMapEntries() { + final String spec = "{ids:true, recode:[1, 2]}"; + try { + final FrameBlock original = new FrameBlock(new ValueType[] {ValueType.STRING, ValueType.STRING}); + final String[] high = {"a", "b", "c", "d", "e", "f", "g", "h"}; + final String[] low = {"x", "y"}; + final int n = 16; + original.ensureAllocatedColumns(n); + for(int i = 0; i < n; i++) { + original.set(i, 0, high[i % high.length]); + original.set(i, 1, low[i % low.length]); + } + final String[] colnames = original.getColumnNames(); + final MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, colnames, original.getNumColumns(), + null); + final MatrixBlock encoded = encoder.encode(original, 1); + if(encoded.isInSparseFormat()) + encoded.sparseToDense(); + final FrameBlock meta = encoder.getMetaData(null); + + final Decoder decoder = DecoderFactory.createDecoder(spec, colnames, null, meta, encoded.getNumColumns()); + final FrameBlock decoded = decoder.decode(encoded, new FrameBlock(decoder.getSchema()), 1); + TestUtils.compareFrames(original, decoded, false); + } + catch(Exception e) { + e.printStackTrace(); + fail(spec + " : " + e.getMessage()); + } + } + + /** + * The parallel decode path runs per-row-block decode tasks on a thread pool; a failure inside a worker must not be + * swallowed but resurface as an unchecked exception to the caller. Feeding a matrix with far fewer columns than the + * decoder expects forces an out-of-range access in a worker, which the parallel wrapper must propagate. + */ + @Test + public void parallelDecodeWrapsWorkerException() { + final String spec = "{ids:true, recode:[1], dummycode:[1]}"; + try { + final FrameBlock original = categoricalFrame(); + final String[] colnames = original.getColumnNames(); + final MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, colnames, original.getNumColumns(), + null); + final MatrixBlock encoded = encoder.encode(original, 1); + final FrameBlock meta = encoder.getMetaData(null); + final Decoder decoder = DecoderFactory.createDecoder(spec, colnames, null, meta, encoded.getNumColumns()); + + // far fewer columns than the dummycode decoder reads -> a parallel worker accesses out of range + final MatrixBlock broken = new MatrixBlock(2, 1, false); + broken.allocateDenseBlock(); + try { + decoder.decode(broken, new FrameBlock(decoder.getSchema()), 4); + fail("expected the parallel decode wrapper to propagate the worker failure"); + } + catch(DMLRuntimeException expected) { + assertNotNull("parallel decode wrapper must retain the worker exception as cause", + expected.getCause()); + } + } + catch(Exception e) { + e.printStackTrace(); + fail(spec + " : " + e.getMessage()); + } + } + + /** + * Interrupting a worker mid parallel-decode must restore the caller's interrupt flag (which {@code Future.get} + * clears when it throws) and surface the failure as a {@link DMLRuntimeException}. The same minimal decoder also + * exercises the base sub-range contract, which rejects decoders that do not implement column sub-ranging. + */ + @Test + public void parallelDecodeInterruptionRestoresFlagAndRejectsSubRange() { + final Thread caller = Thread.currentThread(); + final CountDownLatch release = new CountDownLatch(1); + final Decoder decoder = new Decoder(new ValueType[] {ValueType.FP64}, new int[] {1}) { + private static final long serialVersionUID = 1L; + + @Override + public FrameBlock decode(MatrixBlock in, FrameBlock out) { + return out; + } + + @Override + public void decode(MatrixBlock in, FrameBlock out, int rl, int ru) { + // interrupt the thread blocked in Future.get and stay unfinished so the interrupt is observed + caller.interrupt(); + try { + release.await(); + } + catch(InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + @Override + public void initMetaData(FrameBlock meta) { + // no meta data needed + } + }; + + try { + decoder.subRangeDecoder(1, 2, 0); + fail("a decoder without sub-range support must reject the request"); + } + catch(DMLRuntimeException expected) { + assertTrue(messageChain(expected).contains("sub-range")); + } + + final MatrixBlock in = new MatrixBlock(1, 1, false); + in.allocateDenseBlock(); + try { + decoder.decode(in, new FrameBlock(decoder.getSchema()), 2); + fail("an interrupted parallel decode must throw"); + } + catch(DMLRuntimeException expected) { + assertTrue("the interrupt flag must be restored", Thread.currentThread().isInterrupted()); + assertNotNull(expected.getCause()); + } + finally { + release.countDown(); + Thread.interrupted(); // clear so the interrupt does not leak into other tests + } + } + + private static String messageChain(Throwable t) { + final StringBuilder sb = new StringBuilder(); + for(Throwable c = t; c != null; c = c.getCause()) + sb.append(c.getMessage()).append('\n'); + return sb.toString(); + } + + private static FrameBlock decodeOnce(String spec, String[] colnames, FrameBlock meta, MatrixBlock in, int k) { + final Decoder decoder = DecoderFactory.createDecoder(spec, colnames, null, meta, in.getNumColumns()); + return decoder.decode(in, new FrameBlock(decoder.getSchema()), k); + } + + private void roundTrip(String spec, boolean sparse, int k) { + try { + final FrameBlock original = categoricalFrame(); + final String[] colnames = original.getColumnNames(); + + final MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, colnames, original.getNumColumns(), + null); + MatrixBlock encoded = encoder.encode(original, 1); + final FrameBlock meta = encoder.getMetaData(null); + + if(sparse && !encoded.isInSparseFormat()) + encoded.denseToSparse(); + else if(!sparse && encoded.isInSparseFormat()) + encoded.sparseToDense(); + + final Decoder decoder = DecoderFactory.createDecoder(spec, colnames, null, meta, encoded.getNumColumns()); + final FrameBlock decoded = decoder.decode(encoded, new FrameBlock(decoder.getSchema()), k); + + TestUtils.compareFrames(original, decoded, false); + } + catch(Exception e) { + e.printStackTrace(); + fail(spec + " (sparse=" + sparse + ", k=" + k + ") : " + e.getMessage()); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeTest.java b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeTest.java new file mode 100644 index 00000000000..54bd1679716 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeTest.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.frame.transform; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.logging.Level; +import java.util.logging.Logger; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.transform.decode.Decoder; +import org.apache.sysds.runtime.transform.decode.DecoderFactory; +import org.apache.sysds.runtime.transform.encode.EncoderFactory; +import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder; +import org.apache.sysds.runtime.util.CommonThreadPool; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +/** + * Component tests for the transform decoders. These exercise the row-block and parallel decode paths, the sparse and + * dense dummycode decode paths, the binning source-column offset mapping, and feature-hash column handling end-to-end + * through an encode followed by decode round trip. + */ +@RunWith(value = Parameterized.class) +public class TransformDecodeTest { + protected static final Log LOG = LogFactory.getLog(TransformDecodeTest.class.getName()); + + private final FrameBlock data; + private final int k; + + public TransformDecodeTest(FrameBlock data, int k) { + // name must contain "main" so the parallel decode path reuses the shared thread pool + Thread.currentThread().setName("main_test_decode"); + Logger.getLogger(CommonThreadPool.class.getName()).setLevel(Level.OFF); + this.data = data; + this.k = k; + } + + @Parameters + public static Collection data() { + final ArrayList tests = new ArrayList<>(); + final int[] threads = new int[] {1, 4}; + try { + final FrameBlock[] blocks = new FrameBlock[] { + // single low-cardinality categorical column + TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT4}, 231), + // single categorical column with nulls + TestUtils.generateRandomFrameBlock(64, new ValueType[] {ValueType.UINT4}, 99, 0.2), + // multi column: dummycode/bin on col1 must offset the trailing passthrough columns + TestUtils.generateRandomFrameBlock(120, + new ValueType[] {ValueType.UINT4, ValueType.UINT8, ValueType.FP32}, 17), + // large enough to split into multiple row blocks in the parallel decode path + TestUtils.generateRandomFrameBlock(2500, new ValueType[] {ValueType.UINT4}, 7)}; + + for(FrameBlock block : blocks) + for(int k : threads) + tests.add(new Object[] {block, k}); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + return tests; + } + + @Test + public void testPassThrough() { + decodeConsistency("{ids:true}"); + } + + @Test + public void testRecode() { + decodeConsistency("{ids:true, recode:[1]}"); + } + + @Test + public void testDummycode() { + decodeConsistency("{ids:true, recode:[1], dummycode:[1]}"); + } + + @Test + public void testBinWidth() { + decodeConsistency("{ids:true, bin:[{id:1, method:equi-width, numbins:4}]}"); + } + + @Test + public void testBinHeight() { + decodeConsistency("{ids:true, bin:[{id:1, method:equi-height, numbins:10}]}"); + } + + @Test + public void testBinSingleBin() { + // numbins:1 collapses every value into a single bin, exercising the degenerate boundary handling + decodeConsistency("{ids:true, bin:[{id:1, method:equi-width, numbins:1}]}"); + } + + @Test + public void testHashToDummy() { + // feature-hash columns store their domain size K as a plain integer in the single meta cell, which the + // dummycode decoder reads (instead of numDistinct) to reconstruct the one-hot column ranges + decodeConsistency("{ids:true, hash:[1], K:8, dummycode:[1]}"); + } + + @Test + public void testHashToDummyDomain1() { + decodeConsistency("{ids:true, hash:[1], K:1, dummycode:[1]}"); + } + + /** + * Encode the data, then decode the encoded matrix in three ways: serial dense, parallel dense, and serial sparse. + * All three must produce identical frames. This jointly exercises the parallel block-decode path in + * {@link Decoder#decode(MatrixBlock, FrameBlock, int)} and the separate sparse / dense dummycode decode paths. + */ + private void decodeConsistency(String spec) { + try { + final String[] colnames = data.getColumnNames(); + final MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, colnames, data.getNumColumns(), null); + final MatrixBlock encoded = encoder.encode(data, 1); + final FrameBlock meta = encoder.getMetaData(null); + + final MatrixBlock dense = forceDense(encoded); + final MatrixBlock sparse = forceSparse(encoded); + + final FrameBlock reference = decode(spec, colnames, meta, dense, 1); + final FrameBlock parallel = decode(spec, colnames, meta, dense, k); + final FrameBlock fromSparse = decode(spec, colnames, meta, sparse, 1); + + assertEquals("decoded rows must match input rows", data.getNumRows(), reference.getNumRows()); + + TestUtils.compareFrames(reference, parallel, false); + TestUtils.compareFrames(reference, fromSparse, false); + } + catch(Exception e) { + e.printStackTrace(); + fail(spec + " : " + e.getMessage()); + } + } + + private static FrameBlock decode(String spec, String[] colnames, FrameBlock meta, MatrixBlock in, int k) { + final Decoder decoder = DecoderFactory.createDecoder(spec, colnames, null, meta, in.getNumColumns()); + return decoder.decode(in, new FrameBlock(decoder.getSchema()), k); + } + + private static MatrixBlock forceDense(MatrixBlock in) { + final MatrixBlock out = new MatrixBlock(); + out.copy(in); + if(out.isInSparseFormat()) + out.sparseToDense(); + return out; + } + + private static MatrixBlock forceSparse(MatrixBlock in) { + final MatrixBlock out = new MatrixBlock(); + out.copy(in); + if(!out.isInSparseFormat()) + out.denseToSparse(); + return out; + } +}