Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
53d5ada
support select_last_index for argmin/max
AndrewZhaoLuo Aug 23, 2021
95a6517
reverse conditions which made on accident
AndrewZhaoLuo Aug 23, 2021
5e1f06a
forward args in reduce.py
AndrewZhaoLuo Aug 23, 2021
962e38a
make proper nodes for reduction ops
AndrewZhaoLuo Aug 23, 2021
f92089b
remove complicated nested lambdas
AndrewZhaoLuo Aug 23, 2021
9edc8e6
fix lambda capture for conversion
AndrewZhaoLuo Aug 23, 2021
9e4e69a
forward more arguments
AndrewZhaoLuo Aug 23, 2021
5cf4772
forward more args
AndrewZhaoLuo Aug 23, 2021
4f5a662
enable onnx tests
AndrewZhaoLuo Aug 23, 2021
75cb608
wrapping casts to remove ambiguity
Aug 24, 2021
7a60353
revert changes extraneous
AndrewZhaoLuo Aug 24, 2021
6a9e82f
correct incorrect attrs being used for ops
AndrewZhaoLuo Aug 24, 2021
0fb5db5
change attributes
AndrewZhaoLuo Aug 24, 2021
55a412d
remove old impl
Aug 24, 2021
93173fc
register new attribute node
AndrewZhaoLuo Aug 24, 2021
47b7eed
clean up test
AndrewZhaoLuo Aug 24, 2021
e62513b
reformat
AndrewZhaoLuo Aug 24, 2021
e9ea784
reformat
AndrewZhaoLuo Aug 24, 2021
587e94a
coolio
AndrewZhaoLuo Aug 24, 2021
d048e25
stable comparison
AndrewZhaoLuo Aug 24, 2021
71ab1f3
casts to avoid ambiguity
AndrewZhaoLuo Aug 24, 2021
aecf630
casting more
AndrewZhaoLuo Aug 24, 2021
423d092
correct arg passing
AndrewZhaoLuo Aug 26, 2021
2faf06d
support select_last_index for argmin/max
AndrewZhaoLuo Aug 23, 2021
edbc0f1
reverse conditions which made on accident
AndrewZhaoLuo Aug 23, 2021
ba7f57c
forward args in reduce.py
AndrewZhaoLuo Aug 23, 2021
dbf6dc1
make proper nodes for reduction ops
AndrewZhaoLuo Aug 23, 2021
fa4dd43
remove complicated nested lambdas
AndrewZhaoLuo Aug 23, 2021
78cc734
fix lambda capture for conversion
AndrewZhaoLuo Aug 23, 2021
0979f4d
forward more arguments
AndrewZhaoLuo Aug 23, 2021
647413e
forward more args
AndrewZhaoLuo Aug 23, 2021
f694e58
enable onnx tests
AndrewZhaoLuo Aug 23, 2021
576c56b
wrapping casts to remove ambiguity
Aug 24, 2021
67b5762
revert changes extraneous
AndrewZhaoLuo Aug 24, 2021
6d59d1c
correct incorrect attrs being used for ops
AndrewZhaoLuo Aug 24, 2021
d7a595f
change attributes
AndrewZhaoLuo Aug 24, 2021
6b645de
remove old impl
Aug 24, 2021
0faf5b6
register new attribute node
AndrewZhaoLuo Aug 24, 2021
96d85c2
clean up test
AndrewZhaoLuo Aug 24, 2021
8a6a4bc
reformat
AndrewZhaoLuo Aug 24, 2021
29a2660
reformat
AndrewZhaoLuo Aug 24, 2021
3a2a38d
coolio
AndrewZhaoLuo Aug 24, 2021
296ac2e
stable comparison
AndrewZhaoLuo Aug 24, 2021
12f7213
casts to avoid ambiguity
AndrewZhaoLuo Aug 24, 2021
20cdd36
casting more
AndrewZhaoLuo Aug 24, 2021
49b6322
correct arg passing
AndrewZhaoLuo Aug 26, 2021
fcc420e
Merge branch 'aluo/onnx/argmin_and_argmax' of github.com:AndrewZhaoLu…
AndrewZhaoLuo Aug 27, 2021
8f37f89
fix broken input
AndrewZhaoLuo Aug 27, 2021
2db29ca
OneElementReduceAttrs-->ArgReduceAttrs"
Aug 30, 2021
4055190
reduce boilerplate
Aug 30, 2021
1f56147
change names
Aug 30, 2021
d4cbfcc
remove log statement
Aug 30, 2021
c5f308b
jostle ci
AndrewZhaoLuo Aug 31, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
support select_last_index for argmin/max
  • Loading branch information
AndrewZhaoLuo committed Aug 26, 2021
commit 53d5adaaa38b126bb5eb9a866fe20cd45a556c3f
36 changes: 36 additions & 0 deletions include/tvm/relay/attrs/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,42 @@ struct ReduceAttrs : public tvm::AttrsNode<ReduceAttrs> {
}
};

/*! \brief Attributes for Reduce operators which reduce by finding a single element. E.g. argmin */
struct OneElementReduceAttrs : public tvm::AttrsNode<ReduceAttrs> {
Array<Integer> axis;
bool keepdims;
bool select_last_index;
bool exclude;

TVM_DECLARE_ATTRS(ReduceAttrs, "relay.attrs.ReduceAttrs") {
TVM_ATTR_FIELD(axis)
.set_default(NullValue<Array<Integer>>())
.describe(R"code(The axis or axes along which to perform the reduction.

The default, `axis=()`, will compute over all elements into a
scalar array with shape `(1,)`.

If `axis` is int, a reduction is performed on a particular axis.

If `axis` is a tuple of ints, a reduction is performed on all the axes
specified in the tuple.

If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.)code");

TVM_ATTR_FIELD(keepdims).set_default(false).describe(
"If this is set to `True`, the reduced axes are left "
"in the result as dimension with size one.");
TVM_ATTR_FIELD(select_last_index)
.set_default(false)
.describe(
"Whether to select the last index if the target element appears multiple times, else "
"select the first index which the target element appears");
TVM_ATTR_FIELD(exclude).set_default(false).describe(
"Whether to perform reduction on axis that are NOT in axis instead.");
}
};

struct VarianceAttrs : public tvm::AttrsNode<VarianceAttrs> {
Array<Integer> axis;
bool keepdims;
Expand Down
6 changes: 3 additions & 3 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,9 @@ class MixedModeVisitor : public ::tvm::relay::ExprVisitor {
* \brief VisitExpr is finalized to preserve call expansion of dataflow regions
*/
void VisitExpr(const Expr& expr) final;
void VisitExpr_(const CallNode* op) override;
void VisitExpr_(const TupleNode* op) override;
void VisitExpr_(const TupleGetItemNode* op) override;
virtual void VisitExpr_(const CallNode* op) override;
virtual void VisitExpr_(const TupleNode* op) override;
virtual void VisitExpr_(const TupleGetItemNode* op) override;

protected:
/*!
Expand Down
86 changes: 55 additions & 31 deletions include/tvm/topi/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,40 @@ inline Tensor max(const Tensor& data, const Array<Integer>& axis, bool keepdims
return CommReduce(data, axis, MaxOp, keepdims, atleast1d);
}

inline FCommReduce MakeSinglePassReducer(
std::function<PrimExpr(Var, Var)> comparison_op,
std::function<PrimExpr(const DataType&)> initial_value_generator, String name) {
// Create a Commutative Reducer with a comparison operation, and method to get the initial value.
auto fcombine = [&](Array<Var> lhs, Array<Var> rhs) {
Array<PrimExpr> result;
result.push_back(tvm::tir::Select(comparison_op(lhs[1], rhs[1]), lhs[0], rhs[0])); // idx
result.push_back(tvm::tir::Select(comparison_op(lhs[1], rhs[1]), lhs[1], rhs[1])); // val
return result;
};
auto fidentity = [&](std::vector<DataType> types) {
Array<PrimExpr> result;
result.push_back(tvm::tir::make_const(types[0], -1)); // idx
result.push_back(initial_value_generator(types[1])); // val
return result;
};
return MakeCommReducer(fcombine, fidentity, name);
}

inline FCommReduce MakeArgminReducer(bool select_last_index = false) {
std::function<PrimExpr(Var, Var)> comparison_op;
if (select_last_index) {
comparison_op = [](Var lhs, Var rhs) { return lhs <= rhs; };
} else {
comparison_op = [](Var lhs, Var rhs) { return lhs < rhs; };
}

std::function<PrimExpr(const DataType&)> initial_value_generator = [](const DataType& data_type) {
return tvm::max_value(data_type);
};

return MakeSinglePassReducer(comparison_op, initial_value_generator, "argmin");
}

/*!
* \brief Creates an operation that finds the indices of the minimum
* values over a given axis.
Expand All @@ -442,41 +476,30 @@ inline Tensor max(const Tensor& data, const Array<Integer>& axis, bool keepdims
* left in the result as dimensions with size one. This enables the result
* to broadcast correctly against the input array.
* \param atleast1d Whether the output need to be atleast1d.
* \param select_last_index Whether to select the last index if the minimum element
* appears multiple times, else select the first index.
*
* \return A Tensor whose op member is the argmin operation
*/
inline Tensor argmin(const Tensor& data, const Array<Integer>& axis, bool keepdims = false,
bool atleast1d = false) {
auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
Array<PrimExpr> result;
result.push_back(tvm::tir::Select(lhs[1] <= rhs[1], lhs[0], rhs[0])); // idx
result.push_back(tvm::tir::Select(lhs[1] <= rhs[1], lhs[1], rhs[1])); // val
return result;
};
auto fidentity = [](std::vector<DataType> types) {
Array<PrimExpr> result;
result.push_back(tvm::tir::make_const(types[0], -1)); // idx
result.push_back(tvm::max_value(types[1])); // val
return result;
};
auto func = MakeCommReducer(fcombine, fidentity, "argmin");
return CommReduceIdx(data, axis, func, keepdims, atleast1d);
bool atleast1d = false, bool select_last_index = false) {
auto reducer = MakeArgminReducer(select_last_index);
return CommReduceIdx(data, axis, reducer, keepdims, atleast1d);
}

inline FCommReduce MakeArgmaxReducer() {
auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
Array<PrimExpr> result;
result.push_back(tvm::tir::Select(lhs[1] >= rhs[1], lhs[0], rhs[0])); // idx
result.push_back(tvm::tir::Select(lhs[1] >= rhs[1], lhs[1], rhs[1])); // val
return result;
};
auto fidentity = [](std::vector<DataType> types) {
Array<PrimExpr> result;
result.push_back(tvm::tir::make_const(types[0], -1)); // idx
result.push_back(tvm::min_value(types[1])); // val
return result;
inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) {
std::function<PrimExpr(Var, Var)> comparison_op;
if (select_last_index) {
comparison_op = [](Var lhs, Var rhs) { return lhs >= rhs; };
} else {
comparison_op = [](Var lhs, Var rhs) { return lhs > rhs; };
}

std::function<PrimExpr(const DataType&)> initial_value_generator = [](const DataType& data_type) {
return tvm::min_value(data_type);
};
return MakeCommReducer(fcombine, fidentity, "argmax");

return MakeSinglePassReducer(comparison_op, initial_value_generator, "argmax");
}

/*!
Expand All @@ -490,12 +513,13 @@ inline FCommReduce MakeArgmaxReducer() {
* left in the result as dimensions with size one. This enables the result
* to broadcast correctly against the input array.
* \param atleast1d Whether the output need to be atleast1d.
*
* \param select_last_index Whether to select the last index if the maximum element
* appears multiple times, else select the first index.
* \return A Tensor whose op member is the argmax operation
*/
inline Tensor argmax(const Tensor& data, const Array<Integer>& axis, bool keepdims = false,
bool atleast1d = false) {
auto reducer = MakeArgmaxReducer();
bool atleast1d = false, bool select_last_index = false) {
auto reducer = MakeArgmaxReducer(select_last_index);
return CommReduceIdx(data, axis, reducer, keepdims, atleast1d);
}

Expand Down
11 changes: 11 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1846,6 +1846,17 @@ def _impl_v1(cls, inputs, attr, params):
return _op.cast(AttrCvt("argmin")(inputs, attr), "int64")


@classmethod
def _impl_v13(cls, inputs, attr, params):
if "select_last_index" in attr:
raise NotImplementedError("select_last_index not supported in ArgMin")
axis = attr.get("axis", 0)
keepdims = attr.get("keepdims", True)
select_last_index = attr.get("select_last_index", False)
attr = {"axis": axis, "keepdims": keepdims}
return _op.cast(AttrCvt("argmin")(inputs, attr), "int64")
# return _op.argmin()

class Softmax(OnnxOpConverter):
"""Operator converter for Softmax."""

Expand Down
35 changes: 28 additions & 7 deletions src/relay/op/tensor/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,34 @@ Array<te::Tensor> ReduceCompute(const Attrs& attrs, const Array<te::Tensor>& inp
auto axes = param->axis;
if (param->exclude) {
axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis);
if (axes.size() == 0) {
return {topi::identity(inputs[0])};
}
}

if (axes.size() == 0) {
return {topi::identity(inputs[0])};
}

return {f(inputs[0], axes, param->keepdims, false)};
}

template <typename F>
Array<te::Tensor> OneElementReduceCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
Comment thread
mbrookhart marked this conversation as resolved.
Outdated
const Type& out_type, F f) {
const OneElementReduceAttrs* param = attrs.as<OneElementReduceAttrs>();
ICHECK(param != nullptr);
if (inputs[0]->shape.size() == 0) {
return {topi::identity(inputs[0])};
}
auto axes = param->axis;
if (param->exclude) {
axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis);
}

if (axes.size() == 0) {
return {topi::identity(inputs[0])};
}
return {f(inputs[0], axes, param->keepdims, false, param->select_last_index)};
}

/*!
* \brief ReduceShapeImpl get the outshape for the reduction operator
* \param in_shape Shape of input data.
Expand Down Expand Up @@ -333,31 +354,31 @@ Expr MakeReduce(Expr data, Array<Integer> axis, bool keepdims, bool exclude, Str

Array<te::Tensor> ArgMaxCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
return ReduceCompute(attrs, inputs, out_type, topi::argmax);
return OneElementReduceCompute(attrs, inputs, out_type, topi::argmax);
}

RELAY_REGISTER_REDUCE_OP("argmax")
.describe(R"code(Creates an operation that finds the indices of the maximum
values over a given axis.

)code" TVM_ADD_FILELINE)
.set_attrs_type<ReduceAttrs>()
.set_attrs_type<OneElementReduceAttrs>()
.set_support_level(4)
.add_type_rel("ArgReduce", ArgReduceRel)
.set_attr<FTVMCompute>("FTVMCompute", ArgMaxCompute)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

Array<te::Tensor> ArgMinCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
return ReduceCompute(attrs, inputs, out_type, topi::argmin);
return OneElementReduceCompute(attrs, inputs, out_type, topi::argmin);
}

RELAY_REGISTER_REDUCE_OP("argmin")
.describe(R"code(Creates an operation that finds the indices of the minimum
values over a given axis.

)code" TVM_ADD_FILELINE)
.set_attrs_type<ReduceAttrs>()
.set_attrs_type<OneElementReduceAttrs>()
.set_support_level(4)
.add_type_rel("ArgReduce", ArgReduceRel)
.set_attr<FTVMCompute>("FTVMCompute", ArgMinCompute)
Expand Down