@@ -60,6 +60,10 @@ def context():
6060 return context ()
6161
6262
63+ def _do_nothing (* args , ** kwargs ): # pylint: disable=unused-argument
64+ pass
65+
66+
6367class VarTableFrame :
6468 """The variable table frame.
6569 A frame of variable table stores the variables created in one block or scope.
@@ -259,6 +263,17 @@ def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any:
259263 node = self .diag .source .as_ast ()
260264 self .visit (node )
261265
266+ def get_dispatch_token (self , node : doc .FunctionDef ) -> str :
267+ if not isinstance (node , doc .FunctionDef ):
268+ self .report_error (node , "Only can get dispatch token for function." )
269+ if not node .decorator_list :
270+ self .report_error (node , "Function must be decorated" )
271+ # TODO: only the last decorator is parsed
272+ decorator = self .eval_expr (node .decorator_list [- 1 ])
273+ if not hasattr (decorator , "dispatch_token" ):
274+ self .report_error (node , "The parser does not understand the decorator" )
275+ return decorator .dispatch_token
276+
262277 def with_dispatch_token (self , token : str ):
263278 """Add a new dispatching token as with statement.
264279
@@ -388,6 +403,8 @@ def report_error(
388403 # Only take the last line of the error message
389404 if isinstance (err , TVMError ):
390405 msg = list (filter (None , str (err ).split ("\n " )))[- 1 ]
406+ elif isinstance (err , KeyError ):
407+ msg = "KeyError: " + str (err )
391408 else :
392409 msg = str (err )
393410 self .diag .error (node , msg )
@@ -457,30 +474,33 @@ def visit_tvm_annotation(self, node: doc.expr) -> Any:
457474 """
458475 return _dispatch (self , "tvm_annotation" )(self , node )
459476
460- def visit_FunctionDef (self , node : doc .FunctionDef ) -> Any : # pylint: disable=invalid-name
461- """The general function definition visiting method.
477+ def visit_FunctionDef (self , node : doc .FunctionDef ) -> None : # pylint: disable=invalid-name
478+ """The general function definition visit method.
462479
463480 Parameters
464481 ----------
465482 node : doc.FunctionDef
466- The doc AST function definition node.
467-
468- Returns
469- -------
470- res : Any
471- The visiting result.
483+ The doc FunctionDef node.
472484 """
473- if not node .decorator_list :
474- self .report_error (node , "Function must be decorated" )
475- # TODO: only the last decorator is parsed
476- decorator = self .eval_expr (node .decorator_list [- 1 ])
477- if not hasattr (decorator , "dispatch_token" ):
478- self .report_error (node , "The parser does not understand the decorator" )
479- token = decorator .dispatch_token
485+ token = self .get_dispatch_token (node )
486+ current_token = self .dispatch_tokens [- 1 ]
480487 func = dispatch .get (token = token , type_name = "FunctionDef" , default = None )
481488 if func is None :
482489 self .report_error (node , "The parser does not understand the decorator" )
490+ pre_func = dispatch .get (
491+ token = current_token , type_name = "pre_token_switch" , default = _do_nothing
492+ )
493+ post_func = dispatch .get (
494+ token = current_token , type_name = "post_token_switch" , default = _do_nothing
495+ )
496+ pre_func (self , node )
483497 _dispatch_wrapper (func )(self , node )
498+ post_func (self , node )
499+
500+ def visit_tvm_declare_function (self , node : doc .FunctionDef ) -> None :
501+ token = self .get_dispatch_token (node )
502+ with self .with_dispatch_token (token ):
503+ _dispatch (self , "tvm_declare_function" )(self , node )
484504
485505 def visit_ClassDef (self , node : doc .ClassDef ) -> Any : # pylint: disable=invalid-name
486506 """The general class definition visiting method.
0 commit comments