Skip to content
5 changes: 3 additions & 2 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -456,10 +456,11 @@ TVM_DLL Pass ConvertBlocksToOpaque();
*
* \endcode
*
*
* \param is_strict ensure the compacted shape always smaller than the original shape.
* otherwise it allows to grow the shape to match actual accessed buffer regions.
* \return The pass.
*/
TVM_DLL Pass CompactBufferAllocation();
TVM_DLL Pass CompactBufferAllocation(bool is_strict = true);
Comment thread
Lunderberg marked this conversation as resolved.

/*!
* This pass legalizes packed calls by wrapping their arguments into TVMValues
Expand Down
10 changes: 8 additions & 2 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ def ConvertBlocksToOpaque():
return _ffi_api.ConvertBlocksToOpaque() # type: ignore


def CompactBufferAllocation():
def CompactBufferAllocation(is_strict: bool = True):
"""Compact the buffer access region. by removing the buffer regions
that are not accessed, i.e. narrowing the buffer shape and adjust
the access region if necessary.
Expand Down Expand Up @@ -783,13 +783,19 @@ def CompactBufferAllocation():
for j in range(0, 16):
C[i, j] = B[0, j] + 1

Parameters
----------
is_strict : bool
Ensure the compacted shape to be always smaller than the original shape.
Otherwise it allows to grow the shape to match actual accessed buffer regions.

Returns
-------
fpass : tvm.transform.Pass
The result pass

"""
return _ffi_api.CompactBufferAllocation() # type: ignore
return _ffi_api.CompactBufferAllocation(is_strict) # type: ignore


def LowerMatchBuffer():
Expand Down
10 changes: 6 additions & 4 deletions src/tir/analysis/block_access_region_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class BlockReadWriteDetector : public StmtExprVisitor {
std::unordered_map<const VarNode*, arith::IntSet> dom_map_;
/*! \brief Extra iteration range hint for free vars */
std::unordered_map<const VarNode*, arith::IntSet> hint_map_;
/*! \brief Unresolved conditions within current scope. */
std::vector<PrimExpr> pending_conditions_;
/*! \brief The buffers that the current block reads */
std::vector<Buffer> read_buffers_;
/*! \brief The buffers that the current block writes */
Expand Down Expand Up @@ -164,12 +166,12 @@ void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) {
VisitExpr(op->condition);
{
// Visit then branch
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, true);
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, &pending_conditions_);
Comment thread
wrongtest-intellif marked this conversation as resolved.
StmtExprVisitor::VisitStmt(op->then_case);
}
if (op->else_case) {
// Visit else branch
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, false);
With<ConditionalBoundsContext> ctx(!op->condition, &dom_map_, &hint_map_, &pending_conditions_);
StmtExprVisitor::VisitStmt(op->else_case.value());
}
}
Expand Down Expand Up @@ -207,12 +209,12 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) {
VisitExpr(op->args[0]);
{
// Visit then branch
With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, true);
With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, &pending_conditions_);
StmtExprVisitor::VisitExpr(op->args[1]);
}
{
// Visit else branch
With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, false);
With<ConditionalBoundsContext> ctx(!op->args[0], &dom_map_, &hint_map_, &pending_conditions_);
StmtExprVisitor::VisitExpr(op->args[2]);
}
return;
Expand Down
5 changes: 1 addition & 4 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -481,10 +481,7 @@ TVM_DLL StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sr
*/
TVM_DLL StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis);
/******** Schedule: Block annotation ********/
/*! \brief The quad used by StorageAlign for (buffer_idx, axis, factor, offset) */
using StorageAlignTuple = Array<Integer>;
/*! \brief A list of StorageAlignTuple, used by StorageAlign */
using StorageAlignAnnotation = Array<StorageAlignTuple>;

/*!
* \brief Set alignment requirement for specific dimension such that
* stride[axis] == k * factor + offset for some k. This is useful to set memory layout for
Expand Down
1 change: 1 addition & 0 deletions src/tir/schedule/primitive/block_annotate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
#include <tvm/tir/expr.h>

#include "../../transforms/ir_utils.h"
#include "../utils.h"

namespace tvm {
Expand Down
Loading