-
Notifications
You must be signed in to change notification settings - Fork 3.9k
[AMP][Pass][Typing] Add faster type inference #9735
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f6747be
5d3932f
3220c80
7dee27b
08b391a
2021f23
5136b85
dbf3cf6
e9a5f55
f8c5012
5960c5c
f294f63
4f0b03b
8301057
1cb38f1
5aae167
d6f73f2
09fbbe0
faeed08
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -176,13 +176,17 @@ class MixedPrecisionPass : public MixedModeMutator { | |
| } | ||
|
|
||
| Type GetType(const Expr& expr) const { | ||
| auto mod = IRModule::FromExpr(expr); | ||
| mod = transform::InferType()(mod); | ||
| if (expr.as<FunctionNode>()) { | ||
| return mod->Lookup("main")->checked_type(); | ||
| } else { | ||
| return mod->Lookup("main").as<FunctionNode>()->body->checked_type(); | ||
| // The expression has not been changed AND it's existing type | ||
| // is known to still be valid. (See special handling for tuples etc | ||
| // below for where we null out checked_type_ when we can not | ||
| // sure it is still valid. | ||
| Type checked_type = expr->checked_type_; | ||
| if (checked_type.defined()) { | ||
| return checked_type; | ||
| } | ||
|
|
||
| // This also populates the checked_type_ field for expr | ||
| return transform::InferTypeLocal(expr); | ||
| } | ||
|
|
||
| bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const { | ||
|
|
@@ -381,6 +385,18 @@ class MixedPrecisionPass : public MixedModeMutator { | |
| return Call(cur_op, new_args, pre_call_node->attrs, new_arg_types, pre_call_node->span); | ||
| } | ||
|
|
||
| Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) { | ||
| // The old checked type in the expression may not be valid so clear it | ||
| post->checked_type_ = Type(nullptr); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. am I missing something or will checked_type_ = null iff some sub-expression of post has been rewritten and thus it's type has changed?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm so you would think so, but it looks like the mutator does not by default invalidate the checked_type (and appears to reuse the reference? giving us this problem). I can dig a little deeper, but if I remove this line for TupleGetItemNode the checked type will be wrong (it will be fp32 instead of fp16)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://github.com/apache/tvm/blob/main/src/relay/ir/expr_functor.cc#L248 Here is the behavior for generating post, there is some Copy on write stuff which i don't quite understand the full mechanics of so 🤷
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah! It's the COW, that makes sense. I think that means we should be clearing checked_type_ on COW but let's not dig ourselves any deeper until we've thought about incremental type inference a bit more carefully. |
||
| return post; | ||
| } | ||
|
|
||
| Expr Rewrite_(const TupleNode* pre, const Expr& post) { | ||
| // The old checked type in the expression may not be valid so clear it | ||
| post->checked_type_ = Type(nullptr); | ||
| return post; | ||
| } | ||
|
|
||
| Expr VisitExpr_(const FunctionNode* func) final { | ||
| // Erase the ret_type annotation and let the normal pass recalculate | ||
| const_cast<FunctionNode*>(func)->ret_type = Type(nullptr); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// The expression has not been changed AND it's existing type
// is known to still be valid. (See special handling for tuples etc
// below for where we null out checked_type_ when we can not
// sure it is still valid.
(though see my comment below)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done