diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 044e5d6f6ca9..7358394d3a25 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -154,6 +154,7 @@ class ConstIntBoundAnalyzer { * \param allow_override whether we allow override of existing information. */ TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool allow_override = false); + /*! * \brief Bind variable to a range. * @@ -163,6 +164,13 @@ class ConstIntBoundAnalyzer { */ TVM_DLL void Bind(const Var& var, const Range& range, bool allow_override = false); + /*! + * \brief Check if a variable is bound to a range. + * \param var The variable. + * \return Whether the variable is bound to a range. + */ + TVM_DLL bool IsBound(const Var& var) const; + private: friend class Analyzer; friend class ConstraintContext; diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index 919272a2734b..434e2a3e65c6 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -20,7 +20,8 @@ from typing import Union import tvm.ffi -from tvm import tir, ir +from tvm import ir, tir +from tvm.arith import IntSet from tvm.runtime import Object from . import _ffi_api @@ -109,6 +110,7 @@ def __init__(self): _mod = _ffi_api.CreateAnalyzer() self._const_int_bound = _mod("const_int_bound") self._const_int_bound_update = _mod("const_int_bound_update") + self._const_int_bound_is_bound = _mod("const_int_bound_is_bound") self._bind = _mod("bind") self._modular_set = _mod("modular_set") self._simplify = _mod("Simplify") @@ -123,7 +125,7 @@ def __init__(self): self._get_enabled_extensions = _mod("get_enabled_extensions") self._set_enabled_extensions = _mod("set_enabled_extensions") - def const_int_bound(self, expr): + def const_int_bound(self, expr: tir.PrimExpr) -> ConstIntBound: """Find constant integer bound for expr. Parameters @@ -138,7 +140,22 @@ def const_int_bound(self, expr): """ return self._const_int_bound(expr) - def modular_set(self, expr): + def const_int_bound_is_bound(self, var: tir.Var) -> bool: + """Check if a variable is bound to a range. + + Parameters + ---------- + var : tvm.tir.Var + The variable. + + Returns + ------- + result : bool + Whether the variable is bound to a range. + """ + return self._const_int_bound_is_bound(var) + + def modular_set(self, expr: tir.PrimExpr) -> ModularSet: """Find a modular set that expr belongs to. Parameters @@ -153,7 +170,7 @@ def modular_set(self, expr): """ return self._modular_set(expr) - def simplify(self, expr, steps=2): + def simplify(self, expr: tir.PrimExpr, steps: int = 2) -> tir.PrimExpr: """Simplify expression via both rewrite and canonicalization. Parameters @@ -173,7 +190,7 @@ def simplify(self, expr, steps=2): """ return self._simplify(expr, steps) - def rewrite_simplify(self, expr): + def rewrite_simplify(self, expr: tir.PrimExpr) -> tir.PrimExpr: """Simplify expression via rewriting rules. Parameters @@ -195,7 +212,7 @@ def rewrite_simplify_stats(self): def reset_rewrite_simplify_stats(self): self._reset_rewrite_simplify_stats() - def canonical_simplify(self, expr): + def canonical_simplify(self, expr: tir.PrimExpr) -> tir.PrimExpr: """Simplify expression via canonicalization. Parameters @@ -210,7 +227,7 @@ def canonical_simplify(self, expr): """ return self._canonical_simplify(expr) - def int_set(self, expr, dom_map): + def int_set(self, expr: tir.PrimExpr, dom_map: dict[tir.Var, IntSet]) -> IntSet: """Compute a symbolic IntSet that covers expr for all values in dom_map. Parameters @@ -228,7 +245,9 @@ def int_set(self, expr, dom_map): """ return self._int_set(expr, dom_map) - def can_prove(self, expr, strength=ProofStrength.DEFAULT): + def can_prove( + self, expr: tir.PrimExpr, strength: ProofStrength = ProofStrength.DEFAULT + ) -> bool: """Check whether we can prove expr to be true. Parameters @@ -246,7 +265,7 @@ def can_prove(self, expr, strength=ProofStrength.DEFAULT): """ return self._can_prove(expr, strength) - def bind(self, var: tir.Var, expr: Union[tir.PrimExpr, ir.Range]): + def bind(self, var: tir.Var, expr: Union[tir.PrimExpr, ir.Range]) -> None: """Bind a variable to the expression. Parameters @@ -259,7 +278,7 @@ def bind(self, var: tir.Var, expr: Union[tir.PrimExpr, ir.Range]): """ return self._bind(var, expr) - def constraint_scope(self, constraint): + def constraint_scope(self, constraint: tir.PrimExpr) -> ConstraintScope: """Create a constraint scope. Parameters @@ -290,7 +309,7 @@ def _fenter(): return ConstraintScope(_fenter) - def update(self, var, info, override=False): + def update(self, var: tir.Var, info: ConstIntBound, override: bool = False) -> None: """Update infomation about var Parameters @@ -309,7 +328,7 @@ def update(self, var, info, override=False): else: raise TypeError("Do not know how to handle type {}".format(type(info))) - def can_prove_equal(self, lhs: "PrimExpr", rhs: "PrimExpr"): + def can_prove_equal(self, lhs: tir.PrimExpr, rhs: tir.PrimExpr) -> bool: """Whether we can prove that lhs == rhs Parameters diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index f0a317659d3a..89cdf1c27876 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -288,6 +288,10 @@ TVM_FFI_REGISTER_GLOBAL("arith.CreateAnalyzer") self->const_int_bound.Update(args[0].cast(), args[1].cast(), args[2].cast()); }); + } else if (name == "const_int_bound_is_bound") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + *ret = self->const_int_bound.IsBound(args[0].cast()); + }); } else if (name == "Simplify") { return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { if (args.size() == 1) { diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index a440b52074e8..5078e5013865 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -109,6 +109,8 @@ class ConstIntBoundAnalyzer::Impl BoundInfo(PrimExpr expr, Entry bound) : expr(expr), bound(bound) {} }; + bool IsBound(const Var& var) const { return var_map_.find(var) != var_map_.end(); } + void Bind(const Var& var, const Range& range, bool allow_override) { Entry a = VisitExpr(range->min); Entry b = VisitExpr(range->extent); @@ -793,6 +795,8 @@ void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range, bool allow_ impl_->Bind(var, range, allow_override); } +bool ConstIntBoundAnalyzer::IsBound(const Var& var) const { return impl_->IsBound(var); } + std::function ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& constraint) { return impl_->EnterConstraint(constraint); } diff --git a/tests/python/arith/test_arith_const_int_bound.py b/tests/python/arith/test_arith_const_int_bound.py index e9b764c5f402..14bfec2328f2 100644 --- a/tests/python/arith/test_arith_const_int_bound.py +++ b/tests/python/arith/test_arith_const_int_bound.py @@ -51,6 +51,7 @@ def test_const_bounds(self, test_case): for var, bounds in test_case.known_bounds.items(): analyzer.update(var, ConstIntBound(*bounds)) + assert analyzer.const_int_bound_is_bound(var) with contextlib.ExitStack() as stack: if test_case.constraint is not None: