From d79d8acb9d22ababc79124ef5c041444a23d9269 Mon Sep 17 00:00:00 2001 From: Nicolas Feybesse Date: Mon, 26 Feb 2024 15:29:03 +0100 Subject: [PATCH 1/3] Interface should be public for external usage --- .../src/main/java/org/tensorflow/framework/metrics/Metric.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index c8c1df607c2..c2982e9b0b0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -22,7 +22,7 @@ import org.tensorflow.types.family.TNumber; /** Interface for metrics */ -interface Metric { +public interface Metric { /** * Creates a List of Operations to update the metric state based on input values. From 17951df8ef650b0727c548c835491777b2841776 Mon Sep 17 00:00:00 2001 From: Nicolas Feybesse Date: Tue, 27 Feb 2024 11:42:00 +0100 Subject: [PATCH 2/3] Fix https://github.com/tensorflow/java/issues/523 --- .../framework/losses/impl/LossesHelper.java | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java index 3e635b0d957..451706af1d1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java @@ -282,11 +282,12 @@ private static Operand reduceWeightedLoss( if (reduction == Reduction.NONE) { loss = weightedLoss; } else { - loss = - tf.reduceSum(weightedLoss, allAxes(tf, weightedLoss), ReduceSum.keepDims(Boolean.FALSE)); if (reduction == Reduction.AUTO || reduction == Reduction.SUM_OVER_BATCH_SIZE) { - loss = safeMean(tf, loss, weightedLoss.shape().size()); + loss = safeMean(tf, weightedLoss); } + else + loss = tf.reduceSum(weightedLoss, allAxes(tf, weightedLoss), ReduceSum.keepDims(Boolean.FALSE)); + } return loss; } @@ -302,9 +303,9 @@ private static Operand reduceWeightedLoss( * zero, then zero is returned. */ public static Operand safeMean( - Ops tf, Operand losses, long numElements) { - Operand totalLoss = tf.reduceSum(losses, allAxes(tf, losses)); - return tf.math.divNoNan(totalLoss, cast(tf, tf.constant(numElements), losses.type())); + Ops tf, Operand losses) { + Operand totalLoss = tf.reduceSum(losses, allAxes(tf, losses),ReduceSum.keepDims(Boolean.FALSE)); + return tf.math.divNoNan(totalLoss, cast(tf,tf.shape.size(tf.shape(losses)),losses.type())); } /** From 2b985ce36fc69946df412b97d5fd21739cebd485 Mon Sep 17 00:00:00 2001 From: Nicolas Feybesse Date: Wed, 28 Feb 2024 15:04:31 +0100 Subject: [PATCH 3/3] Fix google format --- .../framework/losses/impl/LossesHelper.java | 38 ++++++++++--------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java index 451706af1d1..6c40149f3de 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java @@ -51,7 +51,7 @@ public class LossesHelper { * @param tf the TensorFlow Ops * @param predictions Predicted values, a Operand of arbitrary dimensions. * @param labels Optional label Operand whose dimensions match prediction - * . + * . * @param the data type for the labels, predictions and result * @return LossTuple of prediction, label,sampleWeight will * be null. Each of them possibly has the last dimension squeezed, sampleWeight @@ -77,7 +77,7 @@ public static LossTuple squeezeOrExpandDimensions( * @param tf the TensorFlow Ops * @param predictions Predicted values, a Operand of arbitrary dimensions. * @param labels Optional label Operand whose dimensions match prediction - * . + * . * @param sampleWeights Optional sample weight(s) Operand whose dimensions match * * prediction. @@ -179,7 +179,7 @@ private static Operand maybeExpandWeights( * * @param tf the TensorFlowOps * @param labels Label values, a Tensor whose dimensions match predictions - * . + * . * @param predictions Predicted values, a Tensor of arbitrary dimensions. * @param the data type for the labels, predictions and result * @return labels and predictions, possibly with last dim squeezed. @@ -194,7 +194,7 @@ public static LossTuple removeSqueezableDimensions( * * @param tf the TensorFlowOps * @param labels Label values, a Operand whose dimensions match predictions - * . + * . * @param predictions Predicted values, a Tensor of arbitrary dimensions. * @param expectedRankDiff Expected result of rank(predictions) - rank(labels). * @param the data type for the labels, predictions and result @@ -222,11 +222,13 @@ public static LossTuple removeSqueezableDimensions( // Use dynamic rank. // TODO: hold for lazy select feature, - // Operand rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels)); + // Operand rankDiff = tf.math.sub(tf.rank(predictions), + // tf.rank(labels)); if (predictionsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(predictionsShape.size(-1), 1)) { /* - * TODO, if we ever get a select that does lazy evaluation, but for now do the tf.squeeze - * predictions = tf.select( tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ), + * TODO, if we ever get a select that does lazy evaluation, but for now do the + * tf.squeeze predictions = tf.select( + * tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ), * tf.squeeze(predictions, Squeeze.axis(Arrays.asList(-1L))), predictions ); * */ predictions = tf.squeeze(predictions, Squeeze.axis(Collections.singletonList(-1L))); @@ -284,10 +286,10 @@ private static Operand reduceWeightedLoss( } else { if (reduction == Reduction.AUTO || reduction == Reduction.SUM_OVER_BATCH_SIZE) { loss = safeMean(tf, weightedLoss); - } - else - loss = tf.reduceSum(weightedLoss, allAxes(tf, weightedLoss), ReduceSum.keepDims(Boolean.FALSE)); - + } else + loss = + tf.reduceSum( + weightedLoss, allAxes(tf, weightedLoss), ReduceSum.keepDims(Boolean.FALSE)); } return loss; } @@ -302,10 +304,10 @@ private static Operand reduceWeightedLoss( * @return A scalar representing the mean of losses. If numElements is * zero, then zero is returned. */ - public static Operand safeMean( - Ops tf, Operand losses) { - Operand totalLoss = tf.reduceSum(losses, allAxes(tf, losses),ReduceSum.keepDims(Boolean.FALSE)); - return tf.math.divNoNan(totalLoss, cast(tf,tf.shape.size(tf.shape(losses)),losses.type())); + public static Operand safeMean(Ops tf, Operand losses) { + Operand totalLoss = + tf.reduceSum(losses, allAxes(tf, losses), ReduceSum.keepDims(Boolean.FALSE)); + return tf.math.divNoNan(totalLoss, cast(tf, tf.shape.size(tf.shape(losses)), losses.type())); } /** @@ -349,7 +351,8 @@ public static Operand rangeCheck( tf.math.logicalAnd( tf.reduceAll(tf.math.greaterEqual(values, minValue), allDims), tf.reduceAll(tf.math.lessEqual(values, maxValue), allDims)); - // Graph and Eager mode need to be handled differently, control dependencies are not allowed in + // Graph and Eager mode need to be handled differently, control dependencies are + // not allowed in // Eager mode if (tf.scope().env().isGraph()) { AssertThat assertThat = @@ -399,7 +402,8 @@ public static Operand valueCheck( } else return values; } else { // use dynamic shape Operand cond = tf.math.equal(tf.shape.size(tf.shape(diff.out())), tf.constant(0)); - // Graph and Eager mode need to be handled differently, control dependencies are not allowed + // Graph and Eager mode need to be handled differently, control dependencies are + // not allowed // in Eager mode if (tf.scope().env().isGraph()) { AssertThat assertThat =