Skip to content

Error in stats.kendalls_tau as Keras Metric #1417

@gonzalesMK

Description

@gonzalesMK

I am trying to use TensorFlow probability as a metric in Keras. With respect to kendalls_tau, I get the following error:

import tensorflow_probability as tfp
import tensorflow as tf
import numpy as np 

def kendalls_tau(y_true, y_pred):
    a = tf.reshape(y_true, shape=(-1,))
    b = tf.reshape(y_pred, shape=(-1,))
    kendall = tfp.stats.kendalls_tau(a, b)
    return kendall

inputs = tf.keras.layers.Input(shape=(3,))
outputs = tf.keras.layers.Dense(2)(inputs)
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)

model.compile(optimizer="Adam", loss="mse", metrics=kendalls_tau)

x = np.random.random((2, 3))
y = np.random.randint(0, 2, (2, 2) )
model.fit(x, y)

TypeError: in user code:

    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:864 train_function  *
        return step_function(self, iterator)
    <ipython-input-4-14a2210abe73>:5 kendalls_tau  *
        kendall = tfp.stats.kendalls_tau(a, b)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:196 kendalls_tau  **
        lexa = lexicographical_indirect_sort(y_true, y_pred)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:154 lexicographical_indirect_sort
        left, _, lexicographic = tf.while_loop(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py:614 new_func
        return func(*args, **kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:2531 while_loop_v2
        return while_loop(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:2729 while_loop
        return while_v2.while_loop(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py:214 while_loop
        body_graph = func_graph_module.func_graph_from_py_func(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:1007 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py:200 wrapped_body
        outputs = body(*_pack_sequence_as(orig_loop_vars, args))
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:152 body
        tf.cond(not_equal, secondary_sort, lambda: lexicographic))
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
        return target(*args, **kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:1438 cond_for_tf_v2
        return cond(pred, true_fn=true_fn, false_fn=false_fn, strict=True, name=name)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
        return target(*args, **kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py:546 new_func
        return func(*args, **kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:1254 cond
        return cond_v2.cond_v2(pred, true_fn, false_fn, name)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/cond_v2.py:83 cond_v2
        true_graph = func_graph_module.func_graph_from_py_func(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:1007 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:148 secondary_sort
        tensorshape_util.set_shape(x, [n])
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/internal/tensorshape_util.py:328 set_shape
        tensor.set_shape(shape)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:758 set_shape
        shape = tensor_shape.TensorShape(shape)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:765 __init__
        self._dims = [Dimension(d) for d in dims]
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:765 <listcomp>
        self._dims = [Dimension(d) for d in dims]
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:206 __init__
        six.raise_from(
    <string>:3 raise_from
        

    TypeError: Dimension value must be integer or None or have an __index__ method, got value '<tf.Tensor 'kendalls_tau/lexicographical_indirect_sort/size0/strided_slice_1:0' shape=() dtype=int32>' with type '<class 'tensorflow.python.framework.ops.Tensor'>'
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/data/Mestrado/Ensaios/drbc_tf.py in 
     257 x = np.random.random((2, 3))
     258 y = np.random.randint(0, 2, (2, 2) )
---> 259 model.fit(x, y)
     260 

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
   1190                 _r=1):
   1191               callbacks.on_train_batch_begin(step)
-> 1192               tmp_logs = self.train_function(iterator)
   1193               if data_handler.should_sync:
   1194                 context.async_wait()

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    883 
    884       with OptionalXlaContext(self._jit_compile):
--> 885         result = self._call(*args, **kwds)
    886 
    887       new_tracing_count = self.experimental_get_tracing_count()

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    927       # This is the first call of __call__, so we have to initialize.
    928       initializers = []
--> 929       self._initialize(args, kwds, add_initializers_to=initializers)
    930     finally:
    931       # At this point we know that the initialization is complete (or less

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    757     self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph)
    758     self._concrete_stateful_fn = (
--> 759         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
    760             *args, **kwds))
    761 

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   3057       args, kwargs = None, None
   3058     with self._lock:
-> 3059       graph_function, _ = self._maybe_define_function(args, kwargs)
   3060     return graph_function
   3061 

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3454 
   3455           self._function_cache.missed.add(call_context_key)
-> 3456           graph_function = self._create_graph_function(args, kwargs)
   3457           self._function_cache.primary[cache_key] = graph_function
   3458 

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3289     arg_names = base_arg_names + missing_arg_names
   3290     graph_function = ConcreteFunction(
-> 3291         func_graph_module.func_graph_from_py_func(
   3292             self._name,
   3293             self._python_function,

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses)
   1005         _, original_func = tf_decorator.unwrap(python_func)
   1006 
-> 1007       func_outputs = python_func(*func_args, **func_kwargs)
   1008 
   1009       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    666         # the function a weak reference to itself to avoid a reference cycle.
    667         with OptionalXlaContext(compile_with_xla):
--> 668           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    669         return out
    670 

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    992           except Exception as e:  # pylint:disable=broad-except
    993             if hasattr(e, "ag_error_metadata"):
--> 994               raise e.ag_error_metadata.to_exception(e)
    995             else:
    996               raise

TypeError: in user code:

    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:864 train_function  *
        return step_function(self, iterator)
    <ipython-input-4-14a2210abe73>:5 kendalls_tau  *
        kendall = tfp.stats.kendalls_tau(a, b)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:196 kendalls_tau  **
        lexa = lexicographical_indirect_sort(y_true, y_pred)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:154 lexicographical_indirect_sort
        left, _, lexicographic = tf.while_loop(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py:614 new_func
        return func(*args, **kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:2531 while_loop_v2
        return while_loop(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:2729 while_loop
        return while_v2.while_loop(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py:214 while_loop
        body_graph = func_graph_module.func_graph_from_py_func(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:1007 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py:200 wrapped_body
        outputs = body(*_pack_sequence_as(orig_loop_vars, args))
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:152 body
        tf.cond(not_equal, secondary_sort, lambda: lexicographic))
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
        return target(*args, **kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:1438 cond_for_tf_v2
        return cond(pred, true_fn=true_fn, false_fn=false_fn, strict=True, name=name)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
        return target(*args, **kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py:546 new_func
        return func(*args, **kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:1254 cond
        return cond_v2.cond_v2(pred, true_fn, false_fn, name)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/cond_v2.py:83 cond_v2
        true_graph = func_graph_module.func_graph_from_py_func(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:1007 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:148 secondary_sort
        tensorshape_util.set_shape(x, [n])
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/internal/tensorshape_util.py:328 set_shape
        tensor.set_shape(shape)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:758 set_shape
        shape = tensor_shape.TensorShape(shape)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:765 __init__
        self._dims = [Dimension(d) for d in dims]
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:765 <listcomp>
        self._dims = [Dimension(d) for d in dims]
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:206 __init__
        six.raise_from(
    <string>:3 raise_from
        

    TypeError: Dimension value must be integer or None or have an __index__ method, got value '<tf.Tensor 'kendalls_tau/lexicographical_indirect_sort/size0/strided_slice_1:0' shape=() dtype=int32>' with type '<class 'tensorflow.python.framework.ops.Tensor'>'

How can I fix this?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions