From 07133428b94c42047084d2d4d8597dab0776128a Mon Sep 17 00:00:00 2001 From: philipportner Date: Mon, 7 Jul 2025 20:35:51 +0200 Subject: [PATCH] [SYSTEMDS-3708] Fix permutation-matrix method Adds a test case with inputs that have multiple groups with varying row counts. This pattern comes from a `lineorder.csv` example dataset that currently causes a runtime exception for the `permutation-matrix` approach but works for the `nested-loop` approach. Why this happened: - `permutation-matrix` approach allocated space assuming every group has `maxRowsInGroup` rows, which is not always the case - groups may have variable sizes resulting in `Y_temp_reduce` having fewer rows than the reshape expects Changes: - correctly pads the matrix in when groups do not all have `maxRowsInGroup` rows - adds testcases that cover this pattern --- scripts/builtin/raGroupby.dml | 17 ++++++- .../builtin/part2/BuiltinRaGroupbyTest.java | 45 +++++++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/scripts/builtin/raGroupby.dml b/scripts/builtin/raGroupby.dml index 7d7035c0ff8..0a23bf51ef0 100644 --- a/scripts/builtin/raGroupby.dml +++ b/scripts/builtin/raGroupby.dml @@ -132,7 +132,22 @@ m_raGroupby = function (Matrix[Double] X, Integer col, String method) # Set value of final output Y = matrix(0, rows=numGroups, cols=totalCells) Y[,1] = key_unique - Y[,2:ncol(Y)] = matrix(Y_temp_reduce, rows=numGroups, cols=totalCells-1) + + # The permutation matrix creates a structure where each group's data + # may not fill exactly maxRowsInGroup rows. + # If needed, we need to pad to the expected size first. + expectedRows = numGroups * maxRowsInGroup + actualRows = nrow(Y_temp_reduce) + + if(actualRows < expectedRows) { + # Pad Y_temp_reduce with zeros to match expected structure + Y_tmp_padded = matrix(0, rows=expectedRows, cols=ncol(Y_temp_reduce)) + Y_tmp_padded[1:actualRows,] = Y_temp_reduce + } else { + Y_tmp_padded = Y_temp_reduce + } + + Y[,2:ncol(Y)] = matrix(Y_tmp_padded, rows=numGroups, cols=totalCells-1) } } diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinRaGroupbyTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinRaGroupbyTest.java index 18ad933e77b..bbdc9fcad03 100644 --- a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinRaGroupbyTest.java +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinRaGroupbyTest.java @@ -80,6 +80,16 @@ public void testRaGroupbyTestwithOneGroup2() { testRaGroupbyTestwithOneGroup("permutation-matrix"); } + @Test + public void testRaGroupbyTestwithMultipleGroupRows1() { + testRaGroupbyTestwithMultipleGroupRows("nested-loop"); + } + + @Test + public void testRaGroupbyTestwithMultipleGroupRows2() { + testRaGroupbyTestwithMultipleGroupRows("permutation-matrix"); + } + public void testRaGroupbyTest(String method) { //generate actual dataset and variables double[][] X = { @@ -160,6 +170,41 @@ public void testRaGroupbyTestwithOneGroup(String method) { runRaGroupbyTest(X, select_col, Y, method); } + public void testRaGroupbyTestwithMultipleGroupRows(String method) { + // Test case with multiple groups having different numbers of rows + // 10 rows x 5 columns, grouping by column 2 + // Groups: 1->3 rows, 2->2 rows, 3->2 rows, 4->2 rows, 5->1 row + double[][] X = { + {1, 1, 11, 12, 13}, + {1, 2, 21, 22, 23}, + {1, 3, 31, 32, 33}, + {1, 4, 41, 42, 43}, + {2, 1, 14, 15, 16}, + {2, 2, 24, 25, 26}, + {2, 3, 34, 35, 36}, + {2, 4, 44, 45, 46}, + {2, 5, 54, 55, 56}, + {3, 1, 17, 18, 19}}; + int select_col = 2; + + // Expected output matrix (grouping by column 2, removing column 2) + // Note: Groups are ordered as they appear in the unique() function output + // Group 1: 3 rows -> [1,11,12,13], [2,14,15,16], [3,17,18,19] + // Group 2: 2 rows -> [1,21,22,23], [2,24,25,26] + // Group 4: 2 rows -> [1,41,42,43], [2,44,45,46] + // Group 5: 1 row -> [2,54,55,56] + // Group 3: 2 rows -> [1,31,32,33], [2,34,35,36] + double[][] Y = { + {1, 1, 11, 12, 13, 2, 14, 15, 16, 3, 17, 18, 19}, + {2, 1, 21, 22, 23, 2, 24, 25, 26, 0, 0, 0, 0}, + {4, 1, 41, 42, 43, 2, 44, 45, 46, 0, 0, 0, 0}, + {5, 2, 54, 55, 56, 0, 0, 0, 0, 0, 0, 0, 0}, + {3, 1, 31, 32, 33, 2, 34, 35, 36, 0, 0, 0, 0} + }; + + runRaGroupbyTest(X, select_col, Y, method); + } + private void runRaGroupbyTest(double [][] X, int col, double [][] Y, String method) { ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE);