Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion scripts/builtin/raGroupby.dml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,22 @@ m_raGroupby = function (Matrix[Double] X, Integer col, String method)
# Set value of final output
Y = matrix(0, rows=numGroups, cols=totalCells)
Y[,1] = key_unique
Y[,2:ncol(Y)] = matrix(Y_temp_reduce, rows=numGroups, cols=totalCells-1)

# The permutation matrix creates a structure where each group's data
# may not fill exactly maxRowsInGroup rows.
# If needed, we need to pad to the expected size first.
expectedRows = numGroups * maxRowsInGroup
actualRows = nrow(Y_temp_reduce)

if(actualRows < expectedRows) {
# Pad Y_temp_reduce with zeros to match expected structure
Y_tmp_padded = matrix(0, rows=expectedRows, cols=ncol(Y_temp_reduce))
Y_tmp_padded[1:actualRows,] = Y_temp_reduce
} else {
Y_tmp_padded = Y_temp_reduce
}

Y[,2:ncol(Y)] = matrix(Y_tmp_padded, rows=numGroups, cols=totalCells-1)
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,16 @@ public void testRaGroupbyTestwithOneGroup2() {
testRaGroupbyTestwithOneGroup("permutation-matrix");
}

@Test
public void testRaGroupbyTestwithMultipleGroupRows1() {
testRaGroupbyTestwithMultipleGroupRows("nested-loop");
}

@Test
public void testRaGroupbyTestwithMultipleGroupRows2() {
testRaGroupbyTestwithMultipleGroupRows("permutation-matrix");
}

public void testRaGroupbyTest(String method) {
//generate actual dataset and variables
double[][] X = {
Expand Down Expand Up @@ -160,6 +170,41 @@ public void testRaGroupbyTestwithOneGroup(String method) {
runRaGroupbyTest(X, select_col, Y, method);
}

public void testRaGroupbyTestwithMultipleGroupRows(String method) {
// Test case with multiple groups having different numbers of rows
// 10 rows x 5 columns, grouping by column 2
// Groups: 1->3 rows, 2->2 rows, 3->2 rows, 4->2 rows, 5->1 row
double[][] X = {
{1, 1, 11, 12, 13},
{1, 2, 21, 22, 23},
{1, 3, 31, 32, 33},
{1, 4, 41, 42, 43},
{2, 1, 14, 15, 16},
{2, 2, 24, 25, 26},
{2, 3, 34, 35, 36},
{2, 4, 44, 45, 46},
{2, 5, 54, 55, 56},
{3, 1, 17, 18, 19}};
int select_col = 2;

// Expected output matrix (grouping by column 2, removing column 2)
// Note: Groups are ordered as they appear in the unique() function output
// Group 1: 3 rows -> [1,11,12,13], [2,14,15,16], [3,17,18,19]
// Group 2: 2 rows -> [1,21,22,23], [2,24,25,26]
// Group 4: 2 rows -> [1,41,42,43], [2,44,45,46]
// Group 5: 1 row -> [2,54,55,56]
// Group 3: 2 rows -> [1,31,32,33], [2,34,35,36]
double[][] Y = {
{1, 1, 11, 12, 13, 2, 14, 15, 16, 3, 17, 18, 19},
{2, 1, 21, 22, 23, 2, 24, 25, 26, 0, 0, 0, 0},
{4, 1, 41, 42, 43, 2, 44, 45, 46, 0, 0, 0, 0},
{5, 2, 54, 55, 56, 0, 0, 0, 0, 0, 0, 0, 0},
{3, 1, 31, 32, 33, 2, 34, 35, 36, 0, 0, 0, 0}
};

runRaGroupbyTest(X, select_col, Y, method);
}

private void runRaGroupbyTest(double [][] X, int col, double [][] Y, String method)
{
ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE);
Expand Down
Loading