diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 06f181030ee0..1b82e93eacf7 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -1391,7 +1391,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const LTNode* op) { // First convert a < b into a - b < 0 PrimExpr expr = this->CanonicalMutate(op->a - op->b); // Case: x0 * s0 + x1 * s1 + ... + xn + c < 0, let d = gcd(s0, s1, ..., s{n-1}, c) - // 1. if can prove -d < xn < d, then we can simplify + // 1. if can prove 0 <= xn < d, then we can simplify // the expression to x0 * (s0/d) + x1 * (s1/d) + ... + x{n-1} * (s{n-1}/d) < c/d, // e.g. `x * 8 + y < 16` where `y` \in [0, 8), we can simplify it to `x < 2` // 2. if xn is in pattern of yn % m, where m % d == 0, convert it to yn // d % (m/d) @@ -1417,8 +1417,8 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const LTNode* op) { ICHECK(extra->dtype == dtype); PrimExpr normal_extra = extra->Normalize(); if (this->analyzer_->CanProve(normal_extra < make_const(dtype, gcd)) && - this->analyzer_->CanProve(normal_extra > make_const(dtype, -gcd))) { - // Case 1. -d < xn < d + this->analyzer_->CanProve(normal_extra >= make_const(dtype, 0))) { + // Case 1. 0 <= xn < d divisible.CopyOnWrite()->DivideBy(gcd); return Rewriter::VisitExpr(divisible->Normalize() < make_zero(dtype)); } else if (extra->args.size() == 1 && diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h index b77adda4c9f4..0068db51d522 100644 --- a/src/runtime/pack_args.h +++ b/src/runtime/pack_args.h @@ -134,7 +134,7 @@ enum ArgConvertCode { }; inline ArgConvertCode GetArgConvertCode(DLDataType t) { - ICHECK_EQ(t.lanes, 1U) << "Cannot pass vector type argument to devic function for now"; + ICHECK_EQ(t.lanes, 1U) << "Cannot pass vector type argument to device function for now"; if (t.code == kDLInt) { if (t.bits == 64U) return INT64_TO_INT64; if (t.bits == 32U) return INT64_TO_INT32; diff --git a/tests/python/arith/test_arith_canonical_simplify.py b/tests/python/arith/test_arith_canonical_simplify.py index 42f5b0ccd0b8..733d1d13b371 100644 --- a/tests/python/arith/test_arith_canonical_simplify.py +++ b/tests/python/arith/test_arith_canonical_simplify.py @@ -448,7 +448,6 @@ def test_simplify_le(): ck.verify(x * -8 + z * 4 < 16, ck.analyzer.rewrite_simplify(-2 < x)) ck.verify(x * 8 + y + z < 16, x * 8 + y + z < 16) - ck.verify(x * 8 + y - z < 16, x < 2) n = te.size_var("n") ck.verify(x * 8 + y < n, x * 8 + y < n) diff --git a/tests/python/dlight/test_gpu_low_batch_gemv.py b/tests/python/dlight/test_gpu_low_batch_gemv.py index 6341b7b0ae66..ae07a3b7318c 100644 --- a/tests/python/dlight/test_gpu_low_batch_gemv.py +++ b/tests/python/dlight/test_gpu_low_batch_gemv.py @@ -136,7 +136,7 @@ def expected(lv429: T.Buffer((T.int64(4096), T.int64(3584)), "uint32"), lv430: T with T.block("NT_matmul_intermediate_pad"): v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(4) + ax0) v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2) - T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size) + T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 * T.int64(4) + ax0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size) T.reads(NT_matmul_intermediate_pad_local[v0, T.int64(0), v1]) T.writes(NT_matmul_intermediate[v0, T.int64(0), v1]) NT_matmul_intermediate[v0, T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] @@ -240,7 +240,7 @@ def expected(var_A: T.handle, B: T.Buffer((T.int64(4096), T.int64(4096)), "float with T.block("NT_matmul_pad"): v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(4) + ax0) v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2) - T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size) + T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 * T.int64(4) + ax0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size) T.reads(NT_matmul_pad_local[v0, T.int64(0), v1]) T.writes(NT_matmul[v0, T.int64(0), v1]) NT_matmul[v0, T.int64(0), v1] = NT_matmul_pad_local[v0, T.int64(0), v1] @@ -369,7 +369,7 @@ def expected(var_A: T.handle, B: T.Buffer((T.int64(8), T.int64(4096)), "float16" with T.block("C_pad"): v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(4) + ax0) v1 = T.axis.spatial(T.int64(8), ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2) - T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size and (T.Mul(T.int64(0), T.int64(16)) + ax1_fused_0_ax1_fused_1_fused % T.int64(16)) * T.int64(2) + ax1_fused_2 < T.int64(8)) + T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 * T.int64(4) + ax0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size and (T.Mul(T.int64(0), T.int64(16)) + ax1_fused_0_ax1_fused_1_fused % T.int64(16)) * T.int64(2) + ax1_fused_2 < T.int64(8)) T.reads(C_pad_local[v0, v1]) T.writes(C[v0, v1]) C[v0, v1] = C_pad_local[v0, v1] @@ -516,7 +516,7 @@ def expected(B0: T.Buffer((512, 6144), "uint32"), B1: T.Buffer((128, 6144), "flo with T.block("C_pad"): v0 = T.axis.spatial(batch_size, ax0_0 * 4 + ax0) v1 = T.axis.spatial(6144, ax1_fused_0 * 64 + ax1) - T.where((ax0_0 - (batch_size + 3) // 4 < 0 or ax0_0 == 0) and ax0_0 * 4 + ax0 < batch_size) + T.where((ax0_0 - (batch_size + 3) // 4 < 0 or ax0_0 * 4 + ax0 == 0) and ax0_0 * 4 + ax0 < batch_size) T.reads(C_pad_local[v0, 0, v1]) T.writes(C[v0, 0, v1]) C[v0, 0, v1] = C_pad_local[v0, 0, v1] diff --git a/tests/python/dlight/test_gpu_matmul.py b/tests/python/dlight/test_gpu_matmul.py index 2fa61faf40f8..f27d9d370fce 100644 --- a/tests/python/dlight/test_gpu_matmul.py +++ b/tests/python/dlight/test_gpu_matmul.py @@ -695,7 +695,7 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.spatial(m, i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0) v2 = T.axis.spatial(T.int64(4096), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1) - T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 - (m + T.int64(15)) // T.int64(16) < T.int64(0) or i0_i1_fused_0 == T.int64(0)) and i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 < m) + T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 - (m + T.int64(15)) // T.int64(16) < T.int64(0) or i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 == T.int64(0)) and i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 < m) T.reads(matmul_pad_local[v0, v1, v2]) T.writes(matmul[v0, v1, v2]) matmul[v0, v1, v2] = matmul_pad_local[v0, v1, v2] @@ -835,7 +835,7 @@ def expected(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv453: T v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) v_ax1 = T.axis.spatial(seq_len, i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0) v_ax2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1) - T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 - (seq_len + T.int64(15)) // T.int64(16) < T.int64(0) or i0_i1_fused_0 == T.int64(0)) and i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 < seq_len) + T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 - (seq_len + T.int64(15)) // T.int64(16) < T.int64(0) or i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 == T.int64(0)) and i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 < seq_len) T.reads(matmul_intermediate_pad_local[v_ax0, v_ax1, v_ax2], transformer_h_0_attn_c_attn_bias3[v_ax2]) T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2]) T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2] = matmul_intermediate_pad_local[v_ax0, v_ax1, v_ax2] + transformer_h_0_attn_c_attn_bias3[v_ax2]