diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java index e08f731e829..58e33a616ca 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java @@ -58,6 +58,7 @@ import org.apache.sysds.runtime.compress.lib.CLALibMMChain; import org.apache.sysds.runtime.compress.lib.CLALibMatrixMult; import org.apache.sysds.runtime.compress.lib.CLALibMerge; +import org.apache.sysds.runtime.compress.lib.CLALibRemoveEmpty; import org.apache.sysds.runtime.compress.lib.CLALibReplace; import org.apache.sysds.runtime.compress.lib.CLALibReorg; import org.apache.sysds.runtime.compress.lib.CLALibReshape; @@ -871,9 +872,7 @@ public MatrixBlock groupedAggOperations(MatrixValue tgt, MatrixValue wghts, Matr @Override public MatrixBlock removeEmptyOperations(MatrixBlock ret, boolean rows, boolean emptyReturn, MatrixBlock select) { - printDecompressWarning("removeEmptyOperations"); - MatrixBlock tmp = getUncompressed(); - return tmp.removeEmptyOperations(ret, rows, emptyReturn, select); + return CLALibRemoveEmpty.rmempty(this, ret, rows, emptyReturn, select); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java index fbe04c732e6..f30cf8b17b2 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java @@ -29,9 +29,9 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.CompressionSettings; import org.apache.sysds.runtime.compress.CompressionSettingsBuilder; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex.SliceResult; @@ -41,6 +41,7 @@ import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; import org.apache.sysds.runtime.compress.lib.CLALibCombineGroups; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -401,8 +402,9 @@ public final AColGroup rightMultByMatrix(MatrixBlock right) { * @param cru The right hand side column upper * @param nRows The number of rows in this column group */ - public void rightDecompressingMult(MatrixBlock right, MatrixBlock ret, int rl, int ru, int nRows, int crl, int cru){ - throw new NotImplementedException("not supporting right Decompressing Multiply on class: " + this.getClass().getSimpleName()); + public void rightDecompressingMult(MatrixBlock right, MatrixBlock ret, int rl, int ru, int nRows, int crl, int cru) { + throw new NotImplementedException( + "not supporting right Decompressing Multiply on class: " + this.getClass().getSimpleName()); } /** @@ -806,7 +808,7 @@ public final void selectionMultiply(MatrixBlock selection, P[] points, MatrixBlo else denseSelection(selection, points, ret, rl, ru); } - + /** * Get an approximate sparsity of this column group * @@ -981,4 +983,70 @@ public String toString() { sb.append(_colIndexes); return sb.toString(); } + + /** + * Return a new column group containing only the selected rows in the given boolean vector. + * + * Whenever possible only modify the index structure, not the dictionary of the column groups. + * + * @param selectV The selection vector + * @param rOut The number of rows in the output + * @return The new column group + */ + public abstract AColGroup removeEmptyRows(boolean[] selectV, int rOut); + + /** + * Return a new column group containing only the selected columns in the given boolean vector. + * + * Whenever possible only modify the column index, and reduce the dictionaries of the column groups. + * + * @param selectV The selection vector + * @return The new column group, or {@code null} if no column of this group is selected + */ + public AColGroup removeEmptyCols(boolean[] selectV) { + if(!inSelection(selectV)) + return null; + + final IntArrayList selectedColumns = new IntArrayList(); + final IntArrayList newIDs = new IntArrayList(); + int idx = 0; + int idxOwn = 0; + final int end = Math.min(selectV.length, _colIndexes.get(_colIndexes.size() - 1) + 1); + for(int i = 0; i < end; i++) { + + if(i == _colIndexes.get(idxOwn)) { + if(selectV[i]) { + selectedColumns.appendValue(idxOwn); + newIDs.appendValue(idx); + } + idxOwn++; + } + if(selectV[i]) + idx++; + } + + final IColIndex newColumnIDs = ColIndexFactory.create(newIDs); + if(newColumnIDs.size() == _colIndexes.size()) + return copyAndSet(newColumnIDs); + else + return removeEmptyColsSubset(newColumnIDs, selectedColumns); + } + + /** + * Using the selection of columns, slice out those and return in a new column group with the given column indexes. + * Ideally this method should only modify the dictionaries. + * + * @param newColumnIDs the new column indexes + * @param selectedColumns The selected columns of this column group (guaranteed < current number of columns) + * @return A new Column group + */ + protected abstract AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns); + + private boolean inSelection(boolean[] selection) { + for(int i = 0; i < _colIndexes.size(); i++) { + if(selection[_colIndexes.get(i)]) + return true; + } + return false; + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java index 45358c7ce46..d825b91f089 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java @@ -59,8 +59,6 @@ public int getNumValues() { * produce an overhead in cases where the count is calculated, but the overhead will be limited to number of distinct * tuples in the dictionary. * - * The returned counts always contains the number of zero tuples as well if there are some contained, even if they - * are not materialized. * * @return The count of each value in the MatrixBlock. */ @@ -212,6 +210,7 @@ public void clear() { counts = null; } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictBasedColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictBasedColGroup.java index 8f2f0b46055..d114f029df8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictBasedColGroup.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictBasedColGroup.java @@ -402,4 +402,5 @@ protected IDictionary combineDictionaries(int nCol, List right) { public double getSparsity() { return _dict.getSparsity(); } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java index 3de98a1c23f..30de5e120c5 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java @@ -203,6 +203,22 @@ private final void leftMultByMatrixNoPreAggRowsDense(MatrixBlock mb, double[] re */ protected abstract void multiplyScalar(double v, double[] resV, int offRet, AIterator it); + public void decompressToSparseBlock(SparseBlock sb, int rl, int ru, int offR, int offC, AIterator it) { + if(_dict instanceof MatrixBlockDictionary) { + final MatrixBlockDictionary md = (MatrixBlockDictionary) _dict; + final MatrixBlock mb = md.getMatrixBlock(); + // The dictionary is never empty. + if(mb.isInSparseFormat()) + // TODO make sparse decompression where the iterator is known in argument + decompressToSparseBlockSparseDictionary(sb, rl, ru, offR, offC, mb.getSparseBlock()); + else + decompressToSparseBlockDenseDictionaryWithProvidedIterator(sb, rl, ru, offR, offC, mb.getDenseBlockValues(), + it); + } + else + decompressToSparseBlockDenseDictionaryWithProvidedIterator(sb, rl, ru, offR, offC, _dict.getValues(), it); + } + public void decompressToDenseBlock(DenseBlock db, int rl, int ru, int offR, int offC, AIterator it) { if(_dict instanceof MatrixBlockDictionary) { final MatrixBlockDictionary md = (MatrixBlockDictionary) _dict; @@ -223,6 +239,9 @@ public void decompressToDenseBlockDenseDictionary(DenseBlock db, int rl, int ru, decompressToDenseBlockDenseDictionaryWithProvidedIterator(db, rl, ru, offR, offC, _dict.getValues(), it); } + public abstract void decompressToSparseBlockDenseDictionaryWithProvidedIterator(SparseBlock db, int rl, int ru, + int offR, int offC, double[] values, AIterator it); + public abstract void decompressToDenseBlockDenseDictionaryWithProvidedIterator(DenseBlock db, int rl, int ru, int offR, int offC, double[] values, AIterator it); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java index 94137eb6381..7d0b2469ec8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java @@ -46,6 +46,7 @@ import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; import org.apache.sysds.runtime.compress.lib.CLALibLeftMultBy; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -527,7 +528,7 @@ public CmCovObject centralMoment(CMOperator op, int nRows) { @Override public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) { IDictionary d = _dict.rexpandCols(max, ignore, cast, _colIndexes.size()); - if(d == null){ + if(d == null) { if(max <= 0) return null; return ColGroupEmpty.create(max); @@ -758,4 +759,14 @@ public AColGroup combineWithSameIndex(int nRow, int nCol, List right) protected boolean allowShallowIdentityRightMult() { return true; } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + return this; + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + return ColGroupConst.create(newColumnIDs, _dict.sliceColumns(selectedColumns, getNumCols())); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java index a3fdf1fc89f..6ac1544e61e 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java @@ -26,8 +26,6 @@ import java.util.List; import java.util.concurrent.ExecutorService; -import jdk.incubator.vector.DoubleVector; -import jdk.incubator.vector.VectorSpecies; import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; @@ -56,6 +54,7 @@ import org.apache.sysds.runtime.compress.estim.EstimationFactors; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -71,6 +70,9 @@ import org.apache.sysds.runtime.matrix.operators.UnaryOperator; import org.jboss.netty.handler.codec.compression.CompressionException; +import jdk.incubator.vector.DoubleVector; +import jdk.incubator.vector.VectorSpecies; + /** * Class to encapsulate information about a column group that is encoded with dense dictionary encoding (DDC). */ @@ -672,7 +674,8 @@ private void defaultRightDecompressingMult(MatrixBlock right, MatrixBlock ret, i } } - final void vectMM(double aa, double[] b, double[] c, int endT, int jd, int crl, int cru, int offOut, int k, int vLen, DoubleVector vVec) { + final void vectMM(double aa, double[] b, double[] c, int endT, int jd, int crl, int cru, int offOut, int k, int vLen, + DoubleVector vVec) { vVec = vVec.broadcast(aa); final int offj = k * jd; final int end = endT + offj; @@ -1095,6 +1098,21 @@ public AColGroup[] splitReshapePushDown(int multiplier, int nRow, int nColOrg, E return res; } + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + return ColGroupDDC.create(_colIndexes, _dict, _data.removeEmpty(selectV, rOut), null); + } + + @Override + protected boolean allowShallowIdentityRightMult() { + return true; + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + return ColGroupDDC.create(newColumnIDs, _dict.sliceColumns(selectedColumns, getNumCols()), _data, null); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); @@ -1104,11 +1122,6 @@ public String toString() { return sb.toString(); } - @Override - protected boolean allowShallowIdentityRightMult() { - return true; - } - public AColGroup convertToDeltaDDC() { int numCols = _colIndexes.size(); int numRows = _data.size(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java index d2ee8cd6673..6a4a92469d2 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java @@ -40,6 +40,7 @@ import org.apache.sysds.runtime.compress.estim.EstimationFactors; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.compress.utils.Util; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; @@ -546,6 +547,20 @@ protected boolean allowShallowIdentityRightMult() { return false; } + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + return ColGroupDDCFOR.create(_colIndexes, _dict, _data.removeEmpty(selectV, rOut), null, _reference); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + double[] ref = new double[selectedColumns.size()]; + for(int i = 0; i < selectedColumns.size(); i++) { + ref[i] = _reference[selectedColumns.get(i)]; + } + return ColGroupDDCFOR.create(newColumnIDs, _dict.sliceColumns(selectedColumns, getNumCols()), _data, null, ref); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCLZW.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCLZW.java index a3926948b83..c820f875a05 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCLZW.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCLZW.java @@ -1009,4 +1009,17 @@ protected void computeRowProduct(double[] c, int rl, int ru, double[] preAgg) { for(int rix = rl; rix < ru; rix++) c[rix] *= preAgg[it.next()]; } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + ColGroupDDC g = (ColGroupDDC) convertToDDC(); + return g.removeEmptyRows(selectV, rOut); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, + org.apache.sysds.runtime.compress.utils.IntArrayList selectedColumns) { + ColGroupDDC g = (ColGroupDDC) convertToDDC(); + return g.removeEmptyColsSubset(newColumnIDs, selectedColumns); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java index 6d7872fce54..7c0a15e123b 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java @@ -44,6 +44,7 @@ import org.apache.sysds.runtime.compress.estim.EstimationFactors; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -476,4 +477,15 @@ public AColGroup combineWithSameIndex(int nRow, int nCol, List right) return new ColGroupEmpty(combinedIndex); } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut){ + return this; + } + + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns){ + return new ColGroupEmpty(newColumnIDs); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupIO.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupIO.java index f4e9007575c..6add5967fde 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupIO.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupIO.java @@ -94,9 +94,7 @@ public static long getExactSizeOnDisk(List colGroups) { } ret += grp.getExactSizeOnDisk(); } - if(LOG.isWarnEnabled()) - LOG.warn(" duplicate dicts on exact Size on Disk : " + (colGroups.size() - dicts.size()) ); - + return ret; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java index 4e9fffaf718..5ac168b9406 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java @@ -32,6 +32,7 @@ import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -740,4 +741,13 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { throw new NotImplementedException("Unimplemented method 'splitReshape'"); } + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + throw new NotImplementedException("Unimplemented method 'removeEmptyRows'"); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns){ + throw new NotImplementedException("Unimplemented method 'removeEmptyColumns'"); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java index ea6d0f34c2a..5833729c378 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java @@ -26,15 +26,16 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.CompressionSettings; import org.apache.sysds.runtime.compress.bitmap.ABitmap; -import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -731,5 +732,13 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { throw new NotImplementedException("Unimplemented method 'splitReshape'"); } + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + throw new NotImplementedException("Unimplemented method 'removeEmptyRows'"); + } + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns){ + throw new NotImplementedException("Unimplemented method 'removeEmptyColumns'"); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java index 2b4b23792e3..c9fc920a845 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java @@ -39,6 +39,7 @@ import org.apache.sysds.runtime.compress.colgroup.scheme.RLEScheme; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -1190,4 +1191,13 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { throw new NotImplementedException("Unimplemented method 'splitReshape'"); } + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + throw new NotImplementedException("Unimplemented method 'removeEmptyRows'"); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns){ + throw new NotImplementedException("Unimplemented method 'removeEmptyColumns'"); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java index 4340637a737..5522a33e3e0 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java @@ -42,6 +42,7 @@ import org.apache.sysds.runtime.compress.colgroup.offset.AIterator; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.OffsetSliceInfo; +import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.RemoveEmptyOffsetsTmp; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; @@ -508,10 +509,10 @@ protected static AColGroup rexpandCols(int max, boolean ignore, boolean cast, in AOffset indexes, AMapToData data, int[] counts, int def, int nVal) { if(d == null) { - if(def <= 0){ + if(def <= 0) { if(max > 0) return ColGroupEmpty.create(max); - else + else return null; } else if(def > max && max > 0) @@ -873,6 +874,23 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { return res; } + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + final RemoveEmptyOffsetsTmp offsetTmp = _indexes.removeEmptyRows(selectV, rOut); + final AMapToData nm = _data.removeEmpty(offsetTmp.select); + return ColGroupSDC.create(_colIndexes, rOut, _dict, _defaultTuple, offsetTmp.retOffset, nm, null); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + double[] ref = new double[selectedColumns.size()]; + for(int i = 0; i < selectedColumns.size(); i++) { + ref[i] = _defaultTuple[selectedColumns.get(i)]; + } + return ColGroupSDC.create(newColumnIDs, _numRows, _dict.sliceColumns(selectedColumns, getNumCols()), ref, + _indexes, _data, null); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java index 675c1120c38..2ef7f3012bc 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java @@ -39,6 +39,7 @@ import org.apache.sysds.runtime.compress.colgroup.offset.AIterator; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.OffsetSliceInfo; +import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.RemoveEmptyOffsetsTmp; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; @@ -620,6 +621,23 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { return res; } + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + final RemoveEmptyOffsetsTmp offsetTmp = _indexes.removeEmptyRows(selectV, rOut); + final AMapToData nm = _data.removeEmpty(offsetTmp.select); + return ColGroupSDCFOR.create(_colIndexes, rOut, _dict, offsetTmp.retOffset, nm, null, _reference); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + double[] ref = new double[selectedColumns.size()]; + for(int i = 0; i < selectedColumns.size(); i++) { + ref[i] = _reference[selectedColumns.get(i)]; + } + return ColGroupSDCFOR.create(newColumnIDs, _numRows, _dict.sliceColumns(selectedColumns, getNumCols()), _indexes, _data, null, + ref); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java index a954f380a04..0f89e54d975 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java @@ -40,6 +40,7 @@ import org.apache.sysds.runtime.compress.colgroup.offset.AIterator; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.OffsetSliceInfo; +import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.RemoveEmptyOffsetsTmp; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetEmpty; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; @@ -469,10 +470,10 @@ public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) { IDictionary d = _dict.rexpandCols(max, ignore, cast, _colIndexes.size()); final int def = (int) _defaultTuple[0]; if(d == null) { - if(def <= 0){ + if(def <= 0) { if(max > 0) return ColGroupEmpty.create(max); - else + else return null; } else if(def > max && max > 0) @@ -718,6 +719,23 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { return res; } + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + // TODO optimize by not constructing boolean array. + final RemoveEmptyOffsetsTmp offsetTmp = _indexes.removeEmptyRows(selectV, rOut); + return ColGroupSDCSingle.create(_colIndexes, rOut, _dict, _defaultTuple, offsetTmp.retOffset, null); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + double[] ref = new double[selectedColumns.size()]; + for(int i = 0; i < selectedColumns.size(); i++) { + ref[i] = _defaultTuple[selectedColumns.get(i)]; + } + return ColGroupSDCSingle.create(newColumnIDs, _numRows, _dict.sliceColumns(selectedColumns, getNumCols()), ref, + _indexes, null); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java index 9efd0c41098..d9341bb9ea8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java @@ -40,6 +40,7 @@ import org.apache.sysds.runtime.compress.colgroup.offset.AIterator; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.OffsetSliceInfo; +import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.RemoveEmptyOffsetsTmp; import org.apache.sysds.runtime.compress.colgroup.offset.AOffsetIterator; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetEmpty; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; @@ -109,10 +110,8 @@ protected void decompressToDenseBlockDenseDictionary(DenseBlock db, int rl, int return; else if(it.value() >= ru) return; - // _indexes.cacheIterator(it, ru); else { decompressToDenseBlockDenseDictionaryWithProvidedIterator(db, rl, ru, offR, offC, values, it); - // _indexes.cacheIterator(it, ru); } } @@ -238,7 +237,7 @@ protected void decompressToSparseBlockSparseDictionary(SparseBlock ret, int rl, if(it == null) return; else if(it.value() >= ru) - _indexes.cacheIterator(it, ru); + return; else if(ru > last) { final int apos = sb.pos(0); final int alen = sb.size(0) + apos; @@ -277,8 +276,15 @@ protected void decompressToSparseBlockDenseDictionary(SparseBlock ret, int rl, i if(it == null) return; else if(it.value() >= ru) - _indexes.cacheIterator(it, ru); - else if(ru > _indexes.getOffsetToLast()) { + return; + else + decompressToSparseBlockDenseDictionaryWithProvidedIterator(ret, rl, ru, offR, offC, values, it); + } + + @Override + public void decompressToSparseBlockDenseDictionaryWithProvidedIterator(SparseBlock ret, int rl, int ru, int offR, + int offC, double[] values, final AIterator it) { + if(ru > _indexes.getOffsetToLast()) { final int nCol = _colIndexes.size(); final int lastOff = _indexes.getOffsetToLast(); int row = offR + it.value(); @@ -963,7 +969,7 @@ protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock re protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { throw new NotImplementedException(); } - + protected void decompressToDenseBlockTransposedSparseDictionary(DenseBlock db, int rl, int ru, SparseBlock sb) { throw new NotImplementedException(); } @@ -1043,6 +1049,20 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { return res; } + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + // TODO optimize by not constructing boolean array. + final RemoveEmptyOffsetsTmp offsetTmp = _indexes.removeEmptyRows(selectV, rOut); + return ColGroupSDCSingleZeros.create(_colIndexes, rOut, _dict, offsetTmp.retOffset, null); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + + return ColGroupSDCSingleZeros.create(newColumnIDs, _numRows, _dict.sliceColumns(selectedColumns, getNumCols()), + _indexes, null); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java index 69e0f776383..86cd9866a75 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java @@ -45,6 +45,7 @@ import org.apache.sysds.runtime.compress.colgroup.offset.AIterator; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.OffsetSliceInfo; +import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.RemoveEmptyOffsetsTmp; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; @@ -184,8 +185,7 @@ private final void decompressToDenseBlockDenseDictionaryPostAllCols(DenseBlock d final double[] c = db.values(idx); final int off = db.pos(idx); final int offDict = _data.getIndex(it.getDataIndex()) * nCol; - for(int j = 0; j < nCol; j++) - c[off + j] += values[offDict + j]; + decompressSingleRow(values, nCol, c, off, offDict); if(it.value() == lastOff) return; it.next(); @@ -301,13 +301,19 @@ private void decompressToDenseBlockDenseDictionaryPreAllCols(DenseBlock db, int final double[] c = db.values(idx); final int off = db.pos(idx) + offC; final int offDict = _data.getIndex(it.getDataIndex()) * nCol; - for(int j = 0; j < nCol; j++) - c[off + j] += values[offDict + j]; + decompressSingleRow(values, nCol, c, off, offDict); it.next(); } } + private static void decompressSingleRow(double[] values, final int nCol, final double[] c, final int off, + final int offDict) { + final int end = nCol + off; + for(int j = off, k = offDict; j < end; j++, k++) + c[j] += values[k]; + } + @Override protected void decompressToDenseBlockSparseDictionary(DenseBlock db, int rl, int ru, int offR, int offC, SparseBlock sb) { @@ -438,8 +444,16 @@ protected void decompressToSparseBlockDenseDictionary(SparseBlock ret, int rl, i if(it == null) return; else if(it.value() >= ru) - _indexes.cacheIterator(it, ru); - else if(ru > _indexes.getOffsetToLast()) { + return; + else + decompressToSparseBlockDenseDictionaryWithProvidedIterator(ret, rl, ru, offR, offC, values, it); + + } + + @Override + public void decompressToSparseBlockDenseDictionaryWithProvidedIterator(SparseBlock ret, int rl, int ru, int offR, + int offC, double[] values, final AIterator it) { + if(ru > _indexes.getOffsetToLast()) { final int lastOff = _indexes.getOffsetToLast(); final int nCol = _colIndexes.size(); while(true) { @@ -467,7 +481,6 @@ else if(ru > _indexes.getOffsetToLast()) { } _indexes.cacheIterator(it, ru); } - } @Override @@ -899,7 +912,6 @@ public AColGroup morph(CompressionType ct, int nRow) { return super.morph(ct, nRow); } - @Override public void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { final SparseBlock sr = ret.getSparseBlock(); @@ -942,14 +954,14 @@ protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret of = it.next(); } else if(points[c].o < of) - c++; + c++; else of = it.next(); - } - // increment the c pointer until it is pointing at least to last point or is done. - while(c < points.length && points[c].o < last) - c++; - c = processRowDense(points, dr, nCol, c, of, _data.getIndex(it.getDataIndex())); + } + // increment the c pointer until it is pointing at least to last point or is done. + while(c < points.length && points[c].o < last) + c++; + c = processRowDense(points, dr, nCol, c, of, _data.getIndex(it.getDataIndex())); } private int processRowSparse(P[] points, final SparseBlock sr, final int nCol, int c, int of, final int did) { @@ -1078,6 +1090,19 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { return res; } + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + final RemoveEmptyOffsetsTmp offsetTmp = _indexes.removeEmptyRows(selectV, rOut); + final AMapToData nm = _data.removeEmpty(offsetTmp.select); + return ColGroupSDCZeros.create(_colIndexes, rOut, _dict, offsetTmp.retOffset, nm, null); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + return ColGroupSDCZeros.create(newColumnIDs, _numRows, _dict.sliceColumns(selectedColumns, getNumCols()), + _indexes, _data, null); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java index 8d446575975..e4e98da46f2 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java @@ -43,6 +43,7 @@ import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; import org.apache.sysds.runtime.compress.estim.EstimationFactors; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.compress.utils.Util; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; @@ -82,7 +83,8 @@ public class ColGroupUncompressed extends AColGroup { /** * Do not use this constructor of column group uncompressed, instead use the create constructor. - * @param mb The contained data. + * + * @param mb The contained data. * @param colIndexes Column indexes for this Columngroup */ protected ColGroupUncompressed(MatrixBlock mb, IColIndex colIndexes) { @@ -92,14 +94,15 @@ protected ColGroupUncompressed(MatrixBlock mb, IColIndex colIndexes) { /** * Do not use this constructor of column group quantization-fused uncompressed, instead use the create constructor. - * @param mb The contained data. - * @param scaleFactors For quantization-fused compression, scale factors per row, or a single value for entire matrix - * @param colIndexes Column indexes for this Columngroup + * + * @param mb The contained data. + * @param scaleFactors For quantization-fused compression, scale factors per row, or a single value for entire matrix + * @param colIndexes Column indexes for this Columngroup */ protected ColGroupUncompressed(MatrixBlock mb, IColIndex colIndexes, double[] scaleFactors) { super(colIndexes); - // Apply scaling and flooring - // TODO: Use internal matrix prod + // Apply scaling and flooring + // TODO: Use internal matrix prod for(int r = 0; r < mb.getNumRows(); r++) { double scaleFactor = scaleFactors.length == 1 ? scaleFactors[0] : scaleFactors[r]; for(int c = 0; c < mb.getNumColumns(); c++) { @@ -108,7 +111,8 @@ protected ColGroupUncompressed(MatrixBlock mb, IColIndex colIndexes, double[] sc } } _data = mb; - } + } + /** * Create an Uncompressed Matrix Block, where the columns are offset by col indexes. * @@ -130,9 +134,9 @@ public static AColGroup create(MatrixBlock mb, IColIndex colIndexes) { * * It is assumed that the size of the colIndexes and number of columns in mb is matching. * - * @param mb The MB / data to contain in the uncompressed column - * @param colIndexes The column indexes for the group - * @param scaleFactors For quantization-fused compression, scale factors per row, or a single value for entire matrix + * @param mb The MB / data to contain in the uncompressed column + * @param colIndexes The column indexes for the group + * @param scaleFactors For quantization-fused compression, scale factors per row, or a single value for entire matrix * @return An Uncompressed Column group */ public static AColGroup createQuantized(MatrixBlock mb, IColIndex colIndexes, double[] scaleFactors) { @@ -147,14 +151,15 @@ public static AColGroup createQuantized(MatrixBlock mb, IColIndex colIndexes, do /** * Main constructor for a quantization-fused uncompressed ColGroup. * - * @param colIndexes Indices (relative to the current block) of the columns that this column group represents. - * @param rawBlock The uncompressed block; uncompressed data must be present at the time that the constructor is - * called - * @param transposed Says if the input matrix raw block have been transposed. - * @param scaleFactors For quantization-fused compression, scale factors per row, or a single value for entire matrix + * @param colIndexes Indices (relative to the current block) of the columns that this column group represents. + * @param rawBlock The uncompressed block; uncompressed data must be present at the time that the constructor is + * called + * @param transposed Says if the input matrix raw block have been transposed. + * @param scaleFactors For quantization-fused compression, scale factors per row, or a single value for entire matrix * @return AColGroup. */ - public static AColGroup createQuantized(IColIndex colIndexes, MatrixBlock rawBlock, boolean transposed, double[] scaleFactors) { + public static AColGroup createQuantized(IColIndex colIndexes, MatrixBlock rawBlock, boolean transposed, + double[] scaleFactors) { // special cases if(rawBlock.isEmptyBlock(false)) // empty input @@ -187,22 +192,24 @@ else if(!transposed && colIndexes.size() == rawBlock.getNumColumns()) final int n = colIndexes.size(); if(transposed) { - if (scaleFactors.length == 1) { + if(scaleFactors.length == 1) { for(int i = 0; i < m; i++) for(int j = 0; j < n; j++) mb.appendValue(i, j, Math.floor(rawBlock.get(i, colIndexes.get(j)) * scaleFactors[0])); - } else { + } + else { for(int i = 0; i < m; i++) for(int j = 0; j < n; j++) mb.appendValue(i, j, Math.floor(rawBlock.get(i, colIndexes.get(j)) * scaleFactors[j])); } } else { - if (scaleFactors.length == 1) { + if(scaleFactors.length == 1) { for(int i = 0; i < m; i++) for(int j = 0; j < n; j++) mb.appendValue(i, j, Math.floor(rawBlock.get(i, colIndexes.get(j)) * scaleFactors[0])); - } else { + } + else { for(int i = 0; i < m; i++) for(int j = 0; j < n; j++) mb.appendValue(i, j, Math.floor(rawBlock.get(i, colIndexes.get(j)) * scaleFactors[i])); @@ -1075,7 +1082,6 @@ public AColGroup morph(CompressionType ct, int nRow) { return comp.get(0).copyAndSet(_colIndexes); } - @Override public void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { if(_data.isInSparseFormat()) @@ -1092,7 +1098,6 @@ protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret denseSelectionDenseColumnGroup(selection, ret, rl, ru); } - private void sparseSelectionSparseColumnGroup(MatrixBlock selection, MatrixBlock ret, int rl, int ru) { final SparseBlock sb = selection.getSparseBlock(); @@ -1192,7 +1197,7 @@ public AColGroup reduceCols() { else return new ColGroupUncompressed(mb, ColIndexFactory.createI(0)); } - + @Override public void decompressToDenseBlockTransposed(DenseBlock db, int rl, int ru) { if(_data.isInSparseFormat()) @@ -1289,11 +1294,25 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { for(int i = 0; i < multiplier; i++) for(int j = 0; j < s; j++) newColumns[i * s + j] = _colIndexes.get(j) + nColOrg * i; - MatrixBlock newData = _data.reshape(nRow/ multiplier, s * multiplier, true); - return new AColGroup[]{create(newData,ColIndexFactory.create(newColumns))}; + MatrixBlock newData = _data.reshape(nRow / multiplier, s * multiplier, true); + return new AColGroup[] {create(newData, ColIndexFactory.create(newColumns))}; // throw new NotImplementedException("Unimplemented method 'splitReshape'"); } + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + MatrixBlock tmp = new MatrixBlock(); + tmp = LibMatrixReorg.removeEmptyRows(_data, tmp, false, false, selectV, rOut); + return ColGroupUncompressed.create(_colIndexes, tmp, false); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + double[] vals = MatrixBlockDictionary.sliceColumns(_data, selectedColumns); + MatrixBlock ret = new MatrixBlock(_data.getNumRows(), selectedColumns.size(), vals); + return ColGroupUncompressed.create(newColumnIDs, ret, false); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java index 08cbab30bcc..0c8f07685b6 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java @@ -19,11 +19,13 @@ package org.apache.sysds.runtime.compress.colgroup; +import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -282,4 +284,13 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { throw new UnsupportedOperationException("Unimplemented method 'splitReshape'"); } + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + throw new NotImplementedException("Unimplemented method 'removeEmptyRows'"); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns){ + throw new NotImplementedException("Unimplemented method 'removeEmptyColumns'"); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java index d667e76ed5e..c26de004373 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java @@ -24,6 +24,7 @@ import java.io.IOException; import org.apache.commons.lang3.NotImplementedException; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.functionobjects.Divide; import org.apache.sysds.runtime.functionobjects.Multiply; import org.apache.sysds.runtime.matrix.operators.ScalarOperator; @@ -136,4 +137,9 @@ public boolean equals(IDictionary o) { public IDictionary clone() { throw new NotImplementedException(); } + + @Override + public IDictionary sliceColumns(IntArrayList selectedColumns, int nCol){ + throw new NotImplementedException(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java index e94cbd7c570..06bd811b50b 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java @@ -31,6 +31,7 @@ import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.compress.utils.Util; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.functionobjects.Builtin; @@ -1341,4 +1342,10 @@ public IDictionary append(double[] row) { return new Dictionary(retV); } + @Override + public IDictionary sliceColumns(IntArrayList selectedColumns, int nCol) { + // TODO: make specialized version for this. + return getMBDict(nCol).sliceColumns(selectedColumns, nCol); + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java index 49330ba2748..726df96d5c8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java @@ -25,6 +25,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.functionobjects.Builtin; @@ -1051,4 +1052,14 @@ public IDictionary rightMMPreAggSparse(int numVals, SparseBlock b, IColIndex thi * @return The nonzero count of each column in the dictionary. */ public int[] countNNZZeroColumns(int[] counts); + + /** + * Slice out the selected columns given of this encoded group. + * + * @param selectedColumns The columns to slice out and return as a new matrix. + * @param nCol The number of columns in this dictionary. + * @return The returned matrix + */ + public IDictionary sliceColumns(IntArrayList selectedColumns, int nCol); + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java index 40e1b065653..c2540de959a 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java @@ -27,6 +27,7 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockFactory; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -540,9 +541,13 @@ public String getString(int colIndexes) { return "IdentityMatrix of size: " + nRowCol + " with empty: " + withEmpty; } + @Override + public IDictionary sliceColumns(IntArrayList selectedColumns, int nCol){ + return getMBDict().sliceColumns(selectedColumns, nCol); + } + @Override public String toString() { return "IdentityMatrix of size: " + nRowCol + " with empty: " + withEmpty; } - } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java index df702524d55..c7f642edfd0 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java @@ -27,6 +27,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -310,6 +311,11 @@ public String getString(int colIndexes) { return toString(); } + @Override + public IDictionary sliceColumns(IntArrayList selectedColumns, int nCol){ + return getMBDict().sliceColumns(selectedColumns, nCol); + } + @Override public String toString() { return "IdentityMatrixSlice of size: " + nRowCol + " l " + l + " u " + u; diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java index 71a4112f157..c1d2ecc5296 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java @@ -27,8 +27,6 @@ import java.util.Arrays; import java.util.Set; -import jdk.incubator.vector.DoubleVector; -import jdk.incubator.vector.VectorSpecies; import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.indexes.ArrayIndex; @@ -36,6 +34,7 @@ import org.apache.sysds.runtime.compress.colgroup.indexes.RangeIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.SingleIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.TwoIndex; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.compress.utils.Util; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.DenseBlockFP64; @@ -61,6 +60,9 @@ import org.apache.sysds.runtime.matrix.operators.ScalarOperator; import org.apache.sysds.runtime.matrix.operators.UnaryOperator; +import jdk.incubator.vector.DoubleVector; +import jdk.incubator.vector.VectorSpecies; + public class MatrixBlockDictionary extends ADictionary { private static final long serialVersionUID = 2535887782150955098L; @@ -2801,4 +2803,46 @@ private void SparseAdd(int sPos, int sEnd, double[] ret, int offOut, int[] sIdx, } } + @Override + public IDictionary sliceColumns(IntArrayList selectedColumns, int nCol) { + + final double[] ret = sliceColumns(_data, selectedColumns); + + return new Dictionary(ret); + } + + public static double[] sliceColumns(MatrixBlock mb, IntArrayList selectedColumns) { + // TODO: Optimize to allow sparse outputs. and change output type to MatrixBlock. + final int outC = selectedColumns.size(); + final int nRow = mb.getNumRows(); + if((long) nRow * outC > (long) Integer.MAX_VALUE) + throw new NotImplementedException("Not supported large output blocks for slicing dictionary columns"); + final double[] ret = new double[nRow * outC]; + if(mb.isEmpty()) + return ret; + + // Read through the current representation without mutating the (shared, immutable) dictionary block. + if(mb.isInSparseFormat()) { + final SparseBlock sb = mb.getSparseBlock(); + for(int i = 0; i < nRow; i++) { + if(sb.isEmpty(i)) + continue; + final int offOut = i * outC; + for(int j = 0; j < outC; j++) + ret[offOut + j] = sb.get(i, selectedColumns.get(j)); + } + } + else { + final DenseBlock db = mb.getDenseBlock(); + for(int i = 0; i < nRow; i++) { + final double[] vals = db.values(i); + final int offIn = db.pos(i); + final int offOut = i * outC; + for(int j = 0; j < outC; j++) + ret[offOut + j] = vals[offIn + selectedColumns.get(j)]; + } + } + return ret; + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java index f5746647a37..2d9075f73c9 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java @@ -23,6 +23,7 @@ import java.io.DataOutput; import java.io.IOException; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.io.IOUtilFunctions; public class PlaceHolderDict extends ADictionary { @@ -101,4 +102,9 @@ public DictType getDictType() { throw new RuntimeException("invalid to get dictionary type for PlaceHolderDict"); } + @Override + public IDictionary sliceColumns(IntArrayList selectedColumns, int nCol) { + throw new RuntimeException("Invalid call"); + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java index 6802d920b49..30b9d806c1f 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java @@ -23,6 +23,7 @@ import java.io.DataOutput; import java.io.IOException; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.utils.MemoryEstimates; @@ -277,4 +278,8 @@ public MatrixBlockDictionary createMBDict(int nCol) { return new MatrixBlockDictionary(mb); } + @Override + public IDictionary sliceColumns(IntArrayList selectedColumns, int nCol) { + return getMBDict().sliceColumns(selectedColumns, nCol); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java index 5fc2acaea7a..83a74972db7 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java @@ -30,6 +30,7 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.IMapToDataGroup; @@ -39,6 +40,7 @@ import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory.MAP_TYPE; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.compress.colgroup.offset.AOffsetIterator; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -1041,4 +1043,36 @@ public String toString() { sb.append("]"); return sb.toString(); } + + public AMapToData removeEmpty(final boolean[] selectV, final int rOut) { + final int s = size(); + int trueCount = 0; + for(int i = 0; i < s; i++) + if(selectV[i]) + trueCount++; + if(trueCount != rOut) + throw new DMLRuntimeException( + "Invalid removeEmpty: number of selected rows " + trueCount + " does not match argument rOut " + rOut); + + final AMapToData ret = MapToFactory.create(rOut, getUnique()); + int t = 0; + for(int i = 0; i < s; i++) + if(selectV[i]) + ret.set(t++, getIndex(i)); + return ret; + } + + /** + * Use the offsets of the select vector to choose which values to keep. + * + * @param select The row indexes to keep + * @return A New MapToData + */ + public AMapToData removeEmpty(IntArrayList select) { + final int s = select.size(); + final AMapToData ret = MapToFactory.create(s, getUnique()); + for(int i = 0; i < s; i++) + ret.set(i, getIndex(select.get(i))); + return ret; + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AIterator.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AIterator.java index 45c78dd3abd..a809afccd3d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AIterator.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AIterator.java @@ -71,8 +71,8 @@ public boolean isNotOver(int ub) { /** * Get the current data index associated with the index returned from value. * - * This index points to a position int the mapToData object, that then inturn can be used to lookup the dictionary - * entry in ADictionary. + * This index points to a position in the AMapToData object, that can be used to lookup the dictionary entry in + * ADictionary. * * @return The Data Index. */ diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java index a961c1188bf..f65876b7f37 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java @@ -586,7 +586,7 @@ public OffsetSliceInfo slice(int l, int u) { else return new OffsetSliceInfo(0, s, moveIndex(l)); } - else if (u < first) + else if(u < first) return emptySlice(); final AIterator it = getIteratorSkipCache(l); @@ -781,6 +781,41 @@ public AOffset reverse(int numRows) { return OffsetFactory.createOffset(newOff); } + public RemoveEmptyOffsetsTmp removeEmptyRows(boolean[] selectV, int rOut) { + IntArrayList newOff = new IntArrayList(); + IntArrayList selectMTmp = new IntArrayList(); + + final AIterator it = getIterator(); + final int last = getOffsetToLast(); + int t = 0; + int o = 0; + while(it.value() < last) { + while(t < it.value()) { + if(selectV[t]) + o++; + t++; + } + if(selectV[it.value()]) { + newOff.appendValue(o); + selectMTmp.appendValue(it.getDataIndex()); + o++; + t++; + } + it.next(); + } + while(t < last) { + if(selectV[t]) + o++; + t++; + } + if(selectV[last]) { + newOff.appendValue(o); + selectMTmp.appendValue(it.getDataIndex()); + } + + return new RemoveEmptyOffsetsTmp(OffsetFactory.createOffset(newOff), selectMTmp); + } + /** * Offset slice info containing the start and end index an offset that contains the slice, and an new AOffset * containing only the sliced elements @@ -810,6 +845,16 @@ public String toString() { } + public static final class RemoveEmptyOffsetsTmp { + public final AOffset retOffset; + public final IntArrayList select; + + protected RemoveEmptyOffsetsTmp(AOffset retOffset, IntArrayList select) { + this.retOffset = retOffset; + this.select = select; + } + } + private static class OffsetCache { private final AIterator it; private final int row; @@ -841,4 +886,5 @@ public String toString() { return "r" + row + " d " + dataIndex + " o " + offIndex + "\n"; } } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetEmpty.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetEmpty.java index acd3b0d04eb..866168ded2f 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetEmpty.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetEmpty.java @@ -76,6 +76,10 @@ public int getOffsetToLast() { public long getInMemorySize() { return estimateInMemorySize(); } + @Override + public boolean equals(AOffset b) { + return b instanceof OffsetEmpty; + } public static long estimateInMemorySize() { return 16; // object header diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRemoveEmpty.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRemoveEmpty.java new file mode 100644 index 00000000000..3755e4040e7 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRemoveEmpty.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.lib; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.commons.lang3.NotImplementedException; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.DataConverter; + +public class CLALibRemoveEmpty { + protected static final Log LOG = LogFactory.getLog(CLALibRemoveEmpty.class.getName()); + + /** + * CP rmempty operation (single input, single output matrix) + * + * @param in The input matrix + * @param ret The output matrix + * @param rows If we are removing based on rows, or columns. + * @param emptyReturn Return row/column of zeros for empty input. + * @param select An optional selection vector, to remove based on rather than empty rows or columns + * @return The result MatrixBlock, can be a different object that the caller used. + */ + public static MatrixBlock rmempty(CompressedMatrixBlock in, MatrixBlock ret, boolean rows, boolean emptyReturn, + MatrixBlock select) { + if(ret == null) + ret = new MatrixBlock(); + MatrixBlock ret2 = LibMatrixReorg.rmemptyEarlyAbort(in, ret, rows, emptyReturn, select); + if(ret2 != null) + return ret2; + + if(rows) + return rmEmptyRows(in, ret, emptyReturn, select); + else + return rmEmptyCols(in, ret, emptyReturn, select); + } + + private static MatrixBlock rmEmptyCols(CompressedMatrixBlock in, MatrixBlock ret, boolean emptyReturn, + MatrixBlock select) { + if(select == null) + return fallback(in, false, emptyReturn, select, ret); + + int cOut = (int) select.getNonZeros(); + if(cOut == -1) + cOut = (int) select.recomputeNonZeros(); + if(cOut == 0){ + ret.reset(in.getNumRows(), !emptyReturn ? 0 : 1); + return ret; + } + + final boolean[] selectV = DataConverter + .convertToBooleanVector(CompressedMatrixBlock.getUncompressed(select, "decompressing selection in rmempty")); + + final List inG = in.getColGroups(); + final List retG = new ArrayList<>(inG.size()); + try { + for(int i = 0; i < inG.size(); i++) { + AColGroup tmp = inG.get(i).removeEmptyCols(selectV); + if(tmp != null) + retG.add(tmp); + } + } + catch(NotImplementedException e) { + // Some column-group encodings (e.g. OLE/RLE) do not support index-only column removal; + // decompress and remove on the uncompressed representation instead of failing. + return fallback(in, false, emptyReturn, select, ret); + } + return new CompressedMatrixBlock(in.getNumRows(), cOut, -1, in.isOverlapping(), retG); + + } + + private static MatrixBlock rmEmptyRows(CompressedMatrixBlock in, MatrixBlock ret, boolean emptyReturn, + MatrixBlock select) { + if(select == null) + return fallback(in, true, emptyReturn, select, ret); + + select = CompressedMatrixBlock.getUncompressed(select, "decompressing selection in rmempty"); + + int rOut = (int) select.getNonZeros(); + if(rOut == -1) + rOut = (int) select.recomputeNonZeros(); + if(rOut == 0){ + ret.reset(!emptyReturn ? 0 : 1, in.getNumColumns()); + return ret; + } + + // TODO: add optimization to avoid linear scan and make selectV indexes, if selection is small relative to number + // of rows + // TODO: add decompress to boolean vector. + final boolean[] selectV = DataConverter.convertToBooleanVector(select); + + + + final List inG = in.getColGroups(); + final List retG = new ArrayList<>(inG.size()); + try { + for(int i = 0; i < inG.size(); i++) { + retG.add(inG.get(i).removeEmptyRows(selectV, rOut)); + } + } + catch(NotImplementedException e) { + // Some column-group encodings (e.g. OLE/RLE) do not support index-only row removal; + // decompress and remove on the uncompressed representation instead of failing. + return fallback(in, true, emptyReturn, select, ret); + } + + return new CompressedMatrixBlock(rOut, in.getNumColumns(), -1, in.isOverlapping(), retG); + } + + private static MatrixBlock fallback(CompressedMatrixBlock in, boolean rows, boolean emptyReturn, MatrixBlock select, + MatrixBlock ret) { + if(LOG.isDebugEnabled()) + LOG.debug("Decompressing for removeEmptyOperations with select: " + (select != null) + " rows: " + rows); + MatrixBlock tmp = CompressedMatrixBlock.getUncompressed(in); + MatrixBlock select2 = CompressedMatrixBlock.getUncompressed(select); + return LibMatrixReorg.rmemptyUnsafe(tmp, ret, rows, emptyReturn, select2); + } + +} diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java index 040a4e1dcb1..5f478979104 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java @@ -964,16 +964,45 @@ public static List reshape(IndexedMatrixValue in, DataCharac } /** - * CP rmempty operation (single input, single output matrix) + * CP rmempty operation (single input, single output matrix) * - * @param in input matrix - * @param ret output matrix - * @param rows ? - * @param emptyReturn return row/column of zeros for empty input - * @param select ? - * @return matrix block + * @param in The input matrix + * @param ret The output matrix + * @param rows If we are removing based on rows, or columns. + * @param emptyReturn Return row/column of zeros for empty input + * @param select An optional selection vector, to remove based on rather than empty rows or columns + * @return The result MatrixBlock */ public static MatrixBlock rmempty(MatrixBlock in, MatrixBlock ret, boolean rows, boolean emptyReturn, MatrixBlock select) { + if(ret == null) + ret = new MatrixBlock(); + MatrixBlock ret2 = rmemptyEarlyAbort(in, ret, rows, emptyReturn, select); + if(ret2 != null ) + return ret2; + // core removeEmpty + return rmemptyUnsafe(in, ret, rows, emptyReturn, select); + } + + public static MatrixBlock rmemptyUnsafe(MatrixBlock in, MatrixBlock ret, boolean rows, boolean emptyReturn, + MatrixBlock select) { + if( rows ) + return removeEmptyRows(in, ret, select, emptyReturn); + else // cols + return removeEmptyColumns(in, ret, select, emptyReturn); + } + + /** + * Handle the early-termination cases of removeEmpty that do not require scanning for empty rows/columns. + * + * @param in The input matrix + * @param ret The output matrix, reused for the empty-input case + * @param rows If removing based on rows, or columns + * @param emptyReturn Return a row/column of zeros for empty input + * @param select An optional selection vector + * @return The early-abort result, or {@code null} if no early termination applies and the caller must continue. + * For the select-all case the returned block is the input {@code in} itself (a shallow alias, not a copy). + */ + public static MatrixBlock rmemptyEarlyAbort(MatrixBlock in, MatrixBlock ret, boolean rows, boolean emptyReturn, MatrixBlock select){ //check for empty inputs //(the semantics of removeEmpty are that for an empty m-by-n matrix, the output //is an empty 1-by-n or m-by-1 matrix because we don't allow matrices with dims 0) @@ -990,12 +1019,8 @@ public static MatrixBlock rmempty(MatrixBlock in, MatrixBlock ret, boolean rows, if( select != null && (select.nonZeros == (rows?in.rlen:in.clen)) ) { return in; } - - // core removeEmpty - if( rows ) - return removeEmptyRows(in, ret, select, emptyReturn); - else //cols - return removeEmptyColumns(in, ret, select, emptyReturn); + + return null; } /** @@ -3620,6 +3645,25 @@ private static MatrixBlock removeEmptyRows(MatrixBlock in, MatrixBlock ret, Matr rlen2 = (int)select.getNonZeros(); } + return removeEmptyRows(in, ret, emptyReturn, select == null, flags, rlen2); + } + + /** + * Remove selected rows, based on the boolean array given. Note this function is internal use only, and require a + * boolean vector to be constructed first. + * + * @param in Input to remove rows from + * @param ret Output to assign the result into + * @param emptyReturn If the output is allowed to be empty. + * @param selectNull If the original caller did not have a selection matrix. + * @param flags The boolean selection vector to specify which rows to keep. + * @param rlen2 The number of true values in the flags argument. + * @return Another reference to the ret matrix input argument. + */ + public static MatrixBlock removeEmptyRows(MatrixBlock in, MatrixBlock ret, boolean emptyReturn, boolean selectNull, + boolean[] flags, int rlen2) { + final int m = in.rlen; + final int n = in.clen; //Step 2: reset result and copy rows //dense stays dense if correct input representation (but robust for any input), //sparse might be dense/sparse @@ -3629,7 +3673,7 @@ private static MatrixBlock removeEmptyRows(MatrixBlock in, MatrixBlock ret, Matr if( in.isEmptyBlock(false) ) return ret; - if( SHALLOW_COPY_REORG && m == rlen2 && select == null ) { + if( SHALLOW_COPY_REORG && m == rlen2 && selectNull ) { // the condition m==rlen2 is not enough with non-empty 1-row input but empty // 1-row select vector because if emptyReturn should output a single empty row ret.sparse = in.sparse; @@ -3672,7 +3716,7 @@ else if( !in.sparse && !ret.sparse ) //DENSE <- DENSE } //check sparsity - ret.nonZeros = (select==null) ? + ret.nonZeros = (selectNull) ? in.nonZeros : ret.recomputeNonZeros(); ret.examSparsity(); diff --git a/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java b/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java index d36c6167cf7..934a5458557 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java @@ -21,6 +21,7 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.junit.Assume.assumeTrue; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -687,4 +688,113 @@ public void toRDDAndBack(int blen) { fail(e.getMessage()); } } + + @Test + public void removeEmptyOperationsBase1() { + removeEmptyOperations(false, false, null); + } + + @Test + public void removeEmptyOperationsBase2() { + removeEmptyOperations(true, false, null); + } + + @Test + public void removeEmptyOperationsBase3() { + removeEmptyOperations(false, true, null); + } + + @Test + public void removeEmptyOperationsBase4() { + removeEmptyOperations(true, true, null); + } + + @Test + public void removeEmptyOperationsSelect1() { + // limit to smaller row counts to keep the dense selection vector generation cheap + assumeTrue(rows < 5000); + MatrixBlock s = TestUtils.generateTestMatrixBlock(rows, 1, 1, 1, 0.05, 321); + removeEmptyOperations(true, false, s); + } + + @Test + public void removeEmptyOperationsSelect2() { + // limit to smaller row counts to keep the dense selection vector generation cheap + assumeTrue(rows < 5000); + MatrixBlock s = TestUtils.generateTestMatrixBlock(1, cols, 1, 1, 0.5, 321); + removeEmptyOperations(false, false, s); + } + + @Test + public void removeEmptyOperationsSelectRowsEmptyReturn() { + assumeTrue(rows < 5000); + MatrixBlock s = TestUtils.generateTestMatrixBlock(rows, 1, 1, 1, 0.05, 321); + removeEmptyOperations(true, true, s); + } + + @Test + public void removeEmptyOperationsSelectColsEmptyReturn() { + assumeTrue(rows < 5000); + MatrixBlock s = TestUtils.generateTestMatrixBlock(1, cols, 1, 1, 0.5, 321); + removeEmptyOperations(false, true, s); + } + + @Test + public void removeEmptyOperationsSelectRowsDense() { + assumeTrue(rows < 5000); + MatrixBlock s = TestUtils.generateTestMatrixBlock(rows, 1, 1, 1, 0.6, 654); + removeEmptyOperations(true, false, s); + } + + @Test + public void removeEmptyOperationsSelectAllRows() { + assumeTrue(rows < 5000); + MatrixBlock s = TestUtils.generateTestMatrixBlock(rows, 1, 1, 1, 1.0, 13); + removeEmptyOperations(true, false, s); + } + + @Test + public void removeEmptyOperationsSelectAllCols() { + assumeTrue(rows < 5000); + MatrixBlock s = TestUtils.generateTestMatrixBlock(1, cols, 1, 1, 1.0, 13); + removeEmptyOperations(false, false, s); + } + + @Test + public void removeEmptyOperationsSelectNoRows() { + assumeTrue(rows < 5000); + removeEmptyOperations(true, false, new MatrixBlock(rows, 1, true)); + } + + @Test + public void removeEmptyOperationsSelectNoRowsEmptyReturn() { + assumeTrue(rows < 5000); + removeEmptyOperations(true, true, new MatrixBlock(rows, 1, true)); + } + + @Test + public void removeEmptyOperationsSelectNoCols() { + assumeTrue(rows < 5000); + removeEmptyOperations(false, false, new MatrixBlock(1, cols, true)); + } + + @Test + public void removeEmptyOperationsSelectNoColsEmptyReturn() { + assumeTrue(rows < 5000); + removeEmptyOperations(false, true, new MatrixBlock(1, cols, true)); + } + + public void removeEmptyOperations(boolean rows, boolean emptyReturn, MatrixBlock select) { + try { + MatrixBlock a = cmb.removeEmptyOperations(null, rows, emptyReturn, select); + MatrixBlock b = mb.removeEmptyOperations(null, rows, emptyReturn, select); + compareResultMatrices(b, a, 0); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + + } + } diff --git a/src/test/java/org/apache/sysds/test/component/compress/CompressedRemoveEmptyColSubsetTest.java b/src/test/java/org/apache/sysds/test/component/compress/CompressedRemoveEmptyColSubsetTest.java new file mode 100644 index 00000000000..e6e28016ee2 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/compress/CompressedRemoveEmptyColSubsetTest.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.compress; + +import static org.junit.Assert.assertTrue; +import static org.junit.Assume.assumeTrue; + +import java.util.Collections; +import java.util.EnumSet; +import java.util.Random; + +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory; +import org.apache.sysds.runtime.compress.CompressionSettingsBuilder; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType; +import org.apache.sysds.runtime.compress.colgroup.ColGroupSDC; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +/** + * Exercises {@code removeEmptyOperations} with a column-selection vector that keeps a strict subset of a + * multi-column column group. This drives the per-encoding {@code removeEmptyColsSubset} dictionary slicing + * paths (SDC / SDC-single) and the {@link org.apache.commons.lang3.NotImplementedException} fallback for encodings that + * do not implement index-only column removal (RLE/OLE). + */ +public class CompressedRemoveEmptyColSubsetTest { + + private static final int ROWS = 500; + private static final int COLS = 3; + + @Test + public void colSubsetSDC() { + // Multiple distinct non-default tuples -> ColGroupSDC. + runColSubset(buildIdentical(new double[] {3.0, 5.0, 9.0}, 1), CompressionType.SDC); + } + + @Test + public void colSubsetSDCSingle() { + // A single distinct non-default tuple -> ColGroupSDCSingle. + runColSubset(buildIdentical(new double[] {3.0}, 2), CompressionType.SDC); + } + + @Test + public void colSubsetDDC() { + runColSubset(buildIdentical(new double[] {3.0, 5.0, 9.0}, 3), CompressionType.DDC); + } + + @Test + public void colSubsetRLEFallback() { + runColSubset(buildIdentical(new double[] {3.0, 5.0, 9.0}, 4), CompressionType.RLE); + } + + @Test + public void colSubsetOLEFallback() { + runColSubset(buildIdentical(new double[] {3.0, 5.0, 9.0}, 5), CompressionType.OLE); + } + + @Test + public void colSubsetSDCFOR() { + // SDCFOR cannot be forced at the planner level, so build a multi-column SDC group and sparsify it to the + // frame-of-reference variant (the production path) before slicing a strict column subset. + MatrixBlock mb = buildIdentical(new double[] {3.0, 5.0, 9.0}, 8); + CompressedMatrixBlock sdc = compressForced(mb, CompressionType.SDC); + AColGroup g = sdc.getColGroups().get(0); + assumeTrue("Expected a multi-column ColGroupSDC to sparsify", g instanceof ColGroupSDC && g.getNumCols() > 1); + AColGroup forGroup = ((ColGroupSDC) g).sparsifyFOR(); + assumeTrue("Expected an SDCFOR group after sparsify", forGroup.getCompType() == CompressionType.SDCFOR); + + CompressedMatrixBlock cmb = new CompressedMatrixBlock(mb.getNumRows(), mb.getNumColumns(), -1, false, + Collections.singletonList(forGroup)); + + MatrixBlock select = new MatrixBlock(1, COLS, false); + select.set(0, 0, 1); + select.set(0, 2, 1); + MatrixBlock actual = cmb.removeEmptyOperations(null, false, false, select); + select = new MatrixBlock(1, COLS, false); + select.set(0, 0, 1); + select.set(0, 2, 1); + MatrixBlock expected = mb.removeEmptyOperations(null, false, false, select); + TestUtils.compareMatrices(expected, actual, 0.0, "removeEmpty col subset for SDCFOR"); + } + + /** Column selection vector with unknown (-1) non-zero count, forcing the recompute branch. */ + @Test + public void colSubsetUnknownNnz() { + MatrixBlock mb = buildIdentical(new double[] {3.0, 5.0, 9.0}, 6); + CompressedMatrixBlock cmb = compressForced(mb, CompressionType.SDC); + MatrixBlock select = new MatrixBlock(1, COLS, false); + select.set(0, 0, 1); + select.set(0, 2, 1); + select.setNonZeros(-1); + MatrixBlock actual = cmb.removeEmptyOperations(null, false, false, select); + select = new MatrixBlock(1, COLS, false); + select.set(0, 0, 1); + select.set(0, 2, 1); + MatrixBlock expected = mb.removeEmptyOperations(null, false, false, select); + TestUtils.compareMatrices(expected, actual, 0.0, "removeEmpty cols unknown-nnz select"); + } + + /** Row selection vector with unknown (-1) non-zero count, forcing the recompute branch. */ + @Test + public void rowsUnknownNnz() { + MatrixBlock mb = buildIdentical(new double[] {3.0, 5.0, 9.0}, 7); + CompressedMatrixBlock cmb = compressForced(mb, CompressionType.SDC); + MatrixBlock select = rowSelect(); + select.setNonZeros(-1); + MatrixBlock actual = cmb.removeEmptyOperations(null, true, false, select); + MatrixBlock expected = mb.removeEmptyOperations(null, true, false, rowSelect()); + TestUtils.compareMatrices(expected, actual, 0.0, "removeEmpty rows unknown-nnz select"); + } + + private void runColSubset(MatrixBlock mb, CompressionType ct) { + CompressedMatrixBlock cmb = compressForced(mb, ct); + assertTrue("Expected a multi-column " + ct + " group to reach the subset path", + cmb.getColGroups().stream().anyMatch(g -> g.getNumCols() > 1)); + + // Keep a strict subset (drop the middle column) so removeEmptyColsSubset is hit instead of copyAndSet. + MatrixBlock select = new MatrixBlock(1, COLS, false); + select.set(0, 0, 1); + select.set(0, 2, 1); + + MatrixBlock actual = cmb.removeEmptyOperations(null, false, false, select); + select = new MatrixBlock(1, COLS, false); + select.set(0, 0, 1); + select.set(0, 2, 1); + MatrixBlock expected = mb.removeEmptyOperations(null, false, false, select); + TestUtils.compareMatrices(expected, actual, 0.0, "removeEmpty col subset for " + ct); + } + + private static MatrixBlock rowSelect() { + MatrixBlock select = new MatrixBlock(ROWS, 1, false); + for(int i = 0; i < ROWS; i += 2) + select.set(i, 0, 1); + return select; + } + + private static CompressedMatrixBlock compressForced(MatrixBlock mb, CompressionType ct) { + CompressionSettingsBuilder csb = new CompressionSettingsBuilder().setMinimumCompressionRatio(0.0) + .setValidCompressions(EnumSet.of(ct)); + MatrixBlock c = CompressedMatrixBlockFactory.compress(mb, 1, csb).getLeft(); + assertTrue("Expected the input to compress into a " + ct + " backed block", c instanceof CompressedMatrixBlock); + return (CompressedMatrixBlock) c; + } + + /** + * Builds a {@code ROWS x COLS} matrix whose columns are identical so column co-coding merges them into a single + * multi-column group, with one dominant value plus a few off-values. + */ + private static MatrixBlock buildIdentical(double[] others, int seed) { + MatrixBlock mb = new MatrixBlock(ROWS, COLS, false); + mb.allocateDenseBlock(); + Random r = new Random(seed); + for(int i = 0; i < ROWS; i++) { + double v = 7.0; + if(r.nextDouble() < 0.2) + v = others[r.nextInt(others.length)]; + for(int j = 0; j < COLS; j++) + mb.set(i, j, v); + } + mb.recomputeNonZeros(); + return mb; + } +} diff --git a/src/test/java/org/apache/sysds/test/component/compress/CompressedRemoveEmptyForcedTest.java b/src/test/java/org/apache/sysds/test/component/compress/CompressedRemoveEmptyForcedTest.java new file mode 100644 index 00000000000..f08cbd92f42 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/compress/CompressedRemoveEmptyForcedTest.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.compress; + +import static org.junit.Assert.assertTrue; + +import java.util.EnumSet; + +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory; +import org.apache.sysds.runtime.compress.CompressionSettingsBuilder; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +/** + * Verifies that {@code removeEmptyOperations} with a selection vector degrades gracefully (decompresses) for + * column-group encodings that do not implement index-only row/column removal (e.g. OLE/RLE), rather than throwing + * {@link org.apache.commons.lang3.NotImplementedException}. + */ +public class CompressedRemoveEmptyForcedTest { + + private static final int ROWS = 500; + private static final int COLS = 4; + + @Test + public void removeEmptyRowsFallbackOLE() { + runFallback(CompressionType.OLE, true); + } + + @Test + public void removeEmptyRowsFallbackRLE() { + runFallback(CompressionType.RLE, true); + } + + @Test + public void removeEmptyColsFallbackRLE() { + runFallback(CompressionType.RLE, false); + } + + private void runFallback(CompressionType ct, boolean rows) { + MatrixBlock mb = CompressibleInputGenerator.getInput(ROWS, COLS, ct, 10, 0.6, 7); + + CompressionSettingsBuilder csb = new CompressionSettingsBuilder().setMinimumCompressionRatio(0.0) + .setValidCompressions(EnumSet.of(ct)); + MatrixBlock compressed = CompressedMatrixBlockFactory.compress(mb, 1, csb).getLeft(); + assertTrue("Expected the input to compress into a " + ct + " backed block", + compressed instanceof CompressedMatrixBlock); + CompressedMatrixBlock cmb = (CompressedMatrixBlock) compressed; + assertTrue("Expected at least one " + ct + " column group to exercise the fallback path", + containsType(cmb, ct)); + + // Use a strict subset selection so the column path reaches removeEmptyColsSubset (which throws + // NotImplementedException for OLE/RLE) rather than the copyAndSet all-selected shortcut. + final MatrixBlock select; + if(rows) { + select = new MatrixBlock(ROWS, 1, false); + for(int i = 0; i < ROWS; i += 2) + select.set(i, 0, 1); + } + else { + select = new MatrixBlock(1, COLS, false); + select.set(0, 0, 1); + } + + // Must not throw NotImplementedException; must match the uncompressed reference via decompression fallback. + MatrixBlock actual = cmb.removeEmptyOperations(null, rows, false, select); + MatrixBlock expected = mb.removeEmptyOperations(null, rows, false, select); + TestUtils.compareMatrices(expected, actual, 0.0, "removeEmpty fallback for " + ct + " rows=" + rows); + } + + private static boolean containsType(CompressedMatrixBlock cmb, CompressionType ct) { + for(AColGroup g : cmb.getColGroups()) + if(g.getCompType() == ct) + return true; + return false; + } +} diff --git a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java index e6e41755dd9..af21b14206a 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java +++ b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java @@ -49,6 +49,7 @@ import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; import org.apache.sysds.runtime.compress.lib.CLALibLeftMultBy; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -468,6 +469,18 @@ public AColGroup[] splitReshapePushDown(int multiplier, int nRow, int nColOrg, E // TODO Auto-generated method stub throw new UnsupportedOperationException("Unimplemented method 'splitReshapePushDown'"); } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'removeEmptyRows'"); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'removeEmptyColsSubset'"); + } } private class FakeDictBasedColGroup extends ADictBasedColGroup { @@ -777,5 +790,17 @@ public AColGroup[] splitReshapePushDown(int multiplier, int nRow, int nColOrg, E // TODO Auto-generated method stub throw new UnsupportedOperationException("Unimplemented method 'splitReshapePushDown'"); } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'removeEmptyRows'"); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'removeEmptyColsSubset'"); + } } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/offset/CustomOffsetTest.java b/src/test/java/org/apache/sysds/test/component/compress/offset/CustomOffsetTest.java index 2e901eeb14d..3755365c018 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/offset/CustomOffsetTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/offset/CustomOffsetTest.java @@ -28,13 +28,14 @@ import org.apache.sysds.runtime.compress.colgroup.offset.AIterator; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.OffsetSliceInfo; +import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.RemoveEmptyOffsetsTmp; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; import org.junit.Test; public class CustomOffsetTest { protected static final Log LOG = LogFactory.getLog(CustomOffsetTest.class.getName()); - static{ + static { CompressedMatrixBlock.debug = true; } @@ -96,4 +97,95 @@ public void printCache() { String s = off.toString(); assertTrue(s.contains("CacheRow")); } + + @Test + public void removeEmptyRows1() { + AOffset of = OffsetFactory.createOffset(new int[] {1, 2, 3, 4, 5}); + RemoveEmptyOffsetsTmp t = of.removeEmptyRows(new boolean[] {false, true, false, false, false, false}, 0); + assertEquals(1, t.select.size()); + assertEquals(0, t.select.get(0)); + assertEquals(1, t.retOffset.getSize()); + assertEquals(OffsetFactory.createOffset(new int[] {0}), t.retOffset); + } + + @Test + public void removeEmptyRows2() { + AOffset of = OffsetFactory.createOffset(new int[] {1, 2, 3, 4, 5}); + RemoveEmptyOffsetsTmp t = of.removeEmptyRows(new boolean[] {false, false, true, false, false, false}, 0); + assertEquals(1, t.select.size()); + assertEquals(1, t.select.get(0)); + assertEquals(1, t.retOffset.getSize()); + assertEquals(OffsetFactory.createOffset(new int[] {0}), t.retOffset); + } + + @Test + public void removeEmptyRows3() { + AOffset of = OffsetFactory.createOffset(new int[] {1, 2, 3, 4, 5}); + RemoveEmptyOffsetsTmp t = of.removeEmptyRows(new boolean[] {false, true, true, false, false, false}, 0); + assertEquals(2, t.select.size()); + assertEquals(0, t.select.get(0)); + assertEquals(1, t.select.get(1)); + assertEquals(2, t.retOffset.getSize()); + assertEquals(OffsetFactory.createOffset(new int[] {0, 1}), t.retOffset); + } + + @Test + public void removeEmptyRows4() { + AOffset of = OffsetFactory.createOffset(new int[] {1, 2, 3, 4, 5}); + RemoveEmptyOffsetsTmp t = of.removeEmptyRows(new boolean[] {false, true, true, false, false, true}, 0); + assertEquals(3, t.select.size()); + assertEquals(0, t.select.get(0)); + assertEquals(1, t.select.get(1)); + assertEquals(4, t.select.get(2)); + assertEquals(3, t.retOffset.getSize()); + assertEquals(OffsetFactory.createOffset(new int[] {0, 1, 2}), t.retOffset); + } + + @Test + public void removeEmptyRows5() { + AOffset of = OffsetFactory.createOffset(new int[] {1, 2, 3, 4, 5}); + RemoveEmptyOffsetsTmp t = of.removeEmptyRows(new boolean[] {false, false, false, false, false, true}, 0); + assertEquals(1, t.select.size()); + assertEquals(4, t.select.get(0)); + assertEquals(1, t.retOffset.getSize()); + assertEquals(OffsetFactory.createOffset(new int[] {0}), t.retOffset); + } + + @Test + public void removeEmptyRows6() { + AOffset of = OffsetFactory.createOffset(new int[] {1, 2, 5}); + RemoveEmptyOffsetsTmp t = of.removeEmptyRows(new boolean[] {false, false, false, true, true, true}, 0); + assertEquals(1, t.select.size()); + assertEquals(2, t.select.get(0)); + assertEquals(1, t.retOffset.getSize()); + assertEquals(OffsetFactory.createOffset(new int[] {2}), t.retOffset); + } + + @Test + public void removeEmptyRows7() { + AOffset of = OffsetFactory.createOffset(new int[] {1, 2, 5}); + RemoveEmptyOffsetsTmp t = of.removeEmptyRows(new boolean[] {true, false, false, true, true, true}, 0); + assertEquals(1, t.select.size()); + assertEquals(2, t.select.get(0)); + assertEquals(1, t.retOffset.getSize()); + assertEquals(OffsetFactory.createOffset(new int[] {3}), t.retOffset); + } + + @Test + public void removeEmptyRows8() { + AOffset of = OffsetFactory.createOffset(new int[] {1, 2, 3, 4, 5}); + RemoveEmptyOffsetsTmp t = of.removeEmptyRows(new boolean[] {true, false, false, false, false, true}, 0); + assertEquals(1, t.select.size()); + assertEquals(4, t.select.get(0)); + assertEquals(1, t.retOffset.getSize()); + assertEquals(OffsetFactory.createOffset(new int[] {1}), t.retOffset); + } + + @Test + public void removeEmptyRowsEmpty() { + AOffset of = OffsetFactory.createOffset(new int[] {1, 2, 3, 4, 5}); + RemoveEmptyOffsetsTmp t = of.removeEmptyRows(new boolean[] {false, false, false, false, false, false}, 0); + assertEquals(0, t.select.size()); + assertEquals(OffsetFactory.createOffset(new int[] {}), t.retOffset); + } }