@@ -58,7 +58,6 @@ Stmt AttrStmtNode::make(ObjectRef node,
5858TVM_REGISTER_GLOBAL (" tir.AttrStmt" )
5959.set_body_typed(AttrStmtNode::make);
6060
61-
6261Stmt AssertStmtNode::make (PrimExpr condition, PrimExpr message, Stmt body) {
6362 CHECK (condition.defined ());
6463 CHECK (message.dtype () == DataType::Int (32 ) ||
@@ -74,8 +73,14 @@ Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) {
7473}
7574
7675TVM_REGISTER_GLOBAL (" tir.AssertStmt" )
77- .set_body_typed(AssertStmtNode::make);
78-
76+ .set_body_typed([](PrimExpr condition, ObjectRef message, Stmt body) {
77+ if (const auto * str = message.as <StringObj>()) {
78+ auto msg = StringImmNode::make (str->data );
79+ return AssertStmtNode::make (condition, msg, body);
80+ } else {
81+ return AssertStmtNode::make (condition, Downcast<PrimExpr>(message), body);
82+ }
83+ });
7984
8085Stmt ProducerConsumerNode::make (FunctionRef func, bool is_producer, Stmt body) {
8186 CHECK (body.defined ());
@@ -92,11 +97,11 @@ TVM_REGISTER_GLOBAL("tir.ProducerConsumer")
9297
9398
9499Stmt ForNode::make (Var loop_var,
95- PrimExpr min,
96- PrimExpr extent,
97- ForType for_type,
98- DeviceAPI device_api,
99- Stmt body) {
100+ PrimExpr min,
101+ PrimExpr extent,
102+ ForType for_type,
103+ DeviceAPI device_api,
104+ Stmt body) {
100105 CHECK (min.defined ());
101106 CHECK (extent.defined ());
102107 CHECK (min.dtype ().is_scalar ());
@@ -119,11 +124,11 @@ TVM_REGISTER_GLOBAL("tir.For")
119124 Var loop_var, PrimExpr min, PrimExpr extent,
120125 int for_type, int device_api, Stmt body) {
121126 return ForNode::make (loop_var,
122- min,
123- extent,
124- static_cast <ForType>(for_type),
125- static_cast <DeviceAPI>(device_api),
126- body);
127+ min,
128+ extent,
129+ static_cast <ForType>(for_type),
130+ static_cast <DeviceAPI>(device_api),
131+ body);
127132});
128133
129134
@@ -176,12 +181,12 @@ TVM_REGISTER_GLOBAL("tir.Provide")
176181
177182
178183Stmt AllocateNode::make (Var buffer_var,
179- DataType dtype,
180- Array<PrimExpr> extents,
181- PrimExpr condition,
182- Stmt body,
183- PrimExpr new_expr,
184- std::string free_function) {
184+ DataType dtype,
185+ Array<PrimExpr> extents,
186+ PrimExpr condition,
187+ Stmt body,
188+ PrimExpr new_expr,
189+ std::string free_function) {
185190 for (size_t i = 0 ; i < extents.size (); ++i) {
186191 CHECK (extents[i].defined ());
187192 CHECK (extents[i].dtype ().is_scalar ());
0 commit comments