Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,14 +250,29 @@ TVM_DLL Pass DynamicToStatic();
/*!
* \brief Infer the type of an expression.
*
* The result of type checking is a new expression with unambigous
* The result of type checking is a new expression with unambiguous
* type information filled in, as well as it's checked type field
* populated with the result type.
*
* \return The pass.
*/
TVM_DLL Pass InferType();

/*!
* \brief Infer the type of an expression, reusing existing type information.
*
* The result of type checking is a new expression with unambiguous
* type information filled in for the given node only. The local
* version can use existing type information populated throughout
* the expression and assumes this information is correct. The local
* version also avoids examining large amounts of the graph assuming
* type information is filled in properly which makes it much faster if we
* iteratively call type inference.
*
* \return The type of the expression.
*/
TVM_DLL Type InferTypeLocal(const Expr& expr);

/*!
* \brief Search and eliminate common subexpression. For example, if there are
* two expressions evaluated to an identical value, a single variable is created
Expand Down
19 changes: 19 additions & 0 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,25 @@ def InferType():
return _ffi_api.InferType()


def InferTypeLocal(expr):
"""Infer the type of a single expr, reusing type information to do so.

This populates the checked_type field in expr. We assume existing type information
in the graph is correct!

Parameters
----------
expr: relay.Expr
The expression we want to know the type of

Returns
-------
type: relay.Type
The type of the expression
"""
return _ffi_api.InferTypeLocal(expr)


def FoldScaleAxis():
"""Fold the scaling of axis into weights of conv2d/dense. This pass will
invoke both forward and backward scale folding.
Expand Down
28 changes: 22 additions & 6 deletions src/relay/transforms/to_mixed_precision.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,17 @@ class MixedPrecisionPass : public MixedModeMutator {
}

Type GetType(const Expr& expr) const {
auto mod = IRModule::FromExpr(expr);
mod = transform::InferType()(mod);
if (expr.as<FunctionNode>()) {
return mod->Lookup("main")->checked_type();
} else {
return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
// The expression has not been changed AND it's existing type
// is known to still be valid. (See special handling for tuples etc
// below for where we null out checked_type_ when we can not
// sure it is still valid.
Type checked_type = expr->checked_type_;
if (checked_type.defined()) {
return checked_type;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// The expression has not been changed AND it's existing type
// is known to still be valid. (See special handling for tuples etc
// below for where we null out checked_type_ when we can not
// sure it is still valid.

(though see my comment below)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

// This also populates the checked_type_ field for expr
return transform::InferTypeLocal(expr);
}

bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
Expand Down Expand Up @@ -381,6 +385,18 @@ class MixedPrecisionPass : public MixedModeMutator {
return Call(cur_op, new_args, pre_call_node->attrs, new_arg_types, pre_call_node->span);
}

Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) {
// The old checked type in the expression may not be valid so clear it
post->checked_type_ = Type(nullptr);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

am I missing something or will checked_type_ = null iff some sub-expression of post has been rewritten and thus it's type has changed?
ie checked_type_ is non-null only if pre == post.get() ??

Copy link
Copy Markdown
Contributor Author

@AndrewZhaoLuo AndrewZhaoLuo Dec 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm so you would think so, but it looks like the mutator does not by default invalidate the checked_type (and appears to reuse the reference? giving us this problem).

I can dig a little deeper, but if I remove this line for TupleGetItemNode the checked type will be wrong (it will be fp32 instead of fp16)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/apache/tvm/blob/main/src/relay/ir/expr_functor.cc#L248

Here is the behavior for generating post, there is some Copy on write stuff which i don't quite understand the full mechanics of so 🤷

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah! It's the COW, that makes sense. I think that means we should be clearing checked_type_ on COW but let's not dig ourselves any deeper until we've thought about incremental type inference a bit more carefully.

return post;
}

Expr Rewrite_(const TupleNode* pre, const Expr& post) {
// The old checked type in the expression may not be valid so clear it
post->checked_type_ = Type(nullptr);
return post;
}

Expr VisitExpr_(const FunctionNode* func) final {
// Erase the ret_type annotation and let the normal pass recalculate
const_cast<FunctionNode*>(func)->ret_type = Type(nullptr);
Expand Down
106 changes: 106 additions & 0 deletions src/relay/transforms/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -824,8 +824,114 @@ void AddGlobalTypes(IRModule mod) {
}
}

/*!
* \brief Returns a possibly much smaller subgraph whose inner nodes have the same type.
*
* Returns the largest sub-graph who's inner nodes need types and leaves are vars standing in
* for already typed sub-expressions. This creates a graph whose inner nodes have the same
* type as the original graph and when running type inference, we can avoid copying and
* recursing through most of the expression graph when running type inference. Note, this assumes
* that current populated type information is correct!
*
* ExprMutator is sufficient over MixedModemutator since we will not recurse much.
*/
class SameTypedSubgraphExtractor : public ExprMutator {
Expr VisitExpr_(const VarNode* op) { return Var(op->vid, op->type_annotation, op->span); }
Expr VisitExpr_(const ConstantNode* op) { return Constant(op->data, op->span); }
Expr VisitExpr_(const GlobalVarNode* op) { return GlobalVar(op->name_hint); }
Expr VisitExpr_(const OpNode* op) { return Op(GetRef<Op>(op)); }
Expr VisitExpr_(const TupleNode* op) {
return Tuple(GetAnalogousExpression(op->fields), op->span);
}
Expr VisitExpr_(const FunctionNode* op) {
// Unfortunately our strategy of inserting variables as dummies would change the signature of
// existing function nodes so we have to copy all used functions always :/
return Function(op->params, op->body, op->ret_type, op->type_params, op->attrs, op->span);
}
Expr VisitExpr_(const CallNode* op) {
return Call(op->op, GetAnalogousExpression(op->args), op->attrs, op->type_args, op->span);
}
Expr VisitExpr_(const LetNode* op) {
return Let(op->var, GetAnalogousExpression(op->value), GetAnalogousExpression(op->body),
op->span);
}
Expr VisitExpr_(const IfNode* op) {
return If(GetAnalogousExpression(op->cond), GetAnalogousExpression(op->true_branch),
GetAnalogousExpression(op->false_branch), op->span);
}
Expr VisitExpr_(const TupleGetItemNode* op) {
return TupleGetItem(GetAnalogousExpression(op->tuple), op->index, op->span);
}
Expr VisitExpr_(const RefCreateNode* op) {
return RefCreate(GetAnalogousExpression(op->value), op->span);
}
Expr VisitExpr_(const RefReadNode* op) {
return RefRead(GetAnalogousExpression(op->ref), op->span);
}
Expr VisitExpr_(const RefWriteNode* op) {
return RefWrite(GetAnalogousExpression(op->ref), GetAnalogousExpression(op->value), op->span);
}
Expr VisitExpr_(const ConstructorNode* op) {
return Constructor(op->name_hint, op->inputs, op->belong_to);
}
Expr VisitExpr_(const MatchNode* op) {
return Match(GetAnalogousExpression(op->data), op->clauses, op->complete, op->span);
}

private:
Expr GetAnalogousExpression(const Expr& expr) {
// Replace the expression with a potentially simpler expression of the same type
if (expr->checked_type_.defined()) {
// Since the expression already has a checked_type which we assume is correct we don't need
// full type inference to enter it. So stub it out with a dummy var of the same type.
return Var("dummy_var", expr->checked_type(), expr->span);
}

return VisitExpr(expr);
}
Array<Expr> GetAnalogousExpression(const Array<Expr>& fields) {
Array<Expr> new_fields;
for (Expr expr : fields) {
new_fields.push_back(GetAnalogousExpression(expr));
}
return new_fields;
}
};

namespace transform {

Type InferTypeLocal(const Expr& expr) {
/*
This type inference differs from InferType in that it uses existing type information
to avoid recursing over much of the graph, and it only examines the type of the input
node. This makes it faster if you need to run type inference iteratively throughout
a pass for example.

However, it assumes any existing populated type inference is correct! If some populated
type inference is incorrect, an incorrect type may be returned or a type error will be
raised. If you know not all populated type fields are correct with the current graph,
you should use InferType() instead.
*/
SameTypedSubgraphExtractor subgraph_extractor;
Expr sub_graph = subgraph_extractor(expr);
auto mod = IRModule::FromExpr(sub_graph);
mod = transform::InferType()(mod);

Type result_type;
if (expr.as<FunctionNode>()) {
result_type = mod->Lookup("main")->checked_type();
} else {
result_type = mod->Lookup("main").as<FunctionNode>()->body->checked_type();
}

expr->checked_type_ = result_type;
return result_type;
}

TVM_REGISTER_GLOBAL("relay._transform.InferTypeLocal").set_body_typed([](const Expr& expr) {
return InferTypeLocal(expr);
});

Pass InferType() {
auto pass_info = PassInfo(0, "InferType", {});
return tvm::transform::CreateModulePass(
Expand Down
27 changes: 13 additions & 14 deletions tests/python/relay/test_type_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
"""
import pytest
import tvm

from tvm import IRModule, te, relay, parser
from tvm.relay import op, transform, analysis
from tvm import IRModule, parser, relay, te
from tvm.relay import analysis, op, transform
from tvm.relay.op import op as _op


Expand All @@ -33,12 +32,9 @@ def infer_mod(mod, annotate_spans=True):
return mod


def infer_expr(expr, annotate_spans=True):
mod = IRModule.from_expr(expr)
mod = infer_mod(mod, annotate_spans)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
def infer_expr(expr):
transform.InferTypeLocal(expr)
return expr


def assert_has_type(expr, typ, mod=None):
Expand Down Expand Up @@ -68,7 +64,7 @@ def test_monomorphic_let():
# TODO(@jroesch): this seems whack.
sb = relay.ScopeBuilder()
x = relay.var("x", dtype="float64", shape=())
x = sb.let("x", relay.const(1.0, "float64"))
x = sb.let(x, relay.const(1.0, "float64"))
sb.ret(x)
xchecked = infer_expr(sb.get())
assert xchecked.checked_type == relay.scalar_type("float64")
Expand Down Expand Up @@ -165,11 +161,11 @@ def @f(%n: int32, %data: float32) -> float32 {
def test_incomplete_call():
tt = relay.scalar_type("int32")
x = relay.var("x", tt)
f_type = relay.FuncType([tt], tt)
f = relay.var("f")
func = relay.Function([x, f], relay.Call(f, [x]), tt)

ft = infer_expr(func)
f_type = relay.FuncType([tt], tt)
assert ft.checked_type == relay.FuncType([tt, f_type], tt)


Expand Down Expand Up @@ -245,7 +241,7 @@ def test_ref():
def test_free_expr():
x = relay.var("x", "float32")
y = relay.add(x, x)
yy = infer_expr(y, annotate_spans=False)
yy = infer_expr(y)
assert tvm.ir.structural_equal(yy.args[0], x, map_free_vars=True)
assert yy.checked_type == relay.scalar_type("float32")
assert x.vid.same_as(yy.args[0].vid)
Expand All @@ -255,8 +251,11 @@ def test_type_args():
x = relay.var("x", shape=(10, 10))
y = relay.var("y", shape=(1, 10))
z = relay.add(x, y)
ty_z = infer_expr(z)
ty_args = ty_z.type_args

# InferTypeLocal does not support populating the type_args field
mod = infer_mod(IRModule.from_expr(z))
mod = infer_mod(mod, annotate_spans=False)
ty_args = mod["main"].body.type_args
assert len(ty_args) == 2
assert ty_args[0].dtype == "float32"
assert ty_args[1].dtype == "float32"
Expand Down