From 11e11b4cb7eacc907290139fcd930661afa44dcb Mon Sep 17 00:00:00 2001 From: ramesesz Date: Mon, 11 Dec 2023 16:43:13 +0100 Subject: [PATCH 01/10] [MINOR] Performance improvement of the builtin dist function This patch improves the builtin dist function by removing the outer product operator. For 100 function calls on an arbitrary matrix with 4000 rows and 800 cols, the new dist function shortens the runtime from 66.541s to 60.268s. --- scripts/builtin/dist.dml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/builtin/dist.dml b/scripts/builtin/dist.dml index 26ded9a1976..7f34b21e8c8 100644 --- a/scripts/builtin/dist.dml +++ b/scripts/builtin/dist.dml @@ -32,7 +32,8 @@ # ----------------------------------------------------------------------------------------------- m_dist = function(Matrix[Double] X) return (Matrix[Double] Y) { - G = X %*% t(X); - Y = sqrt(-2 * G + outer(diag(G), t(diag(G)), "+")); + n = nrow(X) + s = rowSums(X * X) + Y = - 2*X %*% t(X) + s + t(s) Y = replace(target = Y, pattern=NaN, replacement = 0); } From daf25e24334f7dcc0bb47851ce220c507ec8243c Mon Sep 17 00:00:00 2001 From: ramesesz Date: Sun, 17 Dec 2023 14:00:51 +0100 Subject: [PATCH 02/10] Added missing sqrt and simplified notation. --- scripts/builtin/dist.dml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/builtin/dist.dml b/scripts/builtin/dist.dml index 7f34b21e8c8..f296fd717bc 100644 --- a/scripts/builtin/dist.dml +++ b/scripts/builtin/dist.dml @@ -33,7 +33,7 @@ m_dist = function(Matrix[Double] X) return (Matrix[Double] Y) { n = nrow(X) - s = rowSums(X * X) - Y = - 2*X %*% t(X) + s + t(s) + s = rowSums(X^2) + Y = sqrt(-2 * X %*% t(X) + s + t(s)) Y = replace(target = Y, pattern=NaN, replacement = 0); } From 779627e46bd5809dc414716cb0ea29269bdc5e04 Mon Sep 17 00:00:00 2001 From: ramesesz Date: Fri, 26 Jan 2024 10:18:22 +0100 Subject: [PATCH 03/10] [MINOR] Added early-stopping mechanism to tSNE --- scripts/builtin/tSNE.dml | 50 ++++++++- .../builtin/part2/BuiltinTSNETest.java | 104 ++++++++++++++++++ 2 files changed, 150 insertions(+), 4 deletions(-) diff --git a/scripts/builtin/tSNE.dml b/scripts/builtin/tSNE.dml index 131ab1013cb..0813d7508c9 100644 --- a/scripts/builtin/tSNE.dml +++ b/scripts/builtin/tSNE.dml @@ -31,6 +31,7 @@ # lr Learning rate # momentum Momentum Parameter # max_iter Number of iterations +# tol Tolerance for early stopping in gradient descent # seed The seed used for initial values. # If set to -1 random seeds are selected. # is_verbose Print debug information @@ -42,7 +43,7 @@ # ------------------------------------------------------------------------------------------- m_tSNE = function(Matrix[Double] X, Integer reduced_dims = 2, Integer perplexity = 30, - Double lr = 300., Double momentum = 0.9, Integer max_iter = 1000, Integer seed = -1, Boolean is_verbose = FALSE) + Double lr = 300., Double momentum = 0.9, Integer max_iter = 1000, Double tol = 1e-5, Integer seed = -1, Boolean is_verbose = FALSE) return(Matrix[Double] Y) { d = reduced_dims @@ -63,7 +64,40 @@ m_tSNE = function(Matrix[Double] X, Integer reduced_dims = 2, Integer perplexity if(is_verbose) print("starting loop....") - for (itr in 1:max_iter) { + itr = 1 + + # Start first iteration out of loop as benchmark for early stopping + D = distance_matrix(Y) + Z = 1/(D + 1) + Z = Z * ZERODIAG + Q = Z/sum(Z) + W = (P - Q)*Z + sumW = rowSums(W) + g = Y * sumW - W %*% Y + dY = momentum*dY - lr*g + + norm = sum(dY^2) + norm_initial = norm + norm_target = norm_initial * tol + + if(is_verbose){ + print("L1 Norm initial : " + norm_initial) + print("L1 Norm target : " + norm_target) + } + + Y = Y + dY + Y = Y - colMeans(Y) + + if (itr%%100 == 0) { + C[itr/100,] = sum(P * log(pmax(P, 1e-12) / pmax(Q, 1e-12))) + } + if (itr == 100) { + P = P/4 + } + itr = itr + 1 + # End of first iteration + + while (itr <= max_iter & norm > norm_target) { D = distance_matrix(Y) Z = 1/(D + 1) Z = Z * ZERODIAG @@ -72,6 +106,13 @@ m_tSNE = function(Matrix[Double] X, Integer reduced_dims = 2, Integer perplexity sumW = rowSums(W) g = Y * sumW - W %*% Y dY = momentum*dY - lr*g + + norm = sum(dY^2) + if(is_verbose & itr %%10 ==0){ + print("Iteration: " + itr) + print("L1 Norm: " + norm) + } + Y = Y + dY Y = Y - colMeans(Y) @@ -81,6 +122,7 @@ m_tSNE = function(Matrix[Double] X, Integer reduced_dims = 2, Integer perplexity if (itr == 100) { P = P/4 } + itr = itr + 1 } } @@ -119,7 +161,7 @@ return(matrix[double] P) while (mean(abs(Hdiff)) > tol & itr < 50) { P = exp(-D * beta) P = P * ZERODIAG - sum_Pi = rowSums(P) + sum_Pi = rowSums(P) = 1e-12 W = rowSums(P * D) Ws = W/sum_Pi H = log(sum_Pi) + beta * Ws @@ -141,4 +183,4 @@ return(matrix[double] P) P = P / sum(P) if(is_verbose) print("x2p finishing....") -} +} \ No newline at end of file diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java index 41d2d4606fb..61e220e2eae 100644 --- a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java @@ -403,4 +403,108 @@ private void runTSNETest(Integer reduced_dims, Integer perplexity, Double lr, rtplatform = platformOld; } } + + @Test + public void testTSNEEarlyStopping() throws IOException { + // Test setup guarantees early stopping. + runTSNEEarlyStoppingTest(2, 30, 300., 0.9, 1000, 1e-3, 1, "FALSE", ExecType.CP); + } + + @SuppressWarnings("unused") + private void runTSNEEarlyStoppingTest( + Integer reduced_dims, + Integer perplexity, + Double lr, + Double momentum, + Integer max_iter, + Double tol, + Integer seed, + String is_verbose, + ExecType instType) throws IOException { + + ExecMode platformOld = setExecMode(instType); + try + { + loadTestConfiguration(getTestConfiguration(TEST_NAME)); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[]{ + "-nvargs", "X=" + input("X"), "Y=" + output("Y"), + "reduced_dims=" + reduced_dims, + "perplexity=" + perplexity, + "lr=" + lr, + "momentum=" + momentum, + "max_iter=" + max_iter, + "tol= " + tol, + "seed=" + seed, + "is_verbose=" + is_verbose}; + + // The Input values are calculated using the following dml script: + // X = rand(rows=50, cols=2, min=0, max=5, seed=1) + + // Input + double[][] X = { + {2.271495063217468, 3.917376227330574}, + {1.8027277767886734, 0.5602638182708702}, + {1.8307117742445955, 3.080459273928752}, + {3.960945849944864, 1.836254625202115}, + {3.150073237716238, 1.9323395742833234}, + {2.6433200695475314, 4.1796071244359805}, + {3.837027652831461, 1.8984862827162574}, + {3.3620223450187448, 4.221502623882378}, + {2.3282372847390254, 4.602696981602351}, + {1.063050408038052, 3.049136203059148}, + {4.945108528303021, 3.290728762588105}, + {0.03997874419356229, 4.78972783775991}, + {3.219940877253892, 0.4546090824785526}, + {3.661862179707895, 4.9115252981693205}, + {2.006763020664273, 1.6504573252270927}, + {4.802896313025078, 3.7058196696317185}, + {4.989560263975035, 3.3590878579410233}, + {0.2881957805129115, 2.7235626348446864}, + {4.205473623958116, 0.7513651648333092}, + {1.5030599075316982, 1.9059965151083047}, + {4.111690819873698, 4.38922550887249}, + {3.55235293843559, 2.7707785045249262}, + {3.5421273466628604, 3.218473690489352}, + {2.0021344008348603, 3.293397607562143}, + {0.6236309437993054, 4.690911049840824}, + {4.28743111141226, 3.058259024138692}, + {1.351324262063277, 1.4910437726755477}, + {2.328053099817537, 2.844624510685577}, + {2.058835681566319, 3.1365678249943336}, + {3.758610361307626, 1.0596733909061373}, + {4.4615463190110205, 4.67202160391804}, + {1.44939799230235, 0.3342638523743646}, + {4.299621130384286, 3.781441439604645}, + {4.671573038039089, 1.1565494768485123}, + {0.8624668449657552, 1.9085522899983942}, + {0.34305466410947616, 0.6344221672215061}, + {4.837399879571096, 4.391970748711334}, + {4.280838563730712, 3.3498259946465705}, + {0.9926830544799081, 4.198090879512748}, + {0.2809217637487471, 2.7963040315556564}, + {0.17872992178431912, 3.565772551292108}, + {4.148793911769612, 1.0757141044759506}, + {2.0111513617190186, 2.7646430913923767}, + {0.5114578168532041, 1.3708650661139115}, + {0.38545762498678526, 0.21277125305527278}, + {2.356200617781426, 2.20790000896965}, + {3.665608219962555, 3.399666975542729}, + {1.7618442622801385, 4.675570729512945}, + {4.987236193552888, 0.41700477957766546}, + {0.21496074278985922, 3.5781179414157616} + }; + + + writeInputMatrixWithMTD("X", X, true); + + runTest(true, false, null, -1); + } + finally { + rtplatform = platformOld; + } + + } } From f702dc70831105e2b819332057994527acf32796 Mon Sep 17 00:00:00 2001 From: ramesesz Date: Sun, 28 Jan 2024 15:32:22 +0100 Subject: [PATCH 04/10] Revert dist() changes --- scripts/builtin/dist.dml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/scripts/builtin/dist.dml b/scripts/builtin/dist.dml index f296fd717bc..26ded9a1976 100644 --- a/scripts/builtin/dist.dml +++ b/scripts/builtin/dist.dml @@ -32,8 +32,7 @@ # ----------------------------------------------------------------------------------------------- m_dist = function(Matrix[Double] X) return (Matrix[Double] Y) { - n = nrow(X) - s = rowSums(X^2) - Y = sqrt(-2 * X %*% t(X) + s + t(s)) + G = X %*% t(X); + Y = sqrt(-2 * G + outer(diag(G), t(diag(G)), "+")); Y = replace(target = Y, pattern=NaN, replacement = 0); } From 91dcf78d114c35d2b85deb2b86fed00b7fe24cbb Mon Sep 17 00:00:00 2001 From: ramesesz Date: Sun, 28 Jan 2024 15:32:43 +0100 Subject: [PATCH 05/10] Fixed symbol error --- scripts/builtin/tSNE.dml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/builtin/tSNE.dml b/scripts/builtin/tSNE.dml index 0813d7508c9..c90c8802a5f 100644 --- a/scripts/builtin/tSNE.dml +++ b/scripts/builtin/tSNE.dml @@ -161,7 +161,7 @@ return(matrix[double] P) while (mean(abs(Hdiff)) > tol & itr < 50) { P = exp(-D * beta) P = P * ZERODIAG - sum_Pi = rowSums(P) = 1e-12 + sum_Pi = rowSums(P) + 1e-12 W = rowSums(P * D) Ws = W/sum_Pi H = log(sum_Pi) + beta * Ws @@ -183,4 +183,4 @@ return(matrix[double] P) P = P / sum(P) if(is_verbose) print("x2p finishing....") -} \ No newline at end of file +} From e80f209718cf931e65fadfc0d84045dc4062e61f Mon Sep 17 00:00:00 2001 From: ramesesz Date: Sun, 28 Jan 2024 17:44:30 +0100 Subject: [PATCH 06/10] Added print_iter variable --- scripts/builtin/tSNE.dml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/scripts/builtin/tSNE.dml b/scripts/builtin/tSNE.dml index c90c8802a5f..0fb802d019b 100644 --- a/scripts/builtin/tSNE.dml +++ b/scripts/builtin/tSNE.dml @@ -35,6 +35,8 @@ # seed The seed used for initial values. # If set to -1 random seeds are selected. # is_verbose Print debug information +# print_iter Intervals of printing out the L1 norm values. Parameter not relevant if +# is_verbose = FALSE. # ------------------------------------------------------------------------------------------- # # OUTPUT: @@ -43,7 +45,8 @@ # ------------------------------------------------------------------------------------------- m_tSNE = function(Matrix[Double] X, Integer reduced_dims = 2, Integer perplexity = 30, - Double lr = 300., Double momentum = 0.9, Integer max_iter = 1000, Double tol = 1e-5, Integer seed = -1, Boolean is_verbose = FALSE) + Double lr = 300., Double momentum = 0.9, Integer max_iter = 1000, Double tol = 1e-5, + Integer seed = -1, Boolean is_verbose = FALSE, Integer print_iter = 10) return(Matrix[Double] Y) { d = reduced_dims @@ -108,7 +111,7 @@ m_tSNE = function(Matrix[Double] X, Integer reduced_dims = 2, Integer perplexity dY = momentum*dY - lr*g norm = sum(dY^2) - if(is_verbose & itr %%10 ==0){ + if(is_verbose & itr %% print_iter == 0){ print("Iteration: " + itr) print("L1 Norm: " + norm) } From 024d823c88b47fbf0ce59f880172779d43ef5287 Mon Sep 17 00:00:00 2001 From: ramesesz Date: Sun, 28 Jan 2024 17:45:42 +0100 Subject: [PATCH 07/10] Removed use of distance_matrix within script --- scripts/builtin/tSNE.dml | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/scripts/builtin/tSNE.dml b/scripts/builtin/tSNE.dml index 0fb802d019b..299ecbd9b07 100644 --- a/scripts/builtin/tSNE.dml +++ b/scripts/builtin/tSNE.dml @@ -70,7 +70,7 @@ m_tSNE = function(Matrix[Double] X, Integer reduced_dims = 2, Integer perplexity itr = 1 # Start first iteration out of loop as benchmark for early stopping - D = distance_matrix(Y) + D = dist(Y) Z = 1/(D + 1) Z = Z * ZERODIAG Q = Z/sum(Z) @@ -101,7 +101,7 @@ m_tSNE = function(Matrix[Double] X, Integer reduced_dims = 2, Integer perplexity # End of first iteration while (itr <= max_iter & norm > norm_target) { - D = distance_matrix(Y) + D = dist(Y) Z = 1/(D + 1) Z = Z * ZERODIAG Q = Z/sum(Z) @@ -129,17 +129,6 @@ m_tSNE = function(Matrix[Double] X, Integer reduced_dims = 2, Integer perplexity } } -distance_matrix = function(matrix[double] X) - return (matrix[double] out) -{ - # TODO consolidate with dist() builtin, but with - # better way of obtaining the diag from - n = nrow(X) - s = rowSums(X * X) - out = - 2*X %*% t(X) + s + t(s) -} - - x2p = function(matrix[double] X, double perplexity, Boolean is_verbose = FALSE) return(matrix[double] P) { @@ -150,7 +139,7 @@ return(matrix[double] P) n = nrow(X) if(is_verbose) print(n) - D = distance_matrix(X) + D = dist(X) P = matrix(0, rows=n, cols=n) beta = matrix(1, rows=n, cols=1) From acf847ba560a662a5bb9e45da98c1a1bb1446fdd Mon Sep 17 00:00:00 2001 From: ramesesz Date: Sun, 28 Jan 2024 20:57:47 +0100 Subject: [PATCH 08/10] Changed test case --- .../builtin/part2/BuiltinTSNETest.java | 134 ++++++++++-------- 1 file changed, 78 insertions(+), 56 deletions(-) diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java index 61e220e2eae..f7d8fd12652 100644 --- a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java @@ -25,6 +25,9 @@ import org.apache.sysds.test.TestConfiguration; import org.junit.Test; +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; +import java.io.UnsupportedEncodingException; import java.io.IOException; public class BuiltinTSNETest extends AutomatedTestBase @@ -404,10 +407,11 @@ private void runTSNETest(Integer reduced_dims, Integer perplexity, Double lr, } } + @Test public void testTSNEEarlyStopping() throws IOException { // Test setup guarantees early stopping. - runTSNEEarlyStoppingTest(2, 30, 300., 0.9, 1000, 1e-3, 1, "FALSE", ExecType.CP); + runTSNEEarlyStoppingTest(2, 30, 300., 0.9, 1000, 1e-1, 1, "TRUE", ExecType.CP); } @SuppressWarnings("unused") @@ -444,64 +448,82 @@ private void runTSNEEarlyStoppingTest( // X = rand(rows=50, cols=2, min=0, max=5, seed=1) // Input - double[][] X = { - {2.271495063217468, 3.917376227330574}, - {1.8027277767886734, 0.5602638182708702}, - {1.8307117742445955, 3.080459273928752}, - {3.960945849944864, 1.836254625202115}, - {3.150073237716238, 1.9323395742833234}, - {2.6433200695475314, 4.1796071244359805}, - {3.837027652831461, 1.8984862827162574}, - {3.3620223450187448, 4.221502623882378}, - {2.3282372847390254, 4.602696981602351}, - {1.063050408038052, 3.049136203059148}, - {4.945108528303021, 3.290728762588105}, - {0.03997874419356229, 4.78972783775991}, - {3.219940877253892, 0.4546090824785526}, - {3.661862179707895, 4.9115252981693205}, - {2.006763020664273, 1.6504573252270927}, - {4.802896313025078, 3.7058196696317185}, - {4.989560263975035, 3.3590878579410233}, - {0.2881957805129115, 2.7235626348446864}, - {4.205473623958116, 0.7513651648333092}, - {1.5030599075316982, 1.9059965151083047}, - {4.111690819873698, 4.38922550887249}, - {3.55235293843559, 2.7707785045249262}, - {3.5421273466628604, 3.218473690489352}, - {2.0021344008348603, 3.293397607562143}, - {0.6236309437993054, 4.690911049840824}, - {4.28743111141226, 3.058259024138692}, - {1.351324262063277, 1.4910437726755477}, - {2.328053099817537, 2.844624510685577}, - {2.058835681566319, 3.1365678249943336}, - {3.758610361307626, 1.0596733909061373}, - {4.4615463190110205, 4.67202160391804}, - {1.44939799230235, 0.3342638523743646}, - {4.299621130384286, 3.781441439604645}, - {4.671573038039089, 1.1565494768485123}, - {0.8624668449657552, 1.9085522899983942}, - {0.34305466410947616, 0.6344221672215061}, - {4.837399879571096, 4.391970748711334}, - {4.280838563730712, 3.3498259946465705}, - {0.9926830544799081, 4.198090879512748}, - {0.2809217637487471, 2.7963040315556564}, - {0.17872992178431912, 3.565772551292108}, - {4.148793911769612, 1.0757141044759506}, - {2.0111513617190186, 2.7646430913923767}, - {0.5114578168532041, 1.3708650661139115}, - {0.38545762498678526, 0.21277125305527278}, - {2.356200617781426, 2.20790000896965}, - {3.665608219962555, 3.399666975542729}, - {1.7618442622801385, 4.675570729512945}, - {4.987236193552888, 0.41700477957766546}, - {0.21496074278985922, 3.5781179414157616} - }; - + double[][] X = { + {-0.45700987356506406, 2.834752454661148}, + {-1.3945444464226533, -3.8794723634582597}, + {-1.338576451510809, 1.160918547857504}, + {2.921891699889728, -1.32749074959577}, + {1.3001464754324763, -1.1353208514333533}, + {0.2866401390950628, 3.359214248871961}, + {2.6740553056629217, -1.2030274345674852}, + {1.7240446900374895, 3.4430052477647557}, + {-0.3435254305219493, 4.205393963204703}, + {-2.873899183923896, 1.098272406118296}, + {4.890217056606042, 1.5814575251762104}, + {-4.920042511612875, 4.579455675519821}, + {1.439881754507784, -4.090781835042895}, + {2.32372435941579, 4.823050596338641}, + {-0.9864739586714544, -1.6990853495458147}, + {4.605792626050157, 2.411639339263437}, + {4.979120527950069, 1.7181757158820465}, + {-4.423608438974177, 0.44712526968937283}, + {3.4109472479162317, -3.497269670333382}, + {-1.9938801849366037, -1.1880069697833906}, + {3.223381639747396, 3.7784510177449793}, + {2.10470587687118, 0.5415570090498525}, + {2.084254693325721, 1.4369473809787037}, + {-0.9957311983302795, 1.586795215124286}, + {-3.7527381124013894, 4.3818220996816475}, + {3.5748622228245193, 1.116518048277384}, + {-2.297351475873446, -2.0179124546489047}, + {-0.3438938003649259, 0.689249021371154}, + {-0.8823286368673617, 1.2731356499886672}, + {2.517220722615252, -2.8806532181877254}, + {3.923092638022041, 4.34404320783608}, + {-2.1012040153953, -4.33147229525127}, + {3.5992422607685715, 2.5628828792092904}, + {4.3431460760781775, -2.6869010463029754}, + {-3.27506631006849, -1.1828954200032116}, + {-4.3138906717810475, -3.7311556655569875}, + {4.674799759142193, 3.783941497422669}, + {3.561677127461424, 1.699651989293141}, + {-3.0146338910401838, 3.3961817590254952}, + {-4.438156472502506, 0.5926080631113129}, + {-4.6425401564313615, 2.131545102584216}, + {3.2975878235392244, -2.8485717910480988}, + {-0.9776972765619627, 0.5292861827847535}, + {-3.9770843662935915, -2.258269867772177}, + {-4.22908475002643, -4.574457493889454}, + {-0.28759876443714827, -0.5841999820607002}, + {2.33121643992511, 1.7993339510854582}, + {-1.476311475439723, 4.3511414590258894}, + {4.974472387105775, -4.165990440844669}, + {-4.570078514420281, 2.156235882831523} + }; - writeInputMatrixWithMTD("X", X, true); + // Capture console output + ByteArrayOutputStream outContent = new ByteArrayOutputStream(); + System.setOut(new PrintStream(outContent)); runTest(true, false, null, -1); - } + + // Parse and check L1 norm values + String[] lines = outContent.toString().split(System.lineSeparator()); + double prevL1Norm = Double.POSITIVE_INFINITY; + boolean decreasing = true; + for (String line : lines) { + if (line.startsWith("L1 Norm:")) { + double l1Norm = Double.parseDouble(line.substring(9).trim()); + if (l1Norm >= prevL1Norm) { + decreasing = false; + break; + } + prevL1Norm = l1Norm; + } + } + + assertTrue("L1 norm should decrease each time it is printed out", decreasing); + } finally { rtplatform = platformOld; } From 8f9bb2b4c86a05be4e1574b2f516cab7bda6b7e4 Mon Sep 17 00:00:00 2001 From: ramesesz Date: Sun, 28 Jan 2024 21:07:06 +0100 Subject: [PATCH 09/10] Added missing import --- .../sysds/test/functions/builtin/part2/BuiltinTSNETest.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java index f7d8fd12652..1c7260a6df2 100644 --- a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java @@ -24,10 +24,11 @@ import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; import org.junit.Test; +import static org.junit.Assert.assertTrue; + import java.io.ByteArrayOutputStream; import java.io.PrintStream; -import java.io.UnsupportedEncodingException; import java.io.IOException; public class BuiltinTSNETest extends AutomatedTestBase From a8ad3c3ee99e720cc5f56b112d06e53e51ad3910 Mon Sep 17 00:00:00 2001 From: ramesesz Date: Wed, 31 Jan 2024 19:07:37 +0100 Subject: [PATCH 10/10] Fixed early-stopping test --- .../builtin/part2/BuiltinTSNETest.java | 43 +++++++++++-------- src/test/scripts/functions/builtin/tSNE.dml | 2 +- 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java index 1c7260a6df2..44d06ffaf9a 100644 --- a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java @@ -26,9 +26,6 @@ import org.junit.Test; import static org.junit.Assert.assertTrue; - -import java.io.ByteArrayOutputStream; -import java.io.PrintStream; import java.io.IOException; public class BuiltinTSNETest extends AutomatedTestBase @@ -45,12 +42,12 @@ public void setUp() { @Test public void testTSNECP() throws IOException { runTSNETest(2, 30, 300., - 0.9, 1000, 42, "FALSE", ExecType.CP); + 0.9, 1000, 1e-5d, 42, "FALSE", 10, ExecType.CP); } @SuppressWarnings("unused") - private void runTSNETest(Integer reduced_dims, Integer perplexity, Double lr, - Double momentum, Integer max_iter, Integer seed, String is_verbose, ExecType instType) + private void runTSNETest(int reduced_dims, int perplexity, double lr, + double momentum, int max_iter, double tol, int seed, String is_verbose, Integer print_iter, ExecType instType) throws IOException { ExecMode platformOld = setExecMode(instType); @@ -68,8 +65,11 @@ private void runTSNETest(Integer reduced_dims, Integer perplexity, Double lr, "lr=" + lr, "momentum=" + momentum, "max_iter=" + max_iter, + "tol=" + tol, "seed=" + seed, - "is_verbose=" + is_verbose}; + "is_verbose=" + is_verbose, + "print_iter=" + print_iter + }; // The Input values are calculated using the following R script: // TODO create via dml operations, avoid inlining data @@ -412,7 +412,7 @@ private void runTSNETest(Integer reduced_dims, Integer perplexity, Double lr, @Test public void testTSNEEarlyStopping() throws IOException { // Test setup guarantees early stopping. - runTSNEEarlyStoppingTest(2, 30, 300., 0.9, 1000, 1e-1, 1, "TRUE", ExecType.CP); + runTSNEEarlyStoppingTest(2, 30, 300., 0.9, 1000, 1e-1, 1, "TRUE", 10, ExecType.CP); } @SuppressWarnings("unused") @@ -425,6 +425,7 @@ private void runTSNEEarlyStoppingTest( Double tol, Integer seed, String is_verbose, + Integer print_iter, ExecType instType) throws IOException { ExecMode platformOld = setExecMode(instType); @@ -441,9 +442,11 @@ private void runTSNEEarlyStoppingTest( "lr=" + lr, "momentum=" + momentum, "max_iter=" + max_iter, - "tol= " + tol, + "tol=" + tol, "seed=" + seed, - "is_verbose=" + is_verbose}; + "is_verbose=" + is_verbose, + "print_iter=" + print_iter + }; // The Input values are calculated using the following dml script: // X = rand(rows=50, cols=2, min=0, max=5, seed=1) @@ -502,22 +505,28 @@ private void runTSNEEarlyStoppingTest( {-4.570078514420281, 2.156235882831523} }; + writeInputMatrixWithMTD("X", X, true); + // Capture console output - ByteArrayOutputStream outContent = new ByteArrayOutputStream(); - System.setOut(new PrintStream(outContent)); + setOutputBuffering(true); + String out = runTest(true, false, null, -1).toString(); - runTest(true, false, null, -1); - // Parse and check L1 norm values - String[] lines = outContent.toString().split(System.lineSeparator()); + String[] lines = out.split(System.lineSeparator()); double prevL1Norm = Double.POSITIVE_INFINITY; boolean decreasing = true; + int notDecreasingCount = 0; // Counter to track consecutive non-decreasing values for (String line : lines) { if (line.startsWith("L1 Norm:")) { double l1Norm = Double.parseDouble(line.substring(9).trim()); if (l1Norm >= prevL1Norm) { - decreasing = false; - break; + notDecreasingCount++; + if (notDecreasingCount >= 3) { + decreasing = false; + break; // Exit the loop once we've seen 3 consecutive non-decreasing values + } + } else { + notDecreasingCount = 0; // Reset the counter if the current value is decreasing } prevL1Norm = l1Norm; } diff --git a/src/test/scripts/functions/builtin/tSNE.dml b/src/test/scripts/functions/builtin/tSNE.dml index 8310f75a39c..88e7c039102 100644 --- a/src/test/scripts/functions/builtin/tSNE.dml +++ b/src/test/scripts/functions/builtin/tSNE.dml @@ -20,5 +20,5 @@ #------------------------------------------------------------- X = read($X); -Y = tSNE(X, $reduced_dims, $perplexity, $lr, $momentum, $max_iter, $seed, $is_verbose) +Y = tSNE(X, $reduced_dims, $perplexity, $lr, $momentum, $max_iter, $tol, $seed, $is_verbose, $print_iter) write(Y, $Y)