From b5261a425024ddea5f050bda4bcdfc960c270c51 Mon Sep 17 00:00:00 2001 From: shengxinhu Date: Tue, 11 Apr 2023 11:21:14 +0800 Subject: [PATCH 1/3] [Relay] Enhance type infer for dynamic shape Support type_infer to enable unify PrimExpr such as tir.IndexMod(relay.Any(), 5) --- src/relay/analysis/type_solver.cc | 22 +++++++++++++++++++++- src/relay/analysis/type_solver.h | 1 + tests/python/relay/test_type_infer.py | 8 ++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index d40eb8a17c06..8d9af018506a 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -76,6 +77,20 @@ class TypeSolver::Reporter : public TypeReporterNode { TypeSolver* solver_; }; +class TypeSolver::AnyChecker : public tir::ExprVisitor { + public: + void VisitExpr_(const AnyNode* op) final { + found_ = true; + } + + bool Check(const PrimExpr& expr) { + tir::ExprVisitor::VisitExpr(expr); + return found_; + } + private: + bool found_{false}; +}; + class TypeSolver::OccursChecker : public TypeVisitor { public: explicit OccursChecker(TypeSolver* solver, TypeNode* var) @@ -146,6 +161,11 @@ class TypeSolver::Unifier : public TypeFunctor { } } + bool HasAny(const PrimExpr& expr) { + AnyChecker ac; + return ac.Check(expr); + } + // Checks whether lhs (taken to be a type var) occurs in t, meaning // there is a recursive equality constraint, which should be rejected. // N.b.: A tautology like ?a = ?a is okay and should be checked for @@ -186,7 +206,7 @@ class TypeSolver::Unifier : public TypeFunctor { if (ulhs.same_as(urhs)) { return ulhs; } - if (ulhs.as() || urhs.as()) { + if (HasAny(ulhs) || HasAny(urhs)) { return Any(); } diff --git a/src/relay/analysis/type_solver.h b/src/relay/analysis/type_solver.h index 7940e347b3ea..5d32afab6442 100644 --- a/src/relay/analysis/type_solver.h +++ b/src/relay/analysis/type_solver.h @@ -97,6 +97,7 @@ class TypeSolver { void Emit(const Diagnostic& diag) { diag_ctx_.Emit(diag); } private: + class AnyChecker; class OccursChecker; class Unifier; class Resolver; diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index 187455570216..b4c6c140cdaa 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -417,6 +417,14 @@ def test_dynamic_function(): mod = transform.InferType()(mod) assert mod["main"].params[0].checked_type == s_tt + data = relay.var( + "data", shape=(relay.Any(), relay.Any(), relay.Any(), relay.Any()), dtype="float32" + ) + weigth = relay.const(np.full((1, 16, 3, 3), 0.25), dtype="float32") + x = relay.nn.conv2d(data, weigth, kernel_size=(3, 3), channels=16, groups=2) + mod = tvm.IRModule.from_expr(x) + mod = transform.InferType()(mod) + def test_custom_op_infer(): """Tests infer type for custom_op""" From b045239cac3c0be02236f0298e1d5f03f4c286c9 Mon Sep 17 00:00:00 2001 From: shengxinhu Date: Thu, 13 Apr 2023 09:56:17 +0800 Subject: [PATCH 2/3] fix a bug --- tests/python/relay/test_type_infer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index b4c6c140cdaa..7fbb656b367a 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -420,7 +420,7 @@ def test_dynamic_function(): data = relay.var( "data", shape=(relay.Any(), relay.Any(), relay.Any(), relay.Any()), dtype="float32" ) - weigth = relay.const(np.full((1, 16, 3, 3), 0.25), dtype="float32") + weigth = relay.const(np.full((16, 16, 3, 3), 0.25), dtype="float32") x = relay.nn.conv2d(data, weigth, kernel_size=(3, 3), channels=16, groups=2) mod = tvm.IRModule.from_expr(x) mod = transform.InferType()(mod) From d94b29222fb60a9c82082fd70942d549c60435c9 Mon Sep 17 00:00:00 2001 From: shengxinhu Date: Thu, 13 Apr 2023 16:54:04 +0800 Subject: [PATCH 3/3] fix lint --- src/relay/analysis/type_solver.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index 8d9af018506a..47f96281348b 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -25,8 +25,8 @@ #include #include -#include #include +#include #include #include @@ -79,14 +79,13 @@ class TypeSolver::Reporter : public TypeReporterNode { class TypeSolver::AnyChecker : public tir::ExprVisitor { public: - void VisitExpr_(const AnyNode* op) final { - found_ = true; - } + void VisitExpr_(const AnyNode* op) final { found_ = true; } bool Check(const PrimExpr& expr) { tir::ExprVisitor::VisitExpr(expr); return found_; } + private: bool found_{false}; };