-
Notifications
You must be signed in to change notification settings - Fork 3.9k
[TIR][Schedule] Transform layout #10538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
50e1e92
ab3f83d
203ea83
b4f58df
1e0c445
022ddd3
21bcab5
dcdf22c
a83c413
d518267
d82bccf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,13 +15,15 @@ | |
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| """The TensorIR schedule class""" | ||
| from typing import Dict, List, Optional, Union | ||
| import enum | ||
| from typing import Callable, Dict, List, Optional, Union | ||
|
|
||
| from tvm._ffi import register_object as _register_object | ||
| from tvm.error import TVMError, register_error | ||
| from tvm.ir import IRModule, PrimExpr | ||
| from tvm.runtime import Object, String | ||
| from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc | ||
| from ..function import IndexMap | ||
|
|
||
| from . import _ffi_api | ||
| from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod | ||
|
|
@@ -71,6 +73,13 @@ def __init__(self) -> None: | |
| } | ||
|
|
||
|
|
||
| class BufferType(enum.IntEnum): | ||
| """Type of buffer in access regions of a block""" | ||
|
|
||
| READ = 0 | ||
| WRITE = 1 | ||
|
|
||
|
Comment on lines
+76
to
+81
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @vinx13 Sorry I'm late and missed the code review. On this particular change, I'm 100% in favor of having enum type on C++ side; However, it could be more to "just use string" on the python side, i.e. use "read" and "write" to indicate BufferType. If you agree with my opinion, would you mind sending a quick patch? Thanks a lot!
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This sounds good to me as well, and would give the same readability benefits at the caller side. I kind of like the At some point, it would be nice to have a macro to define an enum in C++, along with its value/name mapping and an FFI interface, so that there would be a clear way to handle these. |
||
|
|
||
| def _parse_error_render_level(error_render_level: str) -> int: | ||
| if error_render_level not in _ERROR_RENDER_LEVEL: | ||
| raise ValueError( | ||
|
|
@@ -2111,6 +2120,82 @@ def after_unannotate(a: T.handle, b: T.handle) -> None: | |
| self, block_or_loop, ann_key | ||
| ) | ||
|
|
||
| ########## Schedule: Layout transformation ########## | ||
|
|
||
| @type_checked | ||
| def transform_layout( | ||
|
vinx13 marked this conversation as resolved.
|
||
| self, | ||
| block: BlockRV, | ||
|
Lunderberg marked this conversation as resolved.
|
||
| buffer_index: int, | ||
| buffer_type: BufferType, | ||
| index_map: Union[IndexMap, Callable], | ||
| ) -> None: | ||
| """Apply a transformation represented by IndexMap to buffer | ||
| Parameters | ||
| ---------- | ||
| block_rv : BlockRV | ||
| The block that accesses the target buffer | ||
| buffer_index: int | ||
| The index of the buffer in block's read or write region | ||
| buffer_type : BufferType | ||
| Type of the buffer, READ or WRITE. | ||
| index_map : Union[IndexMap, Callable] | ||
| The transformation to apply | ||
|
|
||
| Examples | ||
| -------- | ||
| Before transform_layout, in TensorIR, the IR is: | ||
|
vinx13 marked this conversation as resolved.
|
||
|
|
||
| .. code-block:: python | ||
|
|
||
| @T.prim_func | ||
| def before_transform_layout(a: T.handle, c: T.handle) -> None: | ||
| A = T.match_buffer(a, (128, 128), "float32") | ||
| B = T.alloc_buffer((128, 128), "float32") | ||
| C = T.match_buffer(c, (128, 128), "float32") | ||
| for i, j in T.grid(128, 128): | ||
| with T.block("B"): | ||
| vi, vj = T.axis.remap("SS", [i, j]) | ||
| B[vi, vj] = A[vi, vj] * 2.0 | ||
| for i, j in T.grid(128, 128): | ||
| with T.block("C"): | ||
| vi, vj = T.axis.remap("SS", [i, j]) | ||
| C[vi, vj] = B[vi, vj] + 1.0 | ||
|
|
||
| Create the schedule and do transform_layout: | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| sch = tir.Schedule(before_storage_align) | ||
| sch.transform_layout(sch.get_block("B"), buffer_index=0, BufferType.WRITE, | ||
| index_map=lambda m, n: (m // 16, n // 16, m % 16, n % 16)) | ||
| print(sch.mod["main"].script()) | ||
|
|
||
| After applying transform_layout, the IR becomes: | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| @T.prim_func | ||
| def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> None: | ||
| A = T.match_buffer(a, (128, 128), "float32") | ||
| B = T.alloc_buffer((8, 8, 16, 16), "float32") | ||
| C = T.match_buffer(c, (128, 128), "float32") | ||
| for i, j in T.grid(128, 128): | ||
| with T.block("B"): | ||
| vi, vj = T.axis.remap("SS", [i, j]) | ||
| B[vi // 16, vj // 16, vi % 16, vj % 16] = A[vi, vj] * 2.0 | ||
| for i, j in T.grid(128, 128): | ||
| with T.block("C"): | ||
| vi, vj = T.axis.remap("SS", [i, j]) | ||
| C[vi, vj] = B[vi // 16, vj // 16, vi % 16, vj % 16] + 1.0 | ||
|
|
||
| """ | ||
| if callable(index_map): | ||
| index_map = IndexMap.from_func(index_map) | ||
|
Lunderberg marked this conversation as resolved.
|
||
| _ffi_api.ScheduleTransformLayout( # type: ignore # pylint: disable=no-member | ||
| self, block, buffer_index, buffer_type, index_map | ||
| ) | ||
|
|
||
| ########## Schedule: Misc ########## | ||
|
|
||
| @type_checked | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.