diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java index 58e33a616ca..dae13ed9f94 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java @@ -65,6 +65,7 @@ import org.apache.sysds.runtime.compress.lib.CLALibRexpand; import org.apache.sysds.runtime.compress.lib.CLALibScalar; import org.apache.sysds.runtime.compress.lib.CLALibSlice; +import org.apache.sysds.runtime.compress.lib.CLALibSort; import org.apache.sysds.runtime.compress.lib.CLALibSquash; import org.apache.sysds.runtime.compress.lib.CLALibTSMM; import org.apache.sysds.runtime.compress.lib.CLALibTernaryOp; @@ -847,9 +848,8 @@ public CmCovObject covOperations(COVOperator op, MatrixBlock that, MatrixBlock w } @Override - public MatrixBlock sortOperations(MatrixValue weights, MatrixBlock result) { - MatrixBlock right = getUncompressed(weights); - return getUncompressed("sortOperations").sortOperations(right, result); + public MatrixBlock sortOperations(MatrixValue weights, MatrixBlock result, int k) { + return CLALibSort.sort(this, weights, result, k); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java index f30cf8b17b2..354325e293b 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java @@ -974,6 +974,16 @@ public AColGroup[] splitReshapePushDown(final int multiplier, final int nRow, fi return splitReshape(multiplier, nRow, nColOrg); } + /** + * Sort the values of the column group according to double comparison operations and return as another compressed + * group. + * + * This sorting assumes that the column group is sorted independently of everything else. + * + * @return The sorted group + */ + public abstract AColGroup sort(); + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java index 7d0b2469ec8..64f3f4fda07 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java @@ -769,4 +769,9 @@ public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { return ColGroupConst.create(newColumnIDs, _dict.sliceColumns(selectedColumns, getNumCols())); } + + @Override + public AColGroup sort() { + return this; + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java index 6ac1544e61e..b316e48474a 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java @@ -1178,4 +1178,23 @@ public AColGroup convertToDeltaDDC() { public AColGroup convertToDDCLZW() { return ColGroupDDCLZW.create(_colIndexes, _dict, _data, null); } + + @Override + public AColGroup sort() { + // TODO restore support for run length encoding to exploit the runs + + int[] counts = getCounts(); + // get the sort index + int[] r = _dict.sort(); + + AMapToData m = MapToFactory.create(_data.size(), counts.length); + int off = 0; + for(int i = 0; i < counts.length; i++) { + for(int j = 0; j < counts[r[i]]; j++) { + m.set(off++, r[i]); + } + } + + return ColGroupDDC.create(_colIndexes, _dict, m, counts); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java index 6a4a92469d2..d8a8ed1ade6 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java @@ -571,4 +571,23 @@ public String toString() { sb.append(Arrays.toString(_reference)); return sb.toString(); } + + @Override + public AColGroup sort() { + // TODO restore support for run length encoding. + + int[] counts = getCounts(); + // get the sort index + int[] r = _dict.sort(); + + AMapToData m = MapToFactory.create(_data.size(), counts.length); + int off = 0; + for(int i = 0; i < counts.length; i++) { + for(int j = 0; j < counts[r[i]]; j++) { + m.set(off++, r[i]); + } + } + + return ColGroupDDCFOR.create(_colIndexes, _dict, m, counts, _reference); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCLZW.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCLZW.java index c820f875a05..1f3e5934288 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCLZW.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCLZW.java @@ -1022,4 +1022,10 @@ protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, ColGroupDDC g = (ColGroupDDC) convertToDDC(); return g.removeEmptyColsSubset(newColumnIDs, selectedColumns); } + + @Override + public AColGroup sort() { + ColGroupDDC g = (ColGroupDDC) convertToDDC(); + return g.sort(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java index 7c0a15e123b..64114a054ab 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java @@ -488,4 +488,9 @@ public AColGroup removeEmptyRows(boolean[] selectV, int rOut){ protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns){ return new ColGroupEmpty(newColumnIDs); } + + @Override + public AColGroup sort() { + return this; + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java index 5ac168b9406..fa8aa104ffb 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java @@ -750,4 +750,9 @@ public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns){ throw new NotImplementedException("Unimplemented method 'removeEmptyColumns'"); } + + @Override + public AColGroup sort() { + throw new NotImplementedException("Unimplemented method 'sort'"); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java index 5833729c378..a251d828b5f 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java @@ -741,4 +741,9 @@ public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns){ throw new NotImplementedException("Unimplemented method 'removeEmptyColumns'"); } + + @Override + public AColGroup sort() { + throw new NotImplementedException(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java index c9fc920a845..347cea9c0da 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java @@ -1200,4 +1200,9 @@ public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns){ throw new NotImplementedException("Unimplemented method 'removeEmptyColumns'"); } + + @Override + public AColGroup sort() { + throw new NotImplementedException(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java index 5522a33e3e0..faa5ca7fa27 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java @@ -903,4 +903,50 @@ public String toString() { sb.append(_data.toString()); return sb.toString(); } + + @Override + public AColGroup sort() { + if(getNumCols() > 1) + throw new NotImplementedException(); + // TODO restore support for run length encoding. + + final int[] counts = getCounts(); + // get the sort index + final int[] r = _dict.sort(); + + // find default value position. + // todo use binary search for minor improvements. + final double def = _defaultTuple[0]; + int defIdx = counts.length; + for(int i = 0; i < r.length; i++) { + if(_dict.getValue(r[i], 0, 1) >= def) { + defIdx = i; + break; + } + } + + int nondefault = _data.size(); + int defaultLength = _numRows - nondefault; + AMapToData m = MapToFactory.create(nondefault, counts.length); + int[] offsets = new int[nondefault]; + + int off = 0; + for(int i = 0; i < counts.length; i++) { + if(i < defIdx) { + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off; + m.set(off++, r[i]); + } + } + else {// if( i >= defIdx){ + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off + defaultLength; + m.set(off++, r[i]); + } + } + } + + AOffset o = OffsetFactory.createOffset(offsets); + return ColGroupSDC.create(_colIndexes, _numRows, _dict, _defaultTuple, o, m, counts); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java index 2ef7f3012bc..815ecacf378 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java @@ -651,4 +651,49 @@ public String toString() { return sb.toString(); } + @Override + public AColGroup sort() { + if(getNumCols() > 1) + throw new NotImplementedException(); + // TODO restore support for run length encoding. + + final int[] counts = getCounts(); + // get the sort index + final int[] r = _dict.sort(); + + // find default value position. + // todo use binary search for minor improvements. + int defIdx = counts.length; + for(int i = 0; i < r.length; i++) { + if(_dict.getValue(r[i], 0, 1) >= 0) { + defIdx = i; + break; + } + } + + int nondefault = _data.size(); + int defaultLength = _numRows - nondefault; + AMapToData m = MapToFactory.create(nondefault, counts.length); + int[] offsets = new int[nondefault]; + + int off = 0; + for(int i = 0; i < counts.length; i++) { + if(i < defIdx) { + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off; + m.set(off++, r[i]); + } + } + else {// if( i >= defIdx){ + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off + defaultLength; + m.set(off++, r[i]); + } + } + } + + AOffset o = OffsetFactory.createOffset(offsets); + return ColGroupSDCFOR.create(_colIndexes, _numRows, _dict, o, m, counts, _reference); + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java index 0f89e54d975..8a9f401c10c 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java @@ -746,4 +746,24 @@ public String toString() { sb.append(_indexes.toString()); return sb.toString(); } + + @Override + public AColGroup sort() { + if(getNumCols() > 1) + throw new NotImplementedException(); + + // Only a single non-default value exists, so sorting is a contiguous block of that value placed before the + // default values if it is smaller than the default, and after them otherwise. + final int[] counts = getCounts(); + final int nondefault = counts[0]; + final int defaultLength = _numRows - nondefault; + final int base = _dict.getValue(0, 0, 1) >= _defaultTuple[0] ? defaultLength : 0; + + final int[] offsets = new int[nondefault]; + for(int j = 0; j < nondefault; j++) + offsets[j] = base + j; + + AOffset o = OffsetFactory.createOffset(offsets); + return ColGroupSDCSingle.create(_colIndexes, _numRows, _dict, _defaultTuple, o, counts); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java index d9341bb9ea8..26b3cc4ee37 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java @@ -1071,4 +1071,24 @@ public String toString() { sb.append(_indexes.toString()); return sb.toString(); } + + @Override + public AColGroup sort() { + if(getNumCols() > 1) + throw new NotImplementedException(); + + // Only a single non-default value exists, so sorting is a contiguous block of that value placed before the + // zeros (default) if it is negative, and after the zeros otherwise. + final int[] counts = getCounts(); + final int nondefault = counts[0]; + final int defaultLength = _numRows - nondefault; + final int base = _dict.getValue(0, 0, 1) >= 0 ? defaultLength : 0; + + final int[] offsets = new int[nondefault]; + for(int j = 0; j < nondefault; j++) + offsets[j] = base + j; + + AOffset o = OffsetFactory.createOffset(offsets); + return ColGroupSDCSingleZeros.create(_colIndexes, _numRows, _dict, o, counts); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java index 86cd9866a75..09f222bfeee 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java @@ -1113,4 +1113,49 @@ public String toString() { sb.append(_data); return sb.toString(); } + + @Override + public AColGroup sort() { + if(getNumCols() > 1) + throw new NotImplementedException(); + // TODO restore support for run length encoding. + + final int[] counts = getCounts(); + // get the sort index + final int[] r = _dict.sort(); + + // find default value position. + // todo use binary search for minor improvements. + int defIdx = counts.length; + for(int i = 0; i < r.length; i++) { + if(_dict.getValue(r[i], 0, 1) >= 0) { + defIdx = i; + break; + } + } + + int nondefault = _data.size(); + int defaultLength = _numRows - nondefault; + AMapToData m = MapToFactory.create(nondefault, counts.length); + int[] offsets = new int[nondefault]; + + int off = 0; + for(int i = 0; i < counts.length; i++) { + if(i < defIdx) { + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off; + m.set(off++, r[i]); + } + } + else {// if( i >= defIdx){ + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off + defaultLength; + m.set(off++, r[i]); + } + } + } + + AOffset o = OffsetFactory.createOffset(offsets); + return ColGroupSDCZeros.create(_colIndexes, _numRows, _dict, o, m, counts); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java index e4e98da46f2..611add6480f 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java @@ -56,6 +56,7 @@ import org.apache.sysds.runtime.functionobjects.Multiply; import org.apache.sysds.runtime.functionobjects.ReduceAll; import org.apache.sysds.runtime.functionobjects.ReduceRow; +import org.apache.sysds.runtime.functionobjects.SortIndex; import org.apache.sysds.runtime.functionobjects.ValueFunction; import org.apache.sysds.runtime.instructions.cp.CmCovObject; import org.apache.sysds.runtime.matrix.data.LibMatrixMult; @@ -65,6 +66,7 @@ import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; import org.apache.sysds.runtime.matrix.operators.CMOperator; +import org.apache.sysds.runtime.matrix.operators.ReorgOperator; import org.apache.sysds.runtime.matrix.operators.ScalarOperator; import org.apache.sysds.runtime.matrix.operators.UnaryOperator; import org.apache.sysds.utils.stats.InfrastructureAnalyzer; @@ -1331,4 +1333,14 @@ public String toString() { return sb.toString(); } + + @Override + public AColGroup sort() { + if(getNumCols() > 1) + throw new NotImplementedException(); + // sortOperations builds a value/weight table for quantiles; for an ascending column sort we reorder the rows. + MatrixBlock sorted = _data.reorgOperations(new ReorgOperator(new SortIndex(1, false, false), 1), + new MatrixBlock(), 0, 0, 0); + return create(sorted, _colIndexes); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java index 0c8f07685b6..51e26a3f9d2 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java @@ -293,4 +293,9 @@ public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns){ throw new NotImplementedException("Unimplemented method 'removeEmptyColumns'"); } + + @Override + public AColGroup sort() { + throw new NotImplementedException("Unimplemented method 'sort'"); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/AIdentityDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/AIdentityDictionary.java index 17b382f06ad..a7e715b59b8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/AIdentityDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/AIdentityDictionary.java @@ -19,6 +19,7 @@ package org.apache.sysds.runtime.compress.colgroup.dictionary; +import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.DMLCompressionException; public abstract class AIdentityDictionary extends ACachingMBDictionary { @@ -74,4 +75,9 @@ public double[] productAllRowsToDoubleWithDefault(double[] defaultTuple) { ret[ret.length - 1] *= defaultTuple[i]; return ret; } + + @Override + public int[] sort(){ + throw new NotImplementedException(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java index c26de004373..9a0412145f0 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java @@ -142,4 +142,9 @@ public IDictionary clone() { public IDictionary sliceColumns(IntArrayList selectedColumns, int nCol){ throw new NotImplementedException(); } + + @Override + public int[] sort() { + throw new NotImplementedException(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java index 06bd811b50b..fd8dfd127db 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java @@ -1348,4 +1348,67 @@ public IDictionary sliceColumns(IntArrayList selectedColumns, int nCol) { return getMBDict(nCol).sliceColumns(selectedColumns, nCol); } + @Override + public int[] sort() { + return sort(_values); + } + + protected static int[] sort(double[] values) { + int[] indices = new int[values.length]; + for(int i = 0; i < indices.length; i++) { + indices[i] = i; + } + + // quicksort with stack + int[] stack = new int[values.length]; + + int top = -1; + stack[++top] = 0; + stack[++top] = values.length - 1; + + while(top >= 0) { + int high = stack[top--]; + int low = stack[top--]; + + if(low < high) { + + int pivotIndex = partition(indices, values, low, high); + // Left side + if(pivotIndex - 1 > low) { + stack[++top] = low; + stack[++top] = pivotIndex - 1; + } + + // Right side + if(pivotIndex + 1 < high) { + stack[++top] = pivotIndex + 1; + stack[++top] = high; + } + } + } + + return indices; + } + + private static int partition(int[] indices, double[] values, int low, int high) { + double pivotValue = values[indices[high]]; + int i = low - 1; + + for(int j = low; j < high; j++) { + if(values[indices[j]] <= pivotValue) { + i++; + swap(indices, i, j); + } + } + + swap(indices, i + 1, high); + return i + 1; + } + + private static void swap(int[] arr, int i, int j) { + int tmp = arr[i]; + arr[i] = arr[j]; + arr[j] = tmp; + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java index 726df96d5c8..c8ddfc4883a 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java @@ -1062,4 +1062,13 @@ public IDictionary rightMMPreAggSparse(int numVals, SparseBlock b, IColIndex thi */ public IDictionary sliceColumns(IntArrayList selectedColumns, int nCol); + /** + * Sort the values of this dictionary via an index of how the values mapped previously. + * + * In practice this design means we can reuse the previous dictionary for the resulting column group + * + * @return The sorted index. + */ + public int[] sort(); + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java index c1d2ecc5296..b77eacd2205 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java @@ -2845,4 +2845,13 @@ public static double[] sliceColumns(MatrixBlock mb, IntArrayList selectedColumns return ret; } + @Override + public int[] sort() { + if(_data.getNumColumns() > 1) + throw new RuntimeException("Not supported sort on multicolumn dictionaries"); + _data.sparseToDense(); + + return Dictionary.sort(_data.getDenseBlockValues()); + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java index 2d9075f73c9..c38af0be122 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java @@ -107,4 +107,9 @@ public IDictionary sliceColumns(IntArrayList selectedColumns, int nCol) { throw new RuntimeException("Invalid call"); } + @Override + public int[] sort() { + throw new RuntimeException("Invalid call"); + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java index 30b9d806c1f..6912ee12525 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java @@ -23,6 +23,7 @@ import java.io.DataOutput; import java.io.IOException; +import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -282,4 +283,9 @@ public MatrixBlockDictionary createMBDict(int nCol) { public IDictionary sliceColumns(IntArrayList selectedColumns, int nCol) { return getMBDict().sliceColumns(selectedColumns, nCol); } + + @Override + public int[] sort() { + throw new NotImplementedException(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReorg.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReorg.java index d587d26c3cb..5cfcf223213 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReorg.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReorg.java @@ -32,6 +32,7 @@ import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockMCSR; +import org.apache.sysds.runtime.functionobjects.SortIndex; import org.apache.sysds.runtime.functionobjects.SwapIndex; import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -65,12 +66,19 @@ else if(op.fn instanceof SwapIndex) { // the compressed matrix. https://issues.apache.org/jira/browse/SYSTEMDS-3025 return transpose(cmb, ret, op.getNumThreads()); } - else { - String message = !warned ? op.getClass().getSimpleName() + " -- " + op.fn.getClass().getSimpleName() : null; - MatrixBlock tmp = cmb.getUncompressed(message, op.getNumThreads()); - warned = true; - return tmp.reorgOperations(op, ret, startRow, startColumn, length); + else if(op.fn instanceof SortIndex) { + // order: keep the result compressed when a single column / single group is sorted ascending. + MatrixBlock res = CLALibSort.sort(cmb, (SortIndex) op.fn); + if(res != null) + return res; + // otherwise fall through to the decompression fallback below. } + + // Decompression fallback for reorg operations not supported directly on the compressed representation. + String message = !warned ? op.getClass().getSimpleName() + " -- " + op.fn.getClass().getSimpleName() : null; + MatrixBlock tmp = cmb.getUncompressed(message, op.getNumThreads()); + warned = true; + return tmp.reorgOperations(op, ret, startRow, startColumn, length); } private static MatrixBlock transpose(CompressedMatrixBlock cmb, MatrixBlock ret, int k) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSort.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSort.java new file mode 100644 index 00000000000..b94f11ae723 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSort.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.lib; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.commons.lang3.NotImplementedException; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.functionobjects.SortIndex; +import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.runtime.matrix.operators.ReorgOperator; + +public final class CLALibSort { + + private CLALibSort() { + // private constructor for utility class. + } + + /** + * Sort (order) a compressed matrix in place of the {@code order} built-in, while keeping the result compressed. + * + * The compressed fast-path only supports the case the user can benefit from: a single column held in a single column + * group, sorted ascending and returning the sorted values (not the index permutation). For everything else (multiple + * columns, multiple column groups, descending order, index return, or a column-group encoding without a sort + * implementation) this returns {@code null} so the caller can fall back to a decompressed reorg. + * + * @param mb the compressed matrix to sort + * @param fn the sort specification carried by the reorg operator + * @return the sorted compressed matrix, or {@code null} if the compressed fast-path does not apply + */ + public static MatrixBlock sort(CompressedMatrixBlock mb, SortIndex fn) { + final boolean singleColumn = mb.getNumColumns() == 1 && mb.getColGroups().size() == 1; + if(!singleColumn || fn.getDecreasing() || fn.getIndexReturn()) + return null; + + final AColGroup sorted = sortSingleColumn(mb); + if(sorted == null) + return null; + + final List rg = new ArrayList<>(1); + rg.add(sorted); + return new CompressedMatrixBlock(mb.getNumRows(), mb.getNumColumns(), mb.getNonZeros(), false, rg); + } + + /** + * Compute the sorted value/weight table used by the quantile/median/IQM operations (the {@code sort} / qsort lop), + * exploiting compression to sort the few distinct values instead of all rows. + * + * The compressed fast-path applies to an unweighted sort of a single column held in a single column group. The + * produced table is bit-for-bit identical to {@link MatrixBlock#sortOperations(MatrixValue, MatrixBlock, int)}: a + * {@code (1 + nnz) x 2} matrix holding one row per non-zero value (weight 1) plus a single collapsed row for the + * zeros (weight = number of zeros), sorted ascending by value. For every other case (weights present, multiple + * columns or groups, or an encoding without a sort implementation) it falls back to a decompressed sort. + * + * @param mb the compressed matrix to sort + * @param weights optional per-row weights, or {@code null} + * @param result the result matrix (reused by the fallback) + * @param k the parallelization degree + * @return the sorted value/weight table + */ + public static MatrixBlock sort(CompressedMatrixBlock mb, MatrixValue weights, MatrixBlock result, int k) { + final MatrixBlock w = CompressedMatrixBlock.getUncompressed(weights); + if(w == null && mb.getNumColumns() == 1 && mb.getColGroups().size() == 1) { + final MatrixBlock fast = sortTableSingleColumn(mb, result, k); + if(fast != null) + return fast; + } + + // fallback to uncompressed sort. + return CompressedMatrixBlock.getUncompressed(mb, "sortOperations", k).sortOperations(w, result, k); + } + + private static AColGroup sortSingleColumn(CompressedMatrixBlock mb) { + try { + return mb.getColGroups().get(0).sort(); + } + catch(NotImplementedException e) { + // the column-group encoding does not implement sort -> let the caller decompress. + return null; + } + } + + private static MatrixBlock sortTableSingleColumn(CompressedMatrixBlock mb, MatrixBlock result, int k) { + final long lnnz = mb.getNonZeros(); + if(lnnz < 0) // unknown number of non-zeros, cannot size the table. + return null; + + final AColGroup sorted = sortSingleColumn(mb); + if(sorted == null) + return null; + + final int nRows = mb.getNumRows(); + final int nnz = (int) lnnz; + final int zeroCount = nRows - nnz; + + // decompress the already-sorted single column once (ascending, zeros contiguous). + final List rg = new ArrayList<>(1); + rg.add(sorted); + final MatrixBlock sortedCol = new CompressedMatrixBlock(nRows, 1, lnnz, false, rg).decompress(k); + + // build the value/weight table: one row per non-zero value (weight 1) plus a single + // collapsed zero row (weight = number of zeros). The row order is irrelevant because the + // table is sorted by the reorg below, exactly as MatrixBlock.sortOperations does. + final MatrixBlock tdw = new MatrixBlock(1 + nnz, 2, false); + tdw.allocateDenseBlock(); + int w = 0; + for(int i = 0; i < nRows; i++) { + final double v = sortedCol.get(i, 0); + if(v != 0) { + tdw.set(w, 0, v); + tdw.set(w, 1, 1); + w++; + } + } + tdw.set(w, 0, 0); // collapsed zero row (weight 0 when the column is dense) + tdw.set(w, 1, zeroCount); + + // Emit through the same reorg used by MatrixBlock.sortOperations so the produced table is + // bit-for-bit identical to the uncompressed path, including its (intentionally unmaintained) + // non-zero metadata. This keeps downstream quantile/median consumers and result comparisons + // consistent regardless of whether the input was compressed. + if(result == null) + result = new MatrixBlock(1 + nnz, 2, false); + else + result.reset(1 + nnz, 2, false); + final ReorgOperator rop = new ReorgOperator(new SortIndex(1, false, false), k); + LibMatrixReorg.reorg(tdw, result, rop); + return result; + } +} diff --git a/src/test/java/org/apache/sysds/test/component/compress/CompressedSortTest.java b/src/test/java/org/apache/sysds/test/component/compress/CompressedSortTest.java new file mode 100644 index 00000000000..7ab7187eb78 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/compress/CompressedSortTest.java @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.compress; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.util.ArrayList; +import java.util.EnumSet; +import java.util.List; +import java.util.Random; + +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory; +import org.apache.sysds.runtime.compress.CompressionSettingsBuilder; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; +import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; +import org.apache.sysds.runtime.functionobjects.SortIndex; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.ReorgOperator; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +/** + * Tests the {@code order} (sort) reorg operation on compressed matrices. A single column held in a single column group + * is sorted ascending while staying compressed (via {@link org.apache.sysds.runtime.compress.lib.CLALibSort}); every + * other configuration falls back to a decompressed reorg. In all cases the result must match the uncompressed reference. + */ +public class CompressedSortTest { + + private static final int ROWS = 1000; + + private static final ReorgOperator ASC = new ReorgOperator(new SortIndex(1, false, false), 1); + private static final ReorgOperator DESC = new ReorgOperator(new SortIndex(1, true, false), 1); + private static final ReorgOperator INDEX = new ReorgOperator(new SortIndex(1, false, true), 1); + + @Test + public void sortDDC() { + runCompressed(generate(ROWS, 1, 8, 1.0, 1, 50, 7), CompressionType.DDC); + } + + @Test + public void sortDDCWithNegatives() { + runCompressed(generate(ROWS, 1, 10, 1.0, -25, 25, 13), CompressionType.DDC); + } + + @Test + public void sortSDCZeros() { + runCompressed(generate(ROWS, 1, 6, 0.2, 1, 40, 23), CompressionType.SDC); + } + + @Test + public void sortSDCWithNegatives() { + runCompressed(generate(ROWS, 1, 8, 0.3, -20, 20, 41), CompressionType.SDC); + } + + @Test + public void sortSDCSingleValueZeros() { + // sparse with a single distinct non-zero value -> SDCSingleZeros + runCompressed(generate(ROWS, 1, 1, 0.25, 5, 5, 99), CompressionType.SDC); + } + + @Test + public void sortSDCSingleNonZeroDefault() { + // two distinct non-zero values, one dominant default -> SDCSingle + runCompressed(twoValueColumn(3, 7), CompressionType.SDC); + } + + @Test + public void sortSDCSingleNonZeroDefaultNegative() { + // dominant non-zero default with a single smaller (negative) value -> SDCSingle + runCompressed(twoValueColumn(-4, 7), CompressionType.SDC); + } + + @Test + public void sortConst() { + MatrixBlock mb = new MatrixBlock(ROWS, 1, false); + for(int i = 0; i < ROWS; i++) + mb.set(i, 0, 3); + mb.recomputeNonZeros(); + runCompressed(mb, CompressionType.CONST); + } + + @Test + public void sortUncompressedColGroup() { + // a CompressedMatrixBlock holding a single uncompressed column group must also sort correctly + MatrixBlock raw = generate(ROWS, 1, ROWS, 1.0, -100000, 100000, 5); + List groups = new ArrayList<>(1); + groups.add(ColGroupUncompressed.create(raw, ColIndexFactory.create(1))); + CompressedMatrixBlock cmb = new CompressedMatrixBlock(ROWS, 1, raw.getNonZeros(), false, groups); + + MatrixBlock actual = cmb.reorgOperations(ASC, new MatrixBlock(), 0, 0, 0); + assertTrue("Expected the sorted result to stay compressed", actual instanceof CompressedMatrixBlock); + MatrixBlock expected = raw.reorgOperations(ASC, new MatrixBlock(), 0, 0, 0); + TestUtils.compareMatrices(expected, CompressedMatrixBlock.getUncompressed(actual, "sort"), 0.0, + "sort UNCOMPRESSED colgroup"); + } + + @Test + public void sortDescendingFallback() { + // descending order is not supported by the compressed fast-path -> decompress fallback + runFallback(generate(ROWS, 1, 8, 1.0, 1, 50, 7), CompressionType.DDC, DESC); + } + + @Test + public void sortMultiColumnFallback() { + // order on a multi-column matrix sorts rows by the first column -> decompress fallback + runFallback(generate(ROWS, 3, 6, 1.0, 1, 30, 31), CompressionType.DDC, ASC); + } + + @Test + public void sortIndexReturnFallback() { + // returning the sort permutation (index.return=TRUE) is not supported by the fast-path -> decompress fallback + runFallback(generate(ROWS, 1, 8, 1.0, 1, 50, 7), CompressionType.DDC, INDEX); + } + + @Test + public void sortUnsupportedEncodingFallback() { + // OLE does not implement colgroup sort -> the fast-path declines and the order falls back to decompression + runFallback(generate(ROWS, 1, 8, 0.3, 1, 40, 23), CompressionType.OLE, ASC); + } + + @Test + public void quantileTableDDC() { + runQuantile(generate(ROWS, 1, 8, 1.0, 1, 50, 7), CompressionType.DDC); + } + + @Test + public void quantileTableDDCWithNegatives() { + runQuantile(generate(ROWS, 1, 10, 1.0, -25, 25, 13), CompressionType.DDC); + } + + @Test + public void quantileTableSDCZeros() { + runQuantile(generate(ROWS, 1, 6, 0.2, 1, 40, 23), CompressionType.SDC); + } + + @Test + public void quantileTableSDCWithNegatives() { + runQuantile(generate(ROWS, 1, 8, 0.3, -20, 20, 41), CompressionType.SDC); + } + + @Test + public void quantileTableAllNegative() { + runQuantile(generate(ROWS, 1, 8, 0.4, -50, -1, 57), CompressionType.SDC); + } + + @Test + public void quantileTableAllNegativeDense() { + // dense column with no zeros -> the collapsed zero row carries weight 0 + runQuantile(generate(ROWS, 1, 8, 1.0, -50, -1, 57), CompressionType.DDC); + } + + @Test + public void quantileTableConst() { + MatrixBlock mb = new MatrixBlock(ROWS, 1, false); + for(int i = 0; i < ROWS; i++) + mb.set(i, 0, 3); + mb.recomputeNonZeros(); + runQuantile(mb, CompressionType.CONST); + } + + @Test + public void quantileTableUnsupportedEncodingFallback() { + // OLE does not implement colgroup sort -> the quantile table is built via the decompressed fallback + runQuantile(generate(ROWS, 1, 8, 0.3, 1, 40, 23), CompressionType.OLE); + } + + @Test + public void quantileWeightedFallback() { + MatrixBlock mb = generate(ROWS, 1, 8, 1.0, 1, 50, 7); + MatrixBlock weights = new MatrixBlock(ROWS, 1, false); + Random r = new Random(123); + for(int i = 0; i < ROWS; i++) + weights.set(i, 0, r.nextInt(4) + 1); + weights.recomputeNonZeros(); + MatrixBlock expected = new MatrixBlock(mb).sortOperations(weights, new MatrixBlock(), 1); + + CompressedMatrixBlock cmb = compress(mb, CompressionType.DDC); + MatrixBlock actual = cmb.sortOperations(weights, new MatrixBlock(), 1); + + expected.recomputeNonZeros(); + actual.recomputeNonZeros(); + TestUtils.compareMatrices(expected, actual, 0.0, "weighted sortOperations fallback"); + } + + private void runQuantile(MatrixBlock mb, CompressionType ct) { + // reference is computed on a copy because compression may consume the input. + MatrixBlock expected = new MatrixBlock(mb).sortOperations(null, new MatrixBlock(), 1); + + CompressedMatrixBlock cmb = compress(mb, ct); + assertEquals("Expected a single column group", 1, cmb.getColGroups().size()); + + MatrixBlock actual = cmb.sortOperations(null, new MatrixBlock(), 1); + + // sortOperations leaves the non-zero count unmaintained; recompute so the value comparison reads the data. + expected.recomputeNonZeros(); + actual.recomputeNonZeros(); + + // the value/weight table must match the uncompressed reference bit-for-bit ... + TestUtils.compareMatrices(expected, actual, 0.0, "sortOperations table " + ct); + // ... so the downstream median/quantile picks are identical. + assertEquals("median " + ct, expected.median(), actual.median(), 0.0); + assertEquals("q25 " + ct, expected.pickValue(0.25), actual.pickValue(0.25), 0.0); + assertEquals("q90 " + ct, expected.pickValue(0.90), actual.pickValue(0.90), 0.0); + } + + private void runCompressed(MatrixBlock mb, CompressionType ct) { + CompressedMatrixBlock cmb = compress(mb, ct); + assertEquals("Expected a single column group", 1, cmb.getColGroups().size()); + + MatrixBlock actual = cmb.reorgOperations(ASC, new MatrixBlock(), 0, 0, 0); + assertTrue("Expected the sorted result to stay compressed for " + ct, + actual instanceof CompressedMatrixBlock); + + MatrixBlock expected = mb.reorgOperations(ASC, new MatrixBlock(), 0, 0, 0); + TestUtils.compareMatrices(expected, CompressedMatrixBlock.getUncompressed(actual, "sort"), 0.0, "sort " + ct); + } + + private void runFallback(MatrixBlock mb, CompressionType ct, ReorgOperator op) { + CompressedMatrixBlock cmb = compress(mb, ct); + + MatrixBlock actual = cmb.reorgOperations(op, new MatrixBlock(), 0, 0, 0); + MatrixBlock expected = mb.reorgOperations(op, new MatrixBlock(), 0, 0, 0); + TestUtils.compareMatrices(expected, CompressedMatrixBlock.getUncompressed(actual, "sort"), 0.0, + "sort fallback " + ct); + } + + private static CompressedMatrixBlock compress(MatrixBlock mb, CompressionType ct) { + CompressionSettingsBuilder csb = new CompressionSettingsBuilder().setMinimumCompressionRatio(0.0) + .setValidCompressions(EnumSet.of(ct)); + MatrixBlock compressed = CompressedMatrixBlockFactory.compress(mb, 1, csb).getLeft(); + assertTrue("Expected the input to compress into a " + ct + " backed block", + compressed instanceof CompressedMatrixBlock); + return (CompressedMatrixBlock) compressed; + } + + private static MatrixBlock twoValueColumn(int rare, int dominant) { + MatrixBlock mb = new MatrixBlock(ROWS, 1, false); + for(int i = 0; i < ROWS; i++) + mb.set(i, 0, i % 10 < 3 ? rare : dominant); + mb.recomputeNonZeros(); + return mb; + } + + private static MatrixBlock generate(int rows, int cols, int unique, double sparsity, int min, int max, int seed) { + final MatrixBlock mb = new MatrixBlock(rows, cols, false); + final Random pos = new Random(seed); + final Random val = new Random(seed * 31 + 1); + final double[] values = new double[Math.max(unique, 1)]; + for(int i = 0; i < values.length; i++) + values[i] = min + (max > min ? val.nextInt(max - min + 1) : 0); + for(int i = 0; i < rows; i++) + for(int j = 0; j < cols; j++) + if(pos.nextDouble() < sparsity) + mb.set(i, j, values[pos.nextInt(values.length)]); + mb.recomputeNonZeros(); + return mb; + } +} diff --git a/src/test/java/org/apache/sysds/test/component/compress/CompressedVectorTest.java b/src/test/java/org/apache/sysds/test/component/compress/CompressedVectorTest.java index 8aea861d9ee..b2165201c52 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/CompressedVectorTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/CompressedVectorTest.java @@ -29,9 +29,11 @@ import org.apache.sysds.runtime.compress.CompressionSettingsBuilder; import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType; import org.apache.sysds.runtime.functionobjects.CM; +import org.apache.sysds.runtime.functionobjects.SortIndex; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.CMOperator; import org.apache.sysds.runtime.matrix.operators.CMOperator.AggregateOperationTypes; +import org.apache.sysds.runtime.matrix.operators.ReorgOperator; import org.apache.sysds.test.TestUtils; import org.apache.sysds.test.component.compress.TestConstants.MatrixTypology; import org.apache.sysds.test.component.compress.TestConstants.OverLapping; @@ -147,6 +149,26 @@ public void testSortOperations() { } } + @Test + public void testSort() { + try { + if(!(cmb instanceof CompressedMatrixBlock) || cols != 1) + return; // Input was not compressed then just pass test + + // order() builtin: sort the single column ascending (compressed fast-path or decompress fallback). + ReorgOperator op = new ReorgOperator(new SortIndex(1, false, false), _k); + MatrixBlock ret1 = mb.reorgOperations(op, new MatrixBlock(), 0, 0, 0); + MatrixBlock ret2 = cmb.reorgOperations(op, new MatrixBlock(), 0, 0, 0); + + compareResultMatrices(ret1, ret2, 1); + + } + catch(Exception e) { + e.printStackTrace(); + throw new RuntimeException(bufferedToString + "\n" + e.getMessage(), e); + } + } + @Test public void testReExpandRow() { // does not make much sense since it would entail the compression was on a matrix with one row. diff --git a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java index af21b14206a..81fdcb2a876 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java +++ b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java @@ -481,6 +481,12 @@ protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList s // TODO Auto-generated method stub throw new UnsupportedOperationException("Unimplemented method 'removeEmptyColsSubset'"); } + + @Override + public AColGroup sort() { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'sort'"); + } } private class FakeDictBasedColGroup extends ADictBasedColGroup { @@ -802,5 +808,11 @@ protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList s // TODO Auto-generated method stub throw new UnsupportedOperationException("Unimplemented method 'removeEmptyColsSubset'"); } + + @Override + public AColGroup sort() { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'sort'"); + } } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java b/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java index 3dd48636ae4..58f8bb2df32 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java +++ b/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java @@ -365,6 +365,41 @@ public void sum3() { assertEquals(as, bs, 0.0000001); } + @Test + public void sort() { + if(nCol != 1) + return; // sort is only defined for single-column dictionaries. + + final double[] sa = sortedValues(a); + final double[] sb = sortedValues(b); + if(sa == null && sb == null) + return; // neither encoding implements sort -> nothing to compare. + + if(sa != null && sb != null) + TestUtils.compareMatricesBitAvgDistance(sa, sb, 10, 10, "Sorted values differ between dictionaries"); + } + + /** + * Reorders the dictionary values by {@link IDictionary#sort()} and asserts the permutation yields a non-decreasing + * sequence. Returns {@code null} when the encoding does not implement sort. + */ + private double[] sortedValues(IDictionary d) { + final int[] perm; + try { + perm = d.sort(); + } + catch(NotImplementedException e) { + return null; // encoding does not support sort. + } + assertEquals("sort must return one index per value", nRow, perm.length); + final double[] sorted = new double[perm.length]; + for(int i = 0; i < perm.length; i++) + sorted[i] = d.getValue(perm[i], 0, 1); + for(int i = 1; i < sorted.length; i++) + assertTrue("sort did not produce a non-decreasing sequence", sorted[i - 1] <= sorted[i]); + return sorted; + } + @Test public void getValues() { try {