diff --git a/python/tvm/relax/backend/contrib/cublas.py b/python/tvm/relax/backend/contrib/cublas.py index febb401bc0d1..287b18b4409a 100644 --- a/python/tvm/relax/backend/contrib/cublas.py +++ b/python/tvm/relax/backend/contrib/cublas.py @@ -134,7 +134,7 @@ def _check_matmul(context: PatternCheckContext) -> bool: isinstance(lhs_batches, tvm.tir.Var) or isinstance(rhs_batches, tvm.tir.Var) or (analyzer.can_prove_equal(lhs_batches, rhs_batches)) - or (lhs_batches >= 1 and rhs_batches == 1) + or (analyzer.can_prove(lhs_batches >= 1) and analyzer.can_prove(rhs_batches == 1)) )