From 2507acd9ef599ec8b4ce10b4673c96457214c459 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 9 Aug 2023 15:07:26 -0700 Subject: [PATCH 1/5] [TIR] Shuffle in PointerValueTypeRewrite for scalar reads Added an option `rewrite_scalar_read_to_vector_shuffle` in `PointerValueTypeRewrite` (currently only enabled for Vulkan). When enabled, when a buffer has both scalar and vector reads, the buffer will be vectorized if possible and scalar reads will be achieved via T.Shuffle. Close https://github.com/apache/tvm/issues/15463. --- python/tvm/tir/transform/transform.py | 13 +++ src/target/spirv/codegen_spirv.cc | 11 ++ src/target/spirv/codegen_spirv.h | 1 + src/tir/transforms/storage_rewrite.cc | 105 ++++++++++++------ ...ir_transform_pointer_value_type_rewrite.py | 64 +++++++++++ 5 files changed, 160 insertions(+), 34 deletions(-) create mode 100644 tests/python/unittest/test_tir_transform_pointer_value_type_rewrite.py diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 0cd54064a7b5..2f564e4ec8e6 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -230,6 +230,19 @@ def StorageRewrite(): return _ffi_api.StorageRewrite() # type: ignore +def PointerValueTypeRewrite(): + """ + Rewrite the pointer content type of arguments, as well as Alloc internal to the function to use + the most frequently accessed type for load/store to avoid pointer casting in backend when possible. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.PointerValueTypeRewrite() # type: ignore + + def UnrollLoop(): """Unroll the constant loop marked by unroll. diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index ab9aec077542..5cc3f8f8ddc0 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -610,6 +610,17 @@ void CodeGenSPIRV::Scalarize(const PrimExpr& e, std::functionvectors.size() == 1 && op->indices.size() == 1) + << "SPIR-V codegen only supports shuffle " + << "of one vector with one index"; + spirv::Value vector = MakeValue(op->vectors[0]); + int index = Downcast(op->indices[0])->value; + spirv::SType etype = builder_->GetSType(op->dtype); + spirv::Value element = builder_->MakeValue(spv::OpCompositeExtract, etype, vector, index); + return element; +} + void CodeGenSPIRV::VisitStmt_(const BufferStoreNode* op) { ICHECK_EQ(op->indices.size(), 1) << "SPIR-V codegen expects flat memory buffers"; Var buffer_var = op->buffer->data; diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index 1e7b53558508..8ea90a9c4b80 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -102,6 +102,7 @@ class CodeGenSPIRV : public ExprFunctor, spirv::Value VisitExpr_(const RampNode* op) override; spirv::Value VisitExpr_(const BroadcastNode* op) override; spirv::Value VisitExpr_(const BufferLoadNode* op) override; + spirv::Value VisitExpr_(const ShuffleNode* op) override; // stmt void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const ForNode* op) override; diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 3ecd0f64bb44..44315da9414a 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -33,6 +33,7 @@ #include #include +#include #include #include @@ -1083,17 +1084,14 @@ struct BufferVarInfo { DataType preferred_base_type = *base_access_dtype.begin(); - // If there is only one vectorizable size used to access the - // buffer, and if that access size is compatible with the array - // size, then the buffer is vectorizable. In the future, this - // could be improved to allow vectorized buffer access of size - // GCD(*lanes_used), if necessary. int preferred_lanes = element_dtype.lanes(); - if ((element_dtype.lanes() == 1) && (access_dtype.size() == 1)) { + if (element_dtype.lanes() == 1) { + int lanes = access_dtype.begin()->lanes(); + for (auto dtype : access_dtype) { + lanes = std::gcd(lanes, dtype.lanes()); + } arith::Analyzer analyzer_; arith::ModularSet me = analyzer_.modular_set(extent); - - int lanes = access_dtype.begin()->lanes(); if ((me->coeff % lanes == 0) && (me->base % lanes == 0)) { preferred_lanes = lanes; } @@ -1120,8 +1118,10 @@ class VectorTypeAccessChecker : public StmtExprVisitor { * type as it is later accessed, with scalar element types. */ VectorTypeAccessChecker(const Array& params, const Map& buffer_map, - bool allow_untyped_pointers = false) - : allow_untyped_pointers_(allow_untyped_pointers) { + bool allow_untyped_pointers = false, + bool detect_scalar_read_patterns = true) + : allow_untyped_pointers_(allow_untyped_pointers), + detect_scalar_read_patterns_(detect_scalar_read_patterns) { // If a parameter is in the buffer map, we want to track the // version in the map. for (auto it : buffer_map) { @@ -1145,12 +1145,12 @@ class VectorTypeAccessChecker : public StmtExprVisitor { } void VisitExpr_(const BufferLoadNode* op) final { - OnArrayAccess(op->dtype, op->buffer->data.get(), op->indices); + OnArrayAccess(op->dtype, op->buffer->data.get(), op->indices, /*is_buffer_load=*/true); StmtExprVisitor::VisitExpr_(op); } void VisitStmt_(const BufferStoreNode* op) final { - OnArrayAccess(op->value.dtype(), op->buffer->data.get(), op->indices); + OnArrayAccess(op->value.dtype(), op->buffer->data.get(), op->indices, /*is_buffer_load=*/false); StmtExprVisitor::VisitStmt_(op); } @@ -1159,7 +1159,10 @@ class VectorTypeAccessChecker : public StmtExprVisitor { DataType dtype = op->args[0].dtype(); const VarNode* buffer = op->args[1].as(); PrimExpr index = op->args[2]; - OnArrayAccess(dtype, buffer, {index}); + OnArrayAccess(dtype, buffer, {index}, false); + } else if (op->op.same_as(builtin::address_of())) { + BufferLoad load = Downcast(op->args[0]); + OnArrayAccess(load->dtype, load->buffer->data.get(), load->indices, /*is_buffer_load=*/false); } StmtExprVisitor::VisitExpr_(op); } @@ -1226,8 +1229,7 @@ class VectorTypeAccessChecker : public StmtExprVisitor { if (element_dtype == DataType::Bool()) { element_dtype = DataType::Int(8).with_lanes(element_dtype.lanes()); } - - info_map_[buffer.get()] = {buffer, element_dtype, extent, declaration_location}; + info_map_[buffer.get()] = BufferVarInfo{buffer, element_dtype, extent, declaration_location}; } /* Update the type map for a buffer based on its usage @@ -1237,11 +1239,12 @@ class VectorTypeAccessChecker : public StmtExprVisitor { * * @param buffer The VarNode representing the buffer. * - * @param index The index at which the value is being stored/loaded. + * @param indices The index at which the value is being stored/loaded. * - * @param predicate The predicate used for the store/load. + * @param is_buffer_load Whether the access is BufferLoad */ - void OnArrayAccess(DataType value_dtype, const VarNode* buffer, const Array& indices) { + void OnArrayAccess(DataType value_dtype, const VarNode* buffer, const Array& indices, + bool is_buffer_load) { auto it = info_map_.find(buffer); ICHECK(it != info_map_.end()) << "Load/Store of buffer " << buffer->name_hint << " (" << buffer << ") occurred before its declaration."; @@ -1304,6 +1307,18 @@ class VectorTypeAccessChecker : public StmtExprVisitor { } } + if (detect_scalar_read_patterns_ && is_buffer_load && indices.size()) { + const PrimExpr last_dim_index = indices[indices.size() - 1]; + if (last_dim_index.dtype().lanes() == 1) { + arith::ModularSet me = analyzer_.modular_set(last_dim_index); + if (me->coeff > 0) { + // When coeff == 0, the index is constant and doesn't need to be recorded since it can + // always be rewritten to shuffle. + var_info.access_dtype.insert(access_dtype.with_lanes(me->coeff)); + } + return; + } + } var_info.access_dtype.insert(access_dtype.with_lanes(lanes_used)); } @@ -1312,6 +1327,8 @@ class VectorTypeAccessChecker : public StmtExprVisitor { // bool allow_untyped_pointers_{false}; + // Whether to detect scalar read patterns for rewriting to vector shuffle + bool detect_scalar_read_patterns_{true}; // internal analyzer arith::Analyzer analyzer_; @@ -1366,7 +1383,8 @@ class VectorTypeRewriter : public StmtExprMutator { VectorTypeRewriter(const std::unordered_map& info_map, bool rewrite_params = true, bool rewrite_buffer_map = true, bool rewrite_allocate_node = true, bool rewrite_indices = true, - bool rewrite_let_node = true, bool rewrite_allocate_const_node = true) + bool rewrite_let_node = true, bool rewrite_allocate_const_node = true, + bool rewrite_scalar_read_to_vector_shuffle = true) : rewrite_indices_(rewrite_indices) { int rewrite_mask = 0; if (rewrite_params) { @@ -1401,42 +1419,52 @@ class VectorTypeRewriter : public StmtExprMutator { } } + /*! + * \brief Mutator for BufferLoad or BufferStore. + * \return The rewritten node and the shuffle index. (Only for BufferLoad) When the shuffle index + * is non-negative, the caller should generate Shuffle to extract the element from the vector. + */ template - Node VisitBufferAccess(Node node) { + std::pair VisitBufferAccess(Node node) { + int shuffle_index = -1; if (!rewrite_indices_) { - return node; + return {node, shuffle_index}; } auto it = rewrite_map_.find(node->buffer->data.get()); if (it == rewrite_map_.end()) { - return node; + return {node, shuffle_index}; } const auto& info = it->second; Array indices = node->indices; - - const RampNode* ramp_index = indices[indices.size() - 1].as(); - if (ramp_index && is_one(ramp_index->stride)) { + const PrimExpr& last_dim_index = indices[indices.size() - 1]; + if (const RampNode* ramp_index = last_dim_index.as(); + ramp_index && is_one(ramp_index->stride)) { PrimExpr new_index = ramp_index->base / make_const(ramp_index->base.dtype(), ramp_index->lanes); if (ramp_index->lanes != info.factor()) { new_index = Ramp(new_index, ramp_index->stride, ramp_index->lanes / info.factor(), ramp_index->span); } - + indices.Set(indices.size() - 1, new_index); + } else if (last_dim_index.dtype().lanes() == 1 && info.factor() > 1) { + arith::ModularSet me = analyzer_.modular_set(last_dim_index); + ICHECK(me->coeff == 0 || info.factor() % me->coeff == 0); + PrimExpr new_index = last_dim_index / make_const(last_dim_index.dtype(), info.factor()); + shuffle_index = me->base; indices.Set(indices.size() - 1, new_index); } auto writer = node.CopyOnWrite(); writer->buffer = RemapBuffer(node->buffer); writer->indices = indices; - - return node; + return {node, shuffle_index}; } PrimExpr VisitExpr_(const BufferLoadNode* op) final { auto node = Downcast(StmtExprMutator::VisitExpr_(op)); - auto modified = VisitBufferAccess(node); + auto [modified, shuffle_index] = VisitBufferAccess(node); // Not needed for BufferStoreNode, so we can't just call // LegalizeDtype() in VisitBufferAccess. @@ -1445,13 +1473,18 @@ class VectorTypeRewriter : public StmtExprMutator { } else { auto writer = modified.CopyOnWrite(); writer->LegalizeDType(); + if (shuffle_index >= 0) { + return Shuffle::ExtractElement(std::move(modified), shuffle_index); + } return std::move(modified); } } Stmt VisitStmt_(const BufferStoreNode* op) final { auto node = Downcast(StmtExprMutator::VisitStmt_(op)); - return VisitBufferAccess(std::move(node)); + auto [modified, shuffle_index] = VisitBufferAccess(std::move(node)); + ICHECK(shuffle_index < 0); + return std::move(modified); } Stmt VisitStmt_(const LetStmtNode* op) final { @@ -1627,6 +1660,7 @@ class VectorTypeRewriter : public StmtExprMutator { bool rewrite_indices_{true}; std::unordered_map rewrite_map_; std::unordered_map buffer_map_; + arith::Analyzer analyzer_; }; // Rewrite allocates, pointer parameters, and buffer map into vectorized versions @@ -1635,13 +1669,15 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f, bool allow_untyped_pointers = false bool rewrite_params = true, bool rewrite_buffer_map = true, bool rewrite_allocate_node = true, bool rewrite_indices = true, bool rewrite_let_node = true, - bool rewrite_allocate_const_node = true) { - VectorTypeAccessChecker checker(f->params, f->buffer_map, allow_untyped_pointers); + bool rewrite_allocate_const_node = true, + bool rewrite_scalar_read_to_vector_shuffle = true) { + VectorTypeAccessChecker checker(f->params, f->buffer_map, allow_untyped_pointers, + rewrite_scalar_read_to_vector_shuffle); checker(f->body); VectorTypeRewriter rewriter(checker.info_map_, rewrite_params, rewrite_buffer_map, rewrite_allocate_node, rewrite_indices, rewrite_let_node, - rewrite_allocate_const_node); + rewrite_allocate_const_node, rewrite_scalar_read_to_vector_shuffle); PrimFuncNode* n = f.CopyOnWrite(); n->body = rewriter(std::move(n->body)); rewriter.Finalize(&f); @@ -1661,7 +1697,8 @@ Pass StorageRewrite() { // padded out to 32 bits) would require either rewriting // AllocateConst::data, or would require the code generators to // handle vectorized constants. - return PointerValueTypeRewrite(std::move(f), true, false, false, true, true, true, false); + return PointerValueTypeRewrite(std::move(f), true, false, false, true, true, true, false, + false); }; return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {}); } diff --git a/tests/python/unittest/test_tir_transform_pointer_value_type_rewrite.py b/tests/python/unittest/test_tir_transform_pointer_value_type_rewrite.py new file mode 100644 index 000000000000..b7e948170a13 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_pointer_value_type_rewrite.py @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +from tvm import te +from tvm.driver.build_module import schedule_to_module +from tvm.script import tir as T + + +class BaseCompare(tvm.testing.CompareBeforeAfter): + transform = tvm.tir.transform.PointerValueTypeRewrite() + + +class TestRewriteToShuffle(BaseCompare): + @T.prim_func + def before(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32")): + A_local_data = T.allocate([16], "float32", scope="local") + A_local = T.Buffer((16,), "float32", data=A_local_data, scope="local") + for i in range(4): + A_local[i * 4 : i * 4 + 4] = A[i * 4 : i * 4 + 4] + for i in range(4): + B[i] = A_local[i * 4] + A_local[i * 4 + 1] + A_local[i * 4 + 2] + A_local[i * 4 + 3] + + @T.prim_func + def expected(A: T.Buffer((4,), "float32x4"), B: T.Buffer((4,), "float32")): + A_local_data = T.allocate([4], "float32x4", scope="local") + A_local = T.Buffer((4,), "float32x4", data=A_local_data, scope="local") + for i in range(4): + A_local[T.Div(i * 4, 4)] = A[T.Div(i * 4, 4)] + for i in range(4): + B[i] = ( + T.Shuffle([A_local[T.Div(i * 4, 4)]], [0]) + + T.Shuffle([A_local[T.Div(i * 4 + 1, 4)]], [1]) + + T.Shuffle([A_local[T.Div(i * 4 + 2, 4)]], [2]) + + T.Shuffle([A_local[T.Div(i * 4 + 3, 4)]], [3]) + ) + + +class TestAddressOf(BaseCompare): + @T.prim_func + def before(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): + for i in range(4): + T.evaluate(T.address_of(A[i * 4])) + B[i * 4 : i * 4 + 4] = A[i * 4 : i * 4 + 4] + + @T.prim_func + def expected(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32x4")): + for i in range(4): + T.evaluate(T.address_of(A[i * 4])) + B[T.Div(i * 4, 4)] = A[i * 4 : i * 4 + 4] From 9a90ec1cbe329eadc80da8c09893fdc045eab4c8 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 14 Aug 2023 11:50:13 -0700 Subject: [PATCH 2/5] address comments --- src/tir/transforms/storage_rewrite.cc | 24 +++++++++++++------ ...ir_transform_pointer_value_type_rewrite.py | 9 +++++++ 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 44315da9414a..2e24b1f69777 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -33,10 +33,10 @@ #include #include -#include #include #include +#include "../../arith/int_operator.h" #include "../../runtime/thread_storage_scope.h" #include "../ir/buffer_common.h" #include "ir_utils.h" @@ -1067,12 +1067,18 @@ struct BufferVarInfo { // packing in StorageRewrite) or in number of lanes (e.g. float16* // cast to float16x4*). std::unordered_set access_dtype; + // Data types used for scalar reads. This is used to record vectorized read dtypes that can be + // shuffled for scalar reads when rewrite_scalar_read_to_vector_shuffle is enabled. + std::unordered_set scalar_read_dtype; DataType get_preferred_dtype() const { std::unordered_set base_access_dtype; for (auto dtype : access_dtype) { base_access_dtype.insert(dtype.element_of()); } + for (auto dtype : scalar_read_dtype) { + base_access_dtype.insert(dtype.element_of()); + } // If the array is accessed as multiple base types within a // function, no point in changing the declared type. CodeGenC can // handle this with a type-cast prior to indexing. Vulkan will @@ -1082,13 +1088,21 @@ struct BufferVarInfo { return element_dtype; } + // When there are scalar reads and no writes, access_dtype can be empty and we should avoid + // rewriting. + if (access_dtype.empty()) { + return element_dtype; + } DataType preferred_base_type = *base_access_dtype.begin(); int preferred_lanes = element_dtype.lanes(); if (element_dtype.lanes() == 1) { int lanes = access_dtype.begin()->lanes(); for (auto dtype : access_dtype) { - lanes = std::gcd(lanes, dtype.lanes()); + lanes = arith::ZeroAwareGCD(lanes, dtype.lanes()); + } + for (auto dtype : scalar_read_dtype) { + lanes = arith::ZeroAwareGCD(lanes, dtype.lanes()); } arith::Analyzer analyzer_; arith::ModularSet me = analyzer_.modular_set(extent); @@ -1311,11 +1325,7 @@ class VectorTypeAccessChecker : public StmtExprVisitor { const PrimExpr last_dim_index = indices[indices.size() - 1]; if (last_dim_index.dtype().lanes() == 1) { arith::ModularSet me = analyzer_.modular_set(last_dim_index); - if (me->coeff > 0) { - // When coeff == 0, the index is constant and doesn't need to be recorded since it can - // always be rewritten to shuffle. - var_info.access_dtype.insert(access_dtype.with_lanes(me->coeff)); - } + var_info.scalar_read_dtype.emplace(access_dtype.with_lanes(me->coeff)); return; } } diff --git a/tests/python/unittest/test_tir_transform_pointer_value_type_rewrite.py b/tests/python/unittest/test_tir_transform_pointer_value_type_rewrite.py index b7e948170a13..7baa96c1a16e 100644 --- a/tests/python/unittest/test_tir_transform_pointer_value_type_rewrite.py +++ b/tests/python/unittest/test_tir_transform_pointer_value_type_rewrite.py @@ -62,3 +62,12 @@ def expected(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32x4")): for i in range(4): T.evaluate(T.address_of(A[i * 4])) B[T.Div(i * 4, 4)] = A[i * 4 : i * 4 + 4] + + +class TestScalarReadWithoutWrite(BaseCompare): + @T.prim_func + def before(A: T.Buffer((16,), "float32")): + for i in range(4): + T.evaluate(A[i * 4]) + + expected = before From bd7a4bd3b95bc38c14ab76ceced0e9d3374127cb Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 15 Aug 2023 12:29:40 -0700 Subject: [PATCH 3/5] lint --- python/tvm/tir/transform/transform.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 2f564e4ec8e6..a46b2d10373f 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -233,7 +233,8 @@ def StorageRewrite(): def PointerValueTypeRewrite(): """ Rewrite the pointer content type of arguments, as well as Alloc internal to the function to use - the most frequently accessed type for load/store to avoid pointer casting in backend when possible. + the most frequently accessed type for load/store to avoid pointer casting in backend when + possible. Returns ------- From 997b0c0b08446d7d6236770bb88c463821e6a4d1 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 17 Aug 2023 01:01:13 +0000 Subject: [PATCH 4/5] Avoid rewriting when there are multiple vector accesses --- src/tir/transforms/storage_rewrite.cc | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 2e24b1f69777..4a4a25ed560e 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -1088,21 +1088,23 @@ struct BufferVarInfo { return element_dtype; } - // When there are scalar reads and no writes, access_dtype can be empty and we should avoid - // rewriting. - if (access_dtype.empty()) { - return element_dtype; - } DataType preferred_base_type = *base_access_dtype.begin(); + // If there is only one vectorizable size used to access the + // buffer, and if that access size is compatible with the array + // size, then the buffer is vectorizable. In the future, this + // could be improved to allow vectorized buffer access of size + // GCD(*lanes_used), if necessary. + // When there are scalar reads and no writes, access_dtype can be empty and we should avoid + // rewriting. int preferred_lanes = element_dtype.lanes(); - if (element_dtype.lanes() == 1) { + if (element_dtype.lanes() == 1 && (access_dtype.size() == 1)) { int lanes = access_dtype.begin()->lanes(); - for (auto dtype : access_dtype) { - lanes = arith::ZeroAwareGCD(lanes, dtype.lanes()); - } + // Check the scalar read dtypes are compatible with the vectorized access dtype. for (auto dtype : scalar_read_dtype) { - lanes = arith::ZeroAwareGCD(lanes, dtype.lanes()); + if (dtype.lanes() % lanes != 0) { + return element_dtype; + } } arith::Analyzer analyzer_; arith::ModularSet me = analyzer_.modular_set(extent); @@ -1454,7 +1456,9 @@ class VectorTypeRewriter : public StmtExprMutator { PrimExpr new_index = ramp_index->base / make_const(ramp_index->base.dtype(), ramp_index->lanes); if (ramp_index->lanes != info.factor()) { - new_index = Ramp(new_index, ramp_index->stride, ramp_index->lanes / info.factor(), + ICHECK(info.factor() && ramp_index->lanes % info.factor() == 0); + int new_lanes = ramp_index->lanes / info.factor(); + new_index = Ramp(new_index * new_lanes, ramp_index->stride, new_lanes, ramp_index->span); } indices.Set(indices.size() - 1, new_index); From d563055332bd5787b7bfcf104239f668221d2cf5 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 17 Aug 2023 20:29:15 +0000 Subject: [PATCH 5/5] lint --- src/tir/transforms/storage_rewrite.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 4a4a25ed560e..f271769c804b 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -1458,8 +1458,7 @@ class VectorTypeRewriter : public StmtExprMutator { if (ramp_index->lanes != info.factor()) { ICHECK(info.factor() && ramp_index->lanes % info.factor() == 0); int new_lanes = ramp_index->lanes / info.factor(); - new_index = Ramp(new_index * new_lanes, ramp_index->stride, new_lanes, - ramp_index->span); + new_index = Ramp(new_index * new_lanes, ramp_index->stride, new_lanes, ramp_index->span); } indices.Set(indices.size() - 1, new_index); } else if (last_dim_index.dtype().lanes() == 1 && info.factor() > 1) {