Skip to content

Commit aa82b14

Browse files
Siyuan FengMasterJH5574junrushaotqchenYuchenJin
authored andcommitted
[TVMScript] IRModule TVMScript Parser.
This PR adds the TVMScript parser/ir_builder support based on the blockbuilder. This commit contains the non-relax portions from apache#13932. Co-authored-by: Ruihang Lai <ruihangl@cs.cmu.edu> Co-authored-by: Junru Shao <junrushao1994@gmail.com> Co-authored-by: Tianqi Chen <tianqi.tchen@gmail.com> Co-authored-by: Yuchen Jin <yuchenj@cs.washington.edu> Co-authored-by: Steven S. Lyubomirsky <slyubomirsky@gmail.com> Co-authored-by: Yong Wu <yongcale@gmail.com>
1 parent d085dee commit aa82b14

17 files changed

Lines changed: 246 additions & 34 deletions

File tree

include/tvm/script/ir_builder/ir/frame.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,17 @@ namespace ir {
3838
*/
3939
class IRModuleFrameNode : public IRBuilderFrameNode {
4040
public:
41-
Array<GlobalVar> global_vars;
42-
Array<BaseFunc> functions;
41+
/*! \brief A map from string names to global variables that ensures global uniqueness. */
42+
Map<String, GlobalVar> global_var_map;
43+
/*!
44+
* \brief A map from GlobalVar to all global functions.
45+
* \note Only defined functions are in the map, while declared functions are not included.
46+
*/
47+
Map<GlobalVar, BaseFunc> functions;
4348

4449
void VisitAttrs(tvm::AttrVisitor* v) {
4550
IRBuilderFrameNode::VisitAttrs(v);
46-
v->Visit("global_vars", &global_vars);
51+
v->Visit("global_vars", &global_var_map);
4752
v->Visit("functions", &functions);
4853
}
4954

include/tvm/script/ir_builder/ir/ir.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,23 @@ namespace ir {
3737
*/
3838
TVM_DLL IRModuleFrame IRModule();
3939

40+
/*!
41+
* \brief Declare a Function without given the specific function implementation.
42+
* \note It is usually used in cross-function call. And we can specify the function by `DefFunction`
43+
* \param func_name The function unique name.
44+
* \param func_signature A Function w/o body, which used to specify the function signature
45+
* (i.e. func params and func return type/shape).
46+
* \return The corresponding GlobalVar.
47+
*/
48+
TVM_DLL GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature);
49+
50+
/*!
51+
* \brief Define the function which is declared before.
52+
* \param func_name The function unique name.
53+
* \param func The given function implementation
54+
*/
55+
TVM_DLL void DefFunction(const String& func_name, const BaseFunc& func);
56+
4057
} // namespace ir
4158
} // namespace ir_builder
4259
} // namespace script

python/tvm/script/ir_builder/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,10 @@ def __enter__(self) -> "IRBuilderFrame":
6464
_ffi_api.IRBuilderFrameEnter(self) # type: ignore[attr-defined] # pylint: disable=no-member
6565
return self
6666

67-
def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument
68-
_ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] # pylint: disable=no-member
67+
def __exit__(self, exc_type, exc_value, trace) -> None: # pylint: disable=unused-argument
68+
if exc_type is None and exc_value is None:
69+
# Do not execute `FrameExit` if the with scope exits because of exceptions
70+
_ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] # pylint: disable=no-member
6971

7072
def add_callback(self, callback: Callable[[], None]) -> None:
7173
"""Add a callback method invoked when exiting the with-scope.

python/tvm/script/ir_builder/ir/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
# under the License.
1717
"""Package tvm.script.ir_builder.ir"""
1818
from .frame import IRModuleFrame
19-
from .ir import ir_module
19+
from .ir import decl_function, def_function, ir_module

python/tvm/script/ir_builder/ir/ir.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,54 @@
1616
# under the License.
1717
"""Package tvm.script.ir_builder.ir.ir"""
1818

19+
from tvm.ir import BaseFunc, GlobalVar
20+
1921
from . import _ffi_api
2022
from .frame import IRModuleFrame
2123

2224

2325
def ir_module() -> IRModuleFrame:
26+
"""Start a ir_module frame.
27+
Returns
28+
-------
29+
frame: IRModuleFrame
30+
The constructed frame.
31+
"""
2432
return _ffi_api.IRModule() # type: ignore[attr-defined] # pylint: disable=no-member
33+
34+
35+
def decl_function(func_name: str, func_signature: BaseFunc) -> GlobalVar:
36+
"""Declare a Function without given the specific function implementation.
37+
Parameters
38+
----------
39+
func_name : str
40+
The function unique name.
41+
42+
func_signature: Optional[BaseFunc]
43+
A Function w/o body, which used to specify the function signature
44+
(i.e. func params and func return type/shape).
45+
46+
Note
47+
----
48+
It is usually used in cross-function call. And we can specify the function by `DefFunction`
49+
Returns
50+
-------
51+
gv : GlobalVar
52+
The corresponding GlobalVar.
53+
"""
54+
55+
return _ffi_api.DeclFunction( # type: ignore[attr-defined] # pylint: disable=no-member
56+
func_name, func_signature
57+
)
58+
59+
60+
def def_function(func_name: str, func: BaseFunc) -> None:
61+
"""Define the function which is declared before.
62+
Parameters
63+
----------
64+
func_name : str
65+
The function unique name.
66+
func: BaseFunc
67+
The given function implementation
68+
"""
69+
return _ffi_api.DefFunction(func_name, func) # type: ignore[attr-defined] # pylint: disable=no-member

python/tvm/script/parser/core/diagnostics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def _emit(self, node: doc.AST, message: str, level: diagnostics.DiagnosticLevel)
220220
level : diagnostics.DiagnosticLevel
221221
The diagnostic level.
222222
"""
223-
lineno = node.lineno or self.source.start_line
223+
lineno = node.lineno or 1
224224
col_offset = node.col_offset or self.source.start_column
225225
end_lineno = node.end_lineno or lineno
226226
end_col_offset = node.end_col_offset or col_offset

python/tvm/script/parser/core/entry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None)
5151
"ir": ir,
5252
"T": tir,
5353
"tir": tir,
54+
"tvm": tvm,
5455
}
5556

5657
source = Source(program)

python/tvm/script/parser/core/evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def _visit(self, node: doc.AST) -> Any:
203203
else:
204204
value = self._eval_expr(node.__class__(**fields))
205205
except Exception as e: # pylint: disable=broad-except,invalid-name
206-
self.parser.report_error(node, str(e))
206+
self.parser.report_error(node, e)
207207
return self._add_intermediate_result(value)
208208

209209
def _eval_lambda(self, node: doc.Lambda) -> Any:

python/tvm/script/parser/core/parser.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ def context():
6060
return context()
6161

6262

63+
def _do_nothing(*args, **kwargs): # pylint: disable=unused-argument
64+
pass
65+
66+
6367
class VarTableFrame:
6468
"""The variable table frame.
6569
A frame of variable table stores the variables created in one block or scope.
@@ -259,6 +263,17 @@ def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any:
259263
node = self.diag.source.as_ast()
260264
self.visit(node)
261265

266+
def get_dispatch_token(self, node: doc.FunctionDef) -> str:
267+
if not isinstance(node, doc.FunctionDef):
268+
self.report_error(node, "Only can get dispatch token for function.")
269+
if not node.decorator_list:
270+
self.report_error(node, "Function must be decorated")
271+
# TODO: only the last decorator is parsed
272+
decorator = self.eval_expr(node.decorator_list[-1])
273+
if not hasattr(decorator, "dispatch_token"):
274+
self.report_error(node, "The parser does not understand the decorator")
275+
return decorator.dispatch_token
276+
262277
def with_dispatch_token(self, token: str):
263278
"""Add a new dispatching token as with statement.
264279
@@ -388,6 +403,8 @@ def report_error(
388403
# Only take the last line of the error message
389404
if isinstance(err, TVMError):
390405
msg = list(filter(None, str(err).split("\n")))[-1]
406+
elif isinstance(err, KeyError):
407+
msg = "KeyError: " + str(err)
391408
else:
392409
msg = str(err)
393410
self.diag.error(node, msg)
@@ -457,30 +474,33 @@ def visit_tvm_annotation(self, node: doc.expr) -> Any:
457474
"""
458475
return _dispatch(self, "tvm_annotation")(self, node)
459476

460-
def visit_FunctionDef(self, node: doc.FunctionDef) -> Any: # pylint: disable=invalid-name
461-
"""The general function definition visiting method.
477+
def visit_FunctionDef(self, node: doc.FunctionDef) -> None: # pylint: disable=invalid-name
478+
"""The general function definition visit method.
462479
463480
Parameters
464481
----------
465482
node : doc.FunctionDef
466-
The doc AST function definition node.
467-
468-
Returns
469-
-------
470-
res : Any
471-
The visiting result.
483+
The doc FunctionDef node.
472484
"""
473-
if not node.decorator_list:
474-
self.report_error(node, "Function must be decorated")
475-
# TODO: only the last decorator is parsed
476-
decorator = self.eval_expr(node.decorator_list[-1])
477-
if not hasattr(decorator, "dispatch_token"):
478-
self.report_error(node, "The parser does not understand the decorator")
479-
token = decorator.dispatch_token
485+
token = self.get_dispatch_token(node)
486+
current_token = self.dispatch_tokens[-1]
480487
func = dispatch.get(token=token, type_name="FunctionDef", default=None)
481488
if func is None:
482489
self.report_error(node, "The parser does not understand the decorator")
490+
pre_func = dispatch.get(
491+
token=current_token, type_name="pre_token_switch", default=_do_nothing
492+
)
493+
post_func = dispatch.get(
494+
token=current_token, type_name="post_token_switch", default=_do_nothing
495+
)
496+
pre_func(self, node)
483497
_dispatch_wrapper(func)(self, node)
498+
post_func(self, node)
499+
500+
def visit_tvm_declare_function(self, node: doc.FunctionDef) -> None:
501+
token = self.get_dispatch_token(node)
502+
with self.with_dispatch_token(token):
503+
_dispatch(self, "tvm_declare_function")(self, node)
484504

485505
def visit_ClassDef(self, node: doc.ClassDef) -> Any: # pylint: disable=invalid-name
486506
"""The general class definition visiting method.

python/tvm/script/parser/ir/parser.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,12 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None:
3232
node : doc.ClassDef
3333
The doc AST class definition node.
3434
"""
35+
3536
with self.var_table.with_frame():
3637
with I.ir_module():
38+
for stmt in node.body:
39+
if isinstance(stmt, doc.FunctionDef):
40+
self.visit_tvm_declare_function(stmt)
3741
with self.with_dispatch_token("ir"):
3842
self.visit_body(node.body)
3943

0 commit comments

Comments
 (0)