-
Notifications
You must be signed in to change notification settings - Fork 3.9k
[Onnx] Fix NLL Loss tests #8971
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
593b188
7c19abb
160cbb4
1da783d
bb8594b
a611b64
e81dfc9
260ba96
10c2913
e4c2e91
fef2cd6
7949a39
02f1870
a184f7c
56650da
73d3d55
d7e24f8
5ff388b
dbbd42e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1219,11 +1219,13 @@ inline Tensor dyn_tile(const Tensor& x, Array<PrimExpr> new_shape, size_t rdim, | |
| * \param indices The indices of values to gather. | ||
| * \param name The name of the operation. | ||
| * \param tag The tag to mark the operation. | ||
| * \param support_negative_indices If negative indices are supported | ||
| * | ||
| * \return A Tensor whose op member is the gather operation | ||
| */ | ||
| inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, | ||
| std::string name = "T_gather", std::string tag = kInjective) { | ||
| bool support_negative_indices = false, std::string name = "T_gather", | ||
| std::string tag = kInjective) { | ||
| size_t ndim_d = data->shape.size(); | ||
| size_t ndim_i = indices->shape.size(); | ||
| ICHECK_GE(ndim_d, 1) << "Cannot gather from a scalar."; | ||
|
|
@@ -1242,6 +1244,8 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, | |
| out_shape.push_back(indices->shape[i]); | ||
| } | ||
|
|
||
| PrimExpr axis_size = data->shape[axis]; | ||
|
|
||
| return compute( | ||
| out_shape, | ||
| [&](const Array<Var>& out_index) { | ||
|
|
@@ -1252,12 +1256,13 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, | |
| Array<PrimExpr> real_indices; | ||
| for (size_t i = 0; i < ndim_i; ++i) { | ||
| if (i == static_cast<size_t>(axis)) { | ||
| real_indices.push_back(indices(indices_position)); | ||
| PrimExpr index = indices(indices_position); | ||
| real_indices.push_back(index); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove this diff |
||
| } else { | ||
| real_indices.push_back(indices_position[i]); | ||
| } | ||
| } | ||
| return data(real_indices); | ||
| return data(real_indices, support_negative_indices); | ||
| }, | ||
| name, tag); | ||
| } | ||
|
|
@@ -1270,11 +1275,13 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, | |
| * \param batch_dims The number of batch dimensions. | ||
| * \param name The name of the operation. | ||
| * \param tag The tag to mark the operation. | ||
| * \param support_negative_indices If negative indices are supported | ||
| * | ||
| * \return A Tensor whose op member is the gather_nd operation | ||
| */ | ||
| inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dims = 0, | ||
| std::string name = "T_gather_nd", std::string tag = kInjective) { | ||
| bool support_negative_indices = false, std::string name = "T_gather_nd", | ||
| std::string tag = kInjective) { | ||
| size_t ndim_d = data->shape.size(); | ||
| size_t ndim_i = indices->shape.size(); | ||
| ICHECK_GE(ndim_i, 1) << "indices tensor must have at least 1 dimensions"; | ||
|
|
@@ -1302,19 +1309,20 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dim | |
| } | ||
| for (size_t i = 0; i < indices_dim0; ++i) { | ||
| indices_position.Set(0, make_const(DataType::Int(32), i)); | ||
| if (indices->dtype.is_int()) { | ||
| real_indices.push_back(indices(indices_position)); | ||
| } else { | ||
| real_indices.push_back(tvm::cast(tvm::DataType::Int(32), indices(indices_position))); | ||
| PrimExpr index = indices(indices_position); | ||
|
|
||
| if (!indices->dtype.is_int()) { | ||
| index = tvm::cast(tvm::DataType::Int(32), index); | ||
| } | ||
| real_indices.push_back(index); | ||
| } | ||
| if (real_indices.size() == ndim_d) { | ||
| return data(real_indices); | ||
| return data(real_indices, support_negative_indices); | ||
| } | ||
| for (size_t i = ndim_i - 1; i < out_index.size(); ++i) { | ||
| real_indices.push_back(out_index[i]); | ||
| } | ||
| return data(real_indices); | ||
| return data(real_indices, support_negative_indices); | ||
| }, | ||
| name, tag); | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -39,15 +39,26 @@ IterVar reduce_axis(Range dom, std::string name) { return IterVar(dom, Var(name) | |||||||||||||||
| Var var(std::string name_hint, DataType t) { return Var(name_hint, t); } | ||||||||||||||||
|
|
||||||||||||||||
| // Tensor | ||||||||||||||||
| PrimExpr Tensor::operator()(Array<Var> indices) const { | ||||||||||||||||
| PrimExpr Tensor::operator()(Array<Var> indices, bool support_negative_indices) const { | ||||||||||||||||
| Array<PrimExpr> arr(indices.begin(), indices.end()); | ||||||||||||||||
| return operator()(arr); | ||||||||||||||||
| return operator()(arr, support_negative_indices); | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| PrimExpr Tensor::operator()(Array<PrimExpr> indices) const { | ||||||||||||||||
| if (ndim() != 0) { | ||||||||||||||||
| ICHECK_EQ(ndim(), indices.size()) << "Tensor dimension mismatch in read " | ||||||||||||||||
| << "ndim = " << ndim() << ", indices.size=" << indices.size(); | ||||||||||||||||
| PrimExpr Tensor::operator()(Array<PrimExpr> indices, bool support_negative_indices) const { | ||||||||||||||||
| Array<PrimExpr> shape = (*this)->shape; | ||||||||||||||||
|
|
||||||||||||||||
| if (shape.size() != 0) { | ||||||||||||||||
| ICHECK_EQ(shape.size(), indices.size()) | ||||||||||||||||
| << "Tensor dimension mismatch in read " | ||||||||||||||||
| << "ndim = " << ndim() << ", indices.size=" << indices.size(); | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| if (support_negative_indices) { | ||||||||||||||||
| for (size_t i = 0; i < shape.size(); i++) { | ||||||||||||||||
| PrimExpr new_index = if_then_else(indices[i] < make_const(indices[i]->dtype, 0), | ||||||||||||||||
| indices[i] + shape[i], indices[i]); | ||||||||||||||||
| indices.Set(i, new_index); | ||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Negative indices handling is also done in tvm/python/tvm/relay/op/transform.py Lines 926 to 927 in d9fe672
tvm/include/tvm/topi/detail/strided_slice.h Lines 45 to 48 in cbe3dca
tvm/include/tvm/topi/detail/strided_slice.h Line 105 in cbe3dca
I believe there are other cases like this spread across the code base. Maybe we should revisit all index-taking op and centralize negative indices handling. Generally I think people prefer not making a change down the stack.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm this is a good point. I think pushing down the stack is the right choice personally since I expect the most basic indexing op to work with negative indices. Since all of the other operations will use these basic indexing ops we should therefore get these things for free. In our case, we add a flag to a basic indexing operation which turns on this features. Otherwise we'll get a lot of copies of the same code everywhere.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I agree that implementation-wise, this is more convenient. Since this is a fundamental data structure change, how about we open a separate PR for negative indexing support to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that's a fair point. I'll refactor this to use |
||||||||||||||||
| } | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| return ProducerLoad((*this), indices); | ||||||||||||||||
|
|
||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove it