Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,22 @@ struct ScatterNDAttrs : public tvm::AttrsNode<ScatterNDAttrs> {

struct GatherAttrs : public tvm::AttrsNode<GatherAttrs> {
Integer axis;
Bool support_negative_indices = Bool(false);

TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherAttrs") {
TVM_ATTR_FIELD(axis)
.set_default(NullValue<Integer>())
.describe("The axis over which to select values.");
TVM_ATTR_FIELD(support_negative_indices)
.set_default(Bool(false))
.describe("If negative indices are supported.");
}
};

struct GatherNDAttrs : public tvm::AttrsNode<GatherNDAttrs> {
Integer batch_dims;
Optional<Integer> index_rank;
Bool support_negative_indices = Bool(false);

TVM_DECLARE_ATTRS(GatherNDAttrs, "relay.attrs.GatherNDAttrs") {
TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of batch dimensions.");
Expand All @@ -155,6 +160,9 @@ struct GatherNDAttrs : public tvm::AttrsNode<GatherNDAttrs> {
.describe(
"The size of an indexing tuple, which is a fixed value. Only needed when the number of "
"indexting tuples is dynamic.");
TVM_ATTR_FIELD(support_negative_indices)
.set_default(Bool(false))
.describe("If negative indices are supported.");
}
};

Expand Down
6 changes: 4 additions & 2 deletions include/tvm/te/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,17 @@ class Tensor : public DataProducer {
/*!
* \brief Take elements from the tensor
* \param indices the indices.
* \param support_negative_indices whether we support negative indexing which is slightly slower.
* \return the result expression representing tensor read.
*/
TVM_DLL PrimExpr operator()(Array<PrimExpr> indices) const;
TVM_DLL PrimExpr operator()(Array<PrimExpr> indices, bool support_negative_indices = false) const;
/*!
* \brief Take elements from the tensor
* \param indices the indices.
* \param support_negative_indices whether we support negative indexing which is slightly slower.
* \return the result expression representing tensor read.
*/
TVM_DLL PrimExpr operator()(Array<Var> indices) const;
TVM_DLL PrimExpr operator()(Array<Var> indices, bool support_negative_indices = false) const;
/*!
* \brief data structure to represent a slice that fixes first k coordinates.
* This is used to enable syntax sugar of Tensor[x][y][z] to get the element.
Expand Down
28 changes: 18 additions & 10 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1219,11 +1219,13 @@ inline Tensor dyn_tile(const Tensor& x, Array<PrimExpr> new_shape, size_t rdim,
* \param indices The indices of values to gather.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
* \param support_negative_indices If negative indices are supported
*
* \return A Tensor whose op member is the gather operation
*/
inline Tensor gather(const Tensor& data, int axis, const Tensor& indices,
std::string name = "T_gather", std::string tag = kInjective) {
bool support_negative_indices = false, std::string name = "T_gather",
std::string tag = kInjective) {
size_t ndim_d = data->shape.size();
size_t ndim_i = indices->shape.size();
ICHECK_GE(ndim_d, 1) << "Cannot gather from a scalar.";
Expand All @@ -1242,6 +1244,8 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices,
out_shape.push_back(indices->shape[i]);
}

PrimExpr axis_size = data->shape[axis];
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove it


return compute(
out_shape,
[&](const Array<Var>& out_index) {
Expand All @@ -1252,12 +1256,13 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices,
Array<PrimExpr> real_indices;
for (size_t i = 0; i < ndim_i; ++i) {
if (i == static_cast<size_t>(axis)) {
real_indices.push_back(indices(indices_position));
PrimExpr index = indices(indices_position);
real_indices.push_back(index);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this diff

} else {
real_indices.push_back(indices_position[i]);
}
}
return data(real_indices);
return data(real_indices, support_negative_indices);
},
name, tag);
}
Expand All @@ -1270,11 +1275,13 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices,
* \param batch_dims The number of batch dimensions.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
* \param support_negative_indices If negative indices are supported
*
* \return A Tensor whose op member is the gather_nd operation
*/
inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dims = 0,
std::string name = "T_gather_nd", std::string tag = kInjective) {
bool support_negative_indices = false, std::string name = "T_gather_nd",
std::string tag = kInjective) {
size_t ndim_d = data->shape.size();
size_t ndim_i = indices->shape.size();
ICHECK_GE(ndim_i, 1) << "indices tensor must have at least 1 dimensions";
Expand Down Expand Up @@ -1302,19 +1309,20 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dim
}
for (size_t i = 0; i < indices_dim0; ++i) {
indices_position.Set(0, make_const(DataType::Int(32), i));
if (indices->dtype.is_int()) {
real_indices.push_back(indices(indices_position));
} else {
real_indices.push_back(tvm::cast(tvm::DataType::Int(32), indices(indices_position)));
PrimExpr index = indices(indices_position);

if (!indices->dtype.is_int()) {
index = tvm::cast(tvm::DataType::Int(32), index);
}
real_indices.push_back(index);
}
if (real_indices.size() == ndim_d) {
return data(real_indices);
return data(real_indices, support_negative_indices);
}
for (size_t i = ndim_i - 1; i < out_index.size(); ++i) {
real_indices.push_back(out_index[i]);
}
return data(real_indices);
return data(real_indices, support_negative_indices);
},
name, tag);
}
Expand Down
26 changes: 20 additions & 6 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,8 +1544,7 @@ def _impl_v1(cls, inputs, attr, params):
data = inputs[0]
indices = inputs[1]
axis = attr.get("axis", 0)
indices = normalize_gather_indices(data, indices, axis)
return _op.gather(data, axis, indices)
return _op.gather(data, axis, indices, support_negative_indices=True)


class GatherND(OnnxOpConverter):
Expand All @@ -1557,7 +1556,13 @@ def _impl_common(cls, data, indices, batch_dims=0):
indices_shape = infer_shape(indices)
indices = _op.transpose(indices, axes=[-1] + list(range(indices_dims - 1)))
index_rank = indices_shape[-1]
return _op.gather_nd(data, indices, batch_dims, index_rank)
return _op.gather_nd(
data,
indices,
batch_dims=batch_dims,
support_negative_indices=True,
index_rank=index_rank,
)

@classmethod
def _impl_v1(cls, inputs, attr, params):
Expand Down Expand Up @@ -3550,12 +3555,19 @@ def _impl_v13(cls, inputs, attr, params):
dtype=input_tensor.type_annotation.dtype,
)

loss = -relay.gather(input_tensor, axis=1, indices=relay.expand_dims(target_tensor, 1))
loss = -relay.gather(
input_tensor,
axis=1,
indices=relay.expand_dims(target_tensor, 1),
support_negative_indices=True,
)
loss = relay.squeeze(loss, axis=[1])

expanded_target_tensor = relay.expand_dims(target_tensor, 0)
expanded_target_tensor = relay.nn.batch_flatten(expanded_target_tensor)
flattened_weights = relay.gather_nd(weight_tensor, expanded_target_tensor)
flattened_weights = relay.gather_nd(
weight_tensor, expanded_target_tensor, support_negative_indices=True
)
select_weights = relay.reshape_like(flattened_weights, loss)
loss *= select_weights

Expand All @@ -3565,7 +3577,9 @@ def _impl_v13(cls, inputs, attr, params):
target_tensor, relay.const(ignore_index, dtype=target_tensor.type_annotation.dtype)
)
mask_tensor = relay.const(1, dtype="int8") - relay.cast(mask_tensor, "int8")
loss *= relay.cast_like(mask_tensor, loss)
loss = relay.where(
mask_tensor, loss, relay.const(0, infer_type(loss).checked_type.dtype)
)

# This is not explained super clearly in the onnx spec, but masked values don't
# contribute toward the final value in reduction
Expand Down
16 changes: 12 additions & 4 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,7 +1046,7 @@ def reverse_reshape(data, newshape):
return _make.contrib_reverse_reshape(data, list(newshape))


def gather(data, axis, indices):
def gather(data, axis, indices, support_negative_indices=False):
"""Gather values along given axis from given indices.

E.g. for a 3D tensor, output is computed as:
Expand All @@ -1071,6 +1071,10 @@ def gather(data, axis, indices):
indices: relay.Expr
The indices of values to gather.

support_negative_indices: bool
If True, support indices being negative. This is slower than supporting only
positive indices.

Examples
--------
.. code-block:: python
Expand All @@ -1080,10 +1084,10 @@ def gather(data, axis, indices):
indices = [[0, 0], [1, 0]]
relay.gather(data, axis, indices) = [[1, 1], [4, 3]]
"""
return _make.gather(data, axis, indices)
return _make.gather(data, axis, indices, support_negative_indices)


def gather_nd(data, indices, batch_dims=0, index_rank=None):
def gather_nd(data, indices, batch_dims=0, support_negative_indices=False, index_rank=None):
"""Gather elements or slices from data and store to a tensor whose shape is
defined by indices.

Expand All @@ -1102,6 +1106,10 @@ def gather_nd(data, indices, batch_dims=0, index_rank=None):
The size of an indexing tuple, which is a fixed value and the same as indices.shape[0]
Only needed when other dimensions of indices are dynamic.

support_negative_indices: bool
If True, support indices being negative. This is slower than supporting only
positive indices.

Returns
-------
ret : relay.Expr
Expand All @@ -1123,7 +1131,7 @@ def gather_nd(data, indices, batch_dims=0, index_rank=None):
indices = [[1, 0]]
relay.gather_nd(data, indices, batch_dims=1) = [[2,3],[4,5]]
"""
return _make.gather_nd(data, indices, batch_dims, index_rank)
return _make.gather_nd(data, indices, batch_dims, support_negative_indices, index_rank)


def sequence_mask(data, valid_length, mask_value=0, axis=0):
Expand Down
10 changes: 7 additions & 3 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3274,12 +3274,13 @@ bool GatherRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
Array<te::Tensor> GatherCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* param = attrs.as<GatherAttrs>();
return {topi::gather(inputs[0], param->axis, inputs[1])};
return {topi::gather(inputs[0], param->axis, inputs[1], param->support_negative_indices)};
}

Expr MakeGather(Expr data, Integer axis, Expr indices) {
Expr MakeGather(Expr data, Integer axis, Expr indices, Bool support_negative_indices) {
auto attrs = make_object<GatherAttrs>();
attrs->axis = std::move(axis);
attrs->support_negative_indices = std::move(support_negative_indices);
static const Op& op = Op::Get("gather");
return Call(op, {data, indices}, Attrs(attrs), {});
}
Expand Down Expand Up @@ -3353,15 +3354,18 @@ Array<te::Tensor> GatherNDCompute(const Attrs& attrs, const Array<te::Tensor>& i
const Type& out_type) {
const auto* param = attrs.as<GatherNDAttrs>();
ICHECK(param);
return {topi::gather_nd(inputs[0], inputs[1], param->batch_dims)};
return {
topi::gather_nd(inputs[0], inputs[1], param->batch_dims, param->support_negative_indices)};
}

Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0,
Bool support_negative_indices = Bool(0),
Optional<Integer> index_rank = NullValue<Integer>()) {
static const Op& op = Op::Get("gather_nd");
auto attrs = make_object<GatherNDAttrs>();
attrs->batch_dims = batch_dims;
attrs->index_rank = index_rank;
attrs->support_negative_indices = support_negative_indices;
return Call(op, {data, indices}, Attrs(attrs));
}

Expand Down
23 changes: 17 additions & 6 deletions src/te/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,26 @@ IterVar reduce_axis(Range dom, std::string name) { return IterVar(dom, Var(name)
Var var(std::string name_hint, DataType t) { return Var(name_hint, t); }

// Tensor
PrimExpr Tensor::operator()(Array<Var> indices) const {
PrimExpr Tensor::operator()(Array<Var> indices, bool support_negative_indices) const {
Array<PrimExpr> arr(indices.begin(), indices.end());
return operator()(arr);
return operator()(arr, support_negative_indices);
}

PrimExpr Tensor::operator()(Array<PrimExpr> indices) const {
if (ndim() != 0) {
ICHECK_EQ(ndim(), indices.size()) << "Tensor dimension mismatch in read "
<< "ndim = " << ndim() << ", indices.size=" << indices.size();
PrimExpr Tensor::operator()(Array<PrimExpr> indices, bool support_negative_indices) const {
Array<PrimExpr> shape = (*this)->shape;

if (shape.size() != 0) {
ICHECK_EQ(shape.size(), indices.size())
<< "Tensor dimension mismatch in read "
<< "ndim = " << ndim() << ", indices.size=" << indices.size();
}

if (support_negative_indices) {
for (size_t i = 0; i < shape.size(); i++) {
PrimExpr new_index = if_then_else(indices[i] < make_const(indices[i]->dtype, 0),
indices[i] + shape[i], indices[i]);
indices.Set(i, new_index);
Copy link
Copy Markdown
Member

@masahi masahi Sep 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Negative indices handling is also done in

begin = _make.where(begin < cast_like(const(0), begin), begin + ishape_slice, begin)
begin = _make.where(begin >= ishape_slice, ishape_slice, begin)

int64_t end_range = stride < 0 ? extent - 1 : extent;
if (index < 0) {
index += extent;
}

PrimExpr b = begin[i] < 0 ? b_expr + idim : b_expr;

I believe there are other cases like this spread across the code base. Maybe we should revisit all index-taking op and centralize negative indices handling. Generally I think people prefer not making a change down the stack.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm this is a good point. I think pushing down the stack is the right choice personally since I expect the most basic indexing op to work with negative indices. Since all of the other operations will use these basic indexing ops we should therefore get these things for free. In our case, we add a flag to a basic indexing operation which turns on this features.

Otherwise we'll get a lot of copies of the same code everywhere.

Copy link
Copy Markdown
Member

@masahi masahi Sep 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree that implementation-wise, this is more convenient. Since this is a fundamental data structure change, how about we open a separate PR for negative indexing support to te::Tensor, to get opinions from more people?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's a fair point. I'll refactor this to use normalize_gather_indices() in the meantime and do as you say.

}
}

return ProducerLoad((*this), indices);
Expand Down
5 changes: 0 additions & 5 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4787,11 +4787,6 @@ def verify_eyelike(indata):
"test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded",
"test_nllloss_NCd1d2d3d4d5_mean_weight_expanded",
"test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded",
# These nllloss tests are flaky and sometimes gives NaNs
# Investigate it here: https://github.com/apache/tvm/issues/8918
"test_nllloss_NCd1d2d3_none_no_weight_negative_ii",
# Investigate it here: https://github.com/apache/tvm/issues/8964
"test_nllloss_NCd1d2d3_sum_weight_high_ii",
"test_qlinearmatmul_2D",
"test_qlinearmatmul_3D",
"test_range_float_type_positive_delta_expanded",
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from tvm.relay.loops import while_loop
from tvm.relay.testing import run_infer_type as infer_type

from utils.assert_diagnostic import DiagnosticTesting
from utils import ref_funcs
from utils.assert_diagnostic import DiagnosticTesting


def int32(val):
Expand Down Expand Up @@ -2046,7 +2046,7 @@ def test_gather_nd():
def verify_gather_nd(data_shape, indices_shape, data_shape_np, indices_shape_np, batch_dims=0):
x = relay.var("x", relay.TensorType(data_shape, "float32"))
y = relay.var("y", relay.TensorType(indices_shape, "int32"))
z = relay.gather_nd(x, y, batch_dims, indices_shape[0])
z = relay.gather_nd(x, y, batch_dims=batch_dims, index_rank=indices_shape[0])

mod = tvm.IRModule()
mod["main"] = relay.Function([x, y], z)
Expand Down
Loading