Skip to content

[TensorIR][Schedule] New schedule primitive unsafe_hide_buffer_access#15144

Merged
yzh119 merged 5 commits intoapache:mainfrom
yzh119:hide-buffer-access
Jun 24, 2023
Merged

[TensorIR][Schedule] New schedule primitive unsafe_hide_buffer_access#15144
yzh119 merged 5 commits intoapache:mainfrom
yzh119:hide-buffer-access

Conversation

@yzh119
Copy link
Copy Markdown
Member

@yzh119 yzh119 commented Jun 22, 2023

Motivation

Currently, our tensorize schedule primitives rely on buffer read/write regions in the given block to perform pattern matching. However, for workloads such as block sparse operators, the read/write regions include some indices arraies that may fail tensorize primitive.
In SparseTIR we introduce a new schedule primitive called hide_buffer_access which allows us to hide certain buffer regions in a block so that the read/write buffer regions would be recognized by the tensorize primitive to further utilize tensor acceleration units.
This PR upstreams this schedule primitive to TensorIR mainline.

The schedule primitive interface

    def hide_buffer_access(self, block: BlockRV, buf_type: str, buf_index_array: List[int]) -> None:
        """Hide some buffer access in a given block.

        Parameters
        ----------
        block : BlockRV
            The block where we hide read access.
        buf_type : str
            The buffer type: "read"/"write".
        buf_index_array : List[int]
            The array of buffer indices we hide access.
        """
        pass

Example

@T.prim_func
def indirect_mem_access(a: T.handle, idx_a: T.handle, b: T.handle, idx_b: T.handle) -> None:
    A = T.match_buffer(a, [128], dtype="float32")
    IA = T.match_buffer(idx_a, [10], dtype="int32")
    B = T.match_buffer(b, [128], dtype="float32")
    IB = T.match_buffer(idx_b, [10], dtype="int32")

    for i in range(10):
        with T.block("B"):
            vi = T.axis.spatial(10, i)
            T.reads(A[IA[vi]], IA[vi])
            T.writes(B[IB[vi]], IB[vi])
            B[IB[vi]] = A[IA[vi]]

After we perform hiding buffer access to IA[vi] via:

sch = tir.Schedule(indirect_mem_access, debug_mask="all")
block_b = sch.get_block("B")
sch.hide_buffer_access(block_b, "write", [1]) 

the desired transformed IR would be:

@T.prim_func
def indirect_mem_access_hide_ia(a: T.handle, idx_a: T.handle, b: T.handle, idx_b: T.handle) -> None:
    A = T.match_buffer(a, [128], dtype="float32")
    IA = T.match_buffer(idx_a, [10], dtype="int32")
    B = T.match_buffer(b, [128], dtype="float32")
    IB = T.match_buffer(idx_b, [10], dtype="int32")

    for i in range(10):
        with T.block("B"):
            vi = T.axis.spatial(10, i)
            T.reads(A[IA[vi]])
            T.writes(B[IB[vi]], IB[vi])
            B[IB[vi]] = A[IA[vi]]

The existing passes/schedules would not be influenced by this PR.

cc @junrushao @MasterJH5574 @masahi

@tvm-bot
Copy link
Copy Markdown
Collaborator

tvm-bot commented Jun 22, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

  • No users to tag found in teams: tensorir, schedule See #10317 for details

Generated by tvm-bot

@yzh119 yzh119 requested a review from Hzfengsy June 22, 2023 23:17
Copy link
Copy Markdown
Contributor

@quic-sanirudh quic-sanirudh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, a nice addition :). As mentioned in my comment below, could we rename the primitive to unsafe_hide_buffer_access and perhaps add some comments in the docstring to indicate the chances of incorrect output resulting from using this primitive.

Comment thread src/tir/schedule/primitive/hide_buffer_access.cc
@yzh119
Copy link
Copy Markdown
Member Author

yzh119 commented Jun 24, 2023

Hi @quic-sanirudh, thank you for your suggestions, I have marked the schedule as unsafe and added some docstrings explaining it.

@yzh119 yzh119 changed the title [TensorIR][Schedule] New schedule primitive hide_buffer_access [TensorIR][Schedule] New schedule primitive unsafe_hide_buffer_access Jun 24, 2023
@quic-sanirudh
Copy link
Copy Markdown
Contributor

Hi @quic-sanirudh, thank you for your suggestions, I have marked the schedule as unsafe and added some docstrings explaining it.

Thanks for taking my suggestion. Looks good to me now

@yzh119 yzh119 merged commit 0a5f5f0 into apache:main Jun 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants