Skip to content

Commit 8246972

Browse files
committed
Remove PrimExpr from String (apache#5311)
1 parent 9938a66 commit 8246972

6 files changed

Lines changed: 28 additions & 32 deletions

File tree

include/tvm/ir/expr.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,6 @@ class PrimExpr : public BaseExpr {
108108
*/
109109
TVM_DLL PrimExpr(float value); // NOLINT(*)
110110

111-
/*!
112-
* \brief construct from runtime String.
113-
* \param value The value to be constructed.
114-
*/
115-
TVM_DLL PrimExpr(runtime::String value); // NOLINT(*)
116-
117111
/*! \return the data type of this expression. */
118112
DataType dtype() const {
119113
return static_cast<const PrimExprNode*>(get())->dtype;

src/ir/expr.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,6 @@ PrimExpr::PrimExpr(int32_t value)
4040
PrimExpr::PrimExpr(float value)
4141
: PrimExpr(FloatImm(DataType::Float(32), value)) {}
4242

43-
PrimExpr::PrimExpr(runtime::String value)
44-
: PrimExpr(tir::StringImmNode::make(value)) {}
45-
4643
PrimExpr PrimExpr::FromObject_(ObjectRef ref) {
4744
using runtime::ObjectTypeChecker;
4845
if (auto* ptr = ref.as<tir::IterVarNode>()) {

src/target/target.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ Target CreateTarget(const std::string& target_name,
137137
} else if (target_name == "hybrid") {
138138
t->device_type = kDLCPU;
139139
} else if (target_name == "hexagon") {
140-
t->keys_array.push_back(runtime::String("hexagon"));
140+
t->keys_array.push_back("hexagon");
141141
t->device_type = kDLHexagon;
142142
} else {
143143
LOG(ERROR) << "Unknown target name " << target_name;

src/tir/ir/stmt.cc

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ Stmt AttrStmtNode::make(ObjectRef node,
5858
TVM_REGISTER_GLOBAL("tir.AttrStmt")
5959
.set_body_typed(AttrStmtNode::make);
6060

61-
6261
Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) {
6362
CHECK(condition.defined());
6463
CHECK(message.dtype() == DataType::Int(32) ||
@@ -74,8 +73,14 @@ Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) {
7473
}
7574

7675
TVM_REGISTER_GLOBAL("tir.AssertStmt")
77-
.set_body_typed(AssertStmtNode::make);
78-
76+
.set_body_typed([](PrimExpr condition, ObjectRef message, Stmt body) {
77+
if (const auto* str = message.as<StringObj>()) {
78+
auto msg = StringImmNode::make(str->data);
79+
return AssertStmtNode::make(condition, msg, body);
80+
} else {
81+
return AssertStmtNode::make(condition, Downcast<PrimExpr>(message), body);
82+
}
83+
});
7984

8085
Stmt ProducerConsumerNode::make(FunctionRef func, bool is_producer, Stmt body) {
8186
CHECK(body.defined());
@@ -92,11 +97,11 @@ TVM_REGISTER_GLOBAL("tir.ProducerConsumer")
9297

9398

9499
Stmt ForNode::make(Var loop_var,
95-
PrimExpr min,
96-
PrimExpr extent,
97-
ForType for_type,
98-
DeviceAPI device_api,
99-
Stmt body) {
100+
PrimExpr min,
101+
PrimExpr extent,
102+
ForType for_type,
103+
DeviceAPI device_api,
104+
Stmt body) {
100105
CHECK(min.defined());
101106
CHECK(extent.defined());
102107
CHECK(min.dtype().is_scalar());
@@ -119,11 +124,11 @@ TVM_REGISTER_GLOBAL("tir.For")
119124
Var loop_var, PrimExpr min, PrimExpr extent,
120125
int for_type, int device_api, Stmt body) {
121126
return ForNode::make(loop_var,
122-
min,
123-
extent,
124-
static_cast<ForType>(for_type),
125-
static_cast<DeviceAPI>(device_api),
126-
body);
127+
min,
128+
extent,
129+
static_cast<ForType>(for_type),
130+
static_cast<DeviceAPI>(device_api),
131+
body);
127132
});
128133

129134

@@ -176,12 +181,12 @@ TVM_REGISTER_GLOBAL("tir.Provide")
176181

177182

178183
Stmt AllocateNode::make(Var buffer_var,
179-
DataType dtype,
180-
Array<PrimExpr> extents,
181-
PrimExpr condition,
182-
Stmt body,
183-
PrimExpr new_expr,
184-
std::string free_function) {
184+
DataType dtype,
185+
Array<PrimExpr> extents,
186+
PrimExpr condition,
187+
Stmt body,
188+
PrimExpr new_expr,
189+
std::string free_function) {
185190
for (size_t i = 0; i < extents.size(); ++i) {
186191
CHECK(extents[i].defined());
187192
CHECK(extents[i].dtype().is_scalar());

topi/include/topi/contrib/cublas.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ inline Tensor cublas_matmul(const Tensor& lhs,
5353
{ { n, m } }, { lhs->dtype }, { lhs, rhs },
5454
[&](Array<Buffer> ins, Array<Buffer> outs) {
5555
return call_packed({
56-
runtime::String("tvm.contrib.cublas.matmul"),
56+
StringImmNode::make("tvm.contrib.cublas.matmul"),
5757
pack_buffer(ins[0]),
5858
pack_buffer(ins[1]),
5959
pack_buffer(outs[0]),
@@ -85,7 +85,7 @@ inline Tensor cublas_batch_matmul(const Tensor& lhs,
8585
{ { b, n, m } }, { lhs->dtype }, { lhs, rhs },
8686
[&](Array<Buffer> ins, Array<Buffer> outs) {
8787
return call_packed({
88-
runtime::String("tvm.contrib.cublas.batch_matmul"),
88+
StringImmNode::make("tvm.contrib.cublas.batch_matmul"),
8989
pack_buffer(ins[0]),
9090
pack_buffer(ins[1]),
9191
pack_buffer(outs[0]),

topi/include/topi/contrib/rocblas.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ inline Tensor rocblas_matmul(const Tensor& lhs,
5252
{ { n, m } }, { lhs->dtype }, { lhs, rhs },
5353
[&](Array<Buffer> ins, Array<Buffer> outs) {
5454
return call_packed({
55-
runtime::String("tvm.contrib.rocblas.matmul"),
55+
StringImmNode::make("tvm.contrib.rocblas.matmul"),
5656
pack_buffer(ins[0]),
5757
pack_buffer(ins[1]),
5858
pack_buffer(outs[0]),

0 commit comments

Comments
 (0)