diff --git a/scripts/builtin/tSNE.dml b/scripts/builtin/tSNE.dml index 131ab1013cb..299ecbd9b07 100644 --- a/scripts/builtin/tSNE.dml +++ b/scripts/builtin/tSNE.dml @@ -31,9 +31,12 @@ # lr Learning rate # momentum Momentum Parameter # max_iter Number of iterations +# tol Tolerance for early stopping in gradient descent # seed The seed used for initial values. # If set to -1 random seeds are selected. # is_verbose Print debug information +# print_iter Intervals of printing out the L1 norm values. Parameter not relevant if +# is_verbose = FALSE. # ------------------------------------------------------------------------------------------- # # OUTPUT: @@ -42,7 +45,8 @@ # ------------------------------------------------------------------------------------------- m_tSNE = function(Matrix[Double] X, Integer reduced_dims = 2, Integer perplexity = 30, - Double lr = 300., Double momentum = 0.9, Integer max_iter = 1000, Integer seed = -1, Boolean is_verbose = FALSE) + Double lr = 300., Double momentum = 0.9, Integer max_iter = 1000, Double tol = 1e-5, + Integer seed = -1, Boolean is_verbose = FALSE, Integer print_iter = 10) return(Matrix[Double] Y) { d = reduced_dims @@ -63,8 +67,41 @@ m_tSNE = function(Matrix[Double] X, Integer reduced_dims = 2, Integer perplexity if(is_verbose) print("starting loop....") - for (itr in 1:max_iter) { - D = distance_matrix(Y) + itr = 1 + + # Start first iteration out of loop as benchmark for early stopping + D = dist(Y) + Z = 1/(D + 1) + Z = Z * ZERODIAG + Q = Z/sum(Z) + W = (P - Q)*Z + sumW = rowSums(W) + g = Y * sumW - W %*% Y + dY = momentum*dY - lr*g + + norm = sum(dY^2) + norm_initial = norm + norm_target = norm_initial * tol + + if(is_verbose){ + print("L1 Norm initial : " + norm_initial) + print("L1 Norm target : " + norm_target) + } + + Y = Y + dY + Y = Y - colMeans(Y) + + if (itr%%100 == 0) { + C[itr/100,] = sum(P * log(pmax(P, 1e-12) / pmax(Q, 1e-12))) + } + if (itr == 100) { + P = P/4 + } + itr = itr + 1 + # End of first iteration + + while (itr <= max_iter & norm > norm_target) { + D = dist(Y) Z = 1/(D + 1) Z = Z * ZERODIAG Q = Z/sum(Z) @@ -72,6 +109,13 @@ m_tSNE = function(Matrix[Double] X, Integer reduced_dims = 2, Integer perplexity sumW = rowSums(W) g = Y * sumW - W %*% Y dY = momentum*dY - lr*g + + norm = sum(dY^2) + if(is_verbose & itr %% print_iter == 0){ + print("Iteration: " + itr) + print("L1 Norm: " + norm) + } + Y = Y + dY Y = Y - colMeans(Y) @@ -81,20 +125,10 @@ m_tSNE = function(Matrix[Double] X, Integer reduced_dims = 2, Integer perplexity if (itr == 100) { P = P/4 } + itr = itr + 1 } } -distance_matrix = function(matrix[double] X) - return (matrix[double] out) -{ - # TODO consolidate with dist() builtin, but with - # better way of obtaining the diag from - n = nrow(X) - s = rowSums(X * X) - out = - 2*X %*% t(X) + s + t(s) -} - - x2p = function(matrix[double] X, double perplexity, Boolean is_verbose = FALSE) return(matrix[double] P) { @@ -105,7 +139,7 @@ return(matrix[double] P) n = nrow(X) if(is_verbose) print(n) - D = distance_matrix(X) + D = dist(X) P = matrix(0, rows=n, cols=n) beta = matrix(1, rows=n, cols=1) @@ -119,7 +153,7 @@ return(matrix[double] P) while (mean(abs(Hdiff)) > tol & itr < 50) { P = exp(-D * beta) P = P * ZERODIAG - sum_Pi = rowSums(P) + sum_Pi = rowSums(P) + 1e-12 W = rowSums(P * D) Ws = W/sum_Pi H = log(sum_Pi) + beta * Ws diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java index 41d2d4606fb..44d06ffaf9a 100644 --- a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java @@ -24,6 +24,7 @@ import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; import org.junit.Test; +import static org.junit.Assert.assertTrue; import java.io.IOException; @@ -41,12 +42,12 @@ public void setUp() { @Test public void testTSNECP() throws IOException { runTSNETest(2, 30, 300., - 0.9, 1000, 42, "FALSE", ExecType.CP); + 0.9, 1000, 1e-5d, 42, "FALSE", 10, ExecType.CP); } @SuppressWarnings("unused") - private void runTSNETest(Integer reduced_dims, Integer perplexity, Double lr, - Double momentum, Integer max_iter, Integer seed, String is_verbose, ExecType instType) + private void runTSNETest(int reduced_dims, int perplexity, double lr, + double momentum, int max_iter, double tol, int seed, String is_verbose, Integer print_iter, ExecType instType) throws IOException { ExecMode platformOld = setExecMode(instType); @@ -64,8 +65,11 @@ private void runTSNETest(Integer reduced_dims, Integer perplexity, Double lr, "lr=" + lr, "momentum=" + momentum, "max_iter=" + max_iter, + "tol=" + tol, "seed=" + seed, - "is_verbose=" + is_verbose}; + "is_verbose=" + is_verbose, + "print_iter=" + print_iter + }; // The Input values are calculated using the following R script: // TODO create via dml operations, avoid inlining data @@ -403,4 +407,136 @@ private void runTSNETest(Integer reduced_dims, Integer perplexity, Double lr, rtplatform = platformOld; } } + + + @Test + public void testTSNEEarlyStopping() throws IOException { + // Test setup guarantees early stopping. + runTSNEEarlyStoppingTest(2, 30, 300., 0.9, 1000, 1e-1, 1, "TRUE", 10, ExecType.CP); + } + + @SuppressWarnings("unused") + private void runTSNEEarlyStoppingTest( + Integer reduced_dims, + Integer perplexity, + Double lr, + Double momentum, + Integer max_iter, + Double tol, + Integer seed, + String is_verbose, + Integer print_iter, + ExecType instType) throws IOException { + + ExecMode platformOld = setExecMode(instType); + try + { + loadTestConfiguration(getTestConfiguration(TEST_NAME)); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[]{ + "-nvargs", "X=" + input("X"), "Y=" + output("Y"), + "reduced_dims=" + reduced_dims, + "perplexity=" + perplexity, + "lr=" + lr, + "momentum=" + momentum, + "max_iter=" + max_iter, + "tol=" + tol, + "seed=" + seed, + "is_verbose=" + is_verbose, + "print_iter=" + print_iter + }; + + // The Input values are calculated using the following dml script: + // X = rand(rows=50, cols=2, min=0, max=5, seed=1) + + // Input + double[][] X = { + {-0.45700987356506406, 2.834752454661148}, + {-1.3945444464226533, -3.8794723634582597}, + {-1.338576451510809, 1.160918547857504}, + {2.921891699889728, -1.32749074959577}, + {1.3001464754324763, -1.1353208514333533}, + {0.2866401390950628, 3.359214248871961}, + {2.6740553056629217, -1.2030274345674852}, + {1.7240446900374895, 3.4430052477647557}, + {-0.3435254305219493, 4.205393963204703}, + {-2.873899183923896, 1.098272406118296}, + {4.890217056606042, 1.5814575251762104}, + {-4.920042511612875, 4.579455675519821}, + {1.439881754507784, -4.090781835042895}, + {2.32372435941579, 4.823050596338641}, + {-0.9864739586714544, -1.6990853495458147}, + {4.605792626050157, 2.411639339263437}, + {4.979120527950069, 1.7181757158820465}, + {-4.423608438974177, 0.44712526968937283}, + {3.4109472479162317, -3.497269670333382}, + {-1.9938801849366037, -1.1880069697833906}, + {3.223381639747396, 3.7784510177449793}, + {2.10470587687118, 0.5415570090498525}, + {2.084254693325721, 1.4369473809787037}, + {-0.9957311983302795, 1.586795215124286}, + {-3.7527381124013894, 4.3818220996816475}, + {3.5748622228245193, 1.116518048277384}, + {-2.297351475873446, -2.0179124546489047}, + {-0.3438938003649259, 0.689249021371154}, + {-0.8823286368673617, 1.2731356499886672}, + {2.517220722615252, -2.8806532181877254}, + {3.923092638022041, 4.34404320783608}, + {-2.1012040153953, -4.33147229525127}, + {3.5992422607685715, 2.5628828792092904}, + {4.3431460760781775, -2.6869010463029754}, + {-3.27506631006849, -1.1828954200032116}, + {-4.3138906717810475, -3.7311556655569875}, + {4.674799759142193, 3.783941497422669}, + {3.561677127461424, 1.699651989293141}, + {-3.0146338910401838, 3.3961817590254952}, + {-4.438156472502506, 0.5926080631113129}, + {-4.6425401564313615, 2.131545102584216}, + {3.2975878235392244, -2.8485717910480988}, + {-0.9776972765619627, 0.5292861827847535}, + {-3.9770843662935915, -2.258269867772177}, + {-4.22908475002643, -4.574457493889454}, + {-0.28759876443714827, -0.5841999820607002}, + {2.33121643992511, 1.7993339510854582}, + {-1.476311475439723, 4.3511414590258894}, + {4.974472387105775, -4.165990440844669}, + {-4.570078514420281, 2.156235882831523} + }; + + writeInputMatrixWithMTD("X", X, true); + + // Capture console output + setOutputBuffering(true); + String out = runTest(true, false, null, -1).toString(); + + // Parse and check L1 norm values + String[] lines = out.split(System.lineSeparator()); + double prevL1Norm = Double.POSITIVE_INFINITY; + boolean decreasing = true; + int notDecreasingCount = 0; // Counter to track consecutive non-decreasing values + for (String line : lines) { + if (line.startsWith("L1 Norm:")) { + double l1Norm = Double.parseDouble(line.substring(9).trim()); + if (l1Norm >= prevL1Norm) { + notDecreasingCount++; + if (notDecreasingCount >= 3) { + decreasing = false; + break; // Exit the loop once we've seen 3 consecutive non-decreasing values + } + } else { + notDecreasingCount = 0; // Reset the counter if the current value is decreasing + } + prevL1Norm = l1Norm; + } + } + + assertTrue("L1 norm should decrease each time it is printed out", decreasing); + } + finally { + rtplatform = platformOld; + } + + } } diff --git a/src/test/scripts/functions/builtin/tSNE.dml b/src/test/scripts/functions/builtin/tSNE.dml index 8310f75a39c..88e7c039102 100644 --- a/src/test/scripts/functions/builtin/tSNE.dml +++ b/src/test/scripts/functions/builtin/tSNE.dml @@ -20,5 +20,5 @@ #------------------------------------------------------------- X = read($X); -Y = tSNE(X, $reduced_dims, $perplexity, $lr, $momentum, $max_iter, $seed, $is_verbose) +Y = tSNE(X, $reduced_dims, $perplexity, $lr, $momentum, $max_iter, $tol, $seed, $is_verbose, $print_iter) write(Y, $Y)