From cacb737c89b0153c6d952fc4609ffc61ec50936b Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Mon, 8 Jul 2024 11:15:49 +0800 Subject: [PATCH] [Relax] Fix cublas dispatch for corner cases Fix case when `lhs_batches` and `rhs_batches` are symbolic expressions, but not standalone variables. --- python/tvm/relax/backend/contrib/cublas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/backend/contrib/cublas.py b/python/tvm/relax/backend/contrib/cublas.py index febb401bc0d1..287b18b4409a 100644 --- a/python/tvm/relax/backend/contrib/cublas.py +++ b/python/tvm/relax/backend/contrib/cublas.py @@ -134,7 +134,7 @@ def _check_matmul(context: PatternCheckContext) -> bool: isinstance(lhs_batches, tvm.tir.Var) or isinstance(rhs_batches, tvm.tir.Var) or (analyzer.can_prove_equal(lhs_batches, rhs_batches)) - or (lhs_batches >= 1 and rhs_batches == 1) + or (analyzer.can_prove(lhs_batches >= 1) and analyzer.can_prove(rhs_batches == 1)) )