Skip to content

[LoopPartition] Fix a bug of LoopPartition in single point scenarioes#16104

Merged
tqchen merged 1 commit intoapache:mainfrom
lightzhan-intellif:fix_loop_partition
Dec 15, 2023
Merged

[LoopPartition] Fix a bug of LoopPartition in single point scenarioes#16104
tqchen merged 1 commit intoapache:mainfrom
lightzhan-intellif:fix_loop_partition

Conversation

@lightzhan-intellif
Copy link
Copy Markdown
Contributor

This PR tries to fix a bug of the pass LoopPartiton. When there are one or more tensors containing a shape 1 in the concat dim, the pass will unroll the loops wrongly after partitioning. For example:

@T.prim_func
def concat_func_single_point(
    placeholder: T.Buffer((28, 64), "int8"),
    placeholder_1: T.Buffer((28, 1), "int8"),
    placeholder_2: T.Buffer((28, 63), "int8"),
    T_concat: T.Buffer((28, 128), "int8"),
) -> None:
    for i0 in range(28):
        for i1 in T.serial(128, annotations={"pragma_loop_partition_hint": 1}):
            if i1 > 63:
                T_concat[i0, i1] = placeholder[i0, i1 - 64]
            elif i1 == 63:
                T_concat[i0, i1] = placeholder_1[i0, i1 - 63]
            else:
                T_concat[i0, i1] = placeholder_2[i0, i1]

after LoopPartition:

@T.prim_func
def expected_partitioned_concat_single_point(
    placeholder: T.Buffer((28, 64), "int8"),
    placeholder_1: T.Buffer((28, 1), "int8"),
    placeholder_2: T.Buffer((28, 63), "int8"),
    T_concat: T.Buffer((28, 128), "int8"),
):
    for i0 in range(28):
        T_concat_1 = T.Buffer((3584,), "int8", data=T_concat.data)
        for i1 in T.unroll(63): # Note here, it is unrolled.
            placeholder_2_1 = T.Buffer((1764,), "int8", data=placeholder_2.data)
            T_concat_1[i0 * 128 + i1] = placeholder_2_1[i0 * 63 + i1]
        placeholder_1_1 = T.Buffer((28,), "int8", data=placeholder_1.data)
        T_concat_1[i0 * 128 + 63] = placeholder_1_1[i0]
        for i1 in T.unroll(64): # here too.
            placeholder_3 = T.Buffer((1792,), "int8", data=placeholder.data)
            T_concat_1[i0 * 128 + i1 + 64] = placeholder_3[i0 * 64 + i1]

cc @wrongtest-intellif @tqchen

@lightzhan-intellif lightzhan-intellif changed the title [LoopPartition] Fix a bug of LoopPartition in single point scenarioes. [LoopPartition] Fix a bug of LoopPartition in single point scenarioes Nov 10, 2023
@tqchen
Copy link
Copy Markdown
Member

tqchen commented Dec 7, 2023

Thanks @lightzhan-intellif do you mind to fix the ci?

@lightzhan-intellif
Copy link
Copy Markdown
Contributor Author

@tvm-bot rerun

@lightzhan-intellif
Copy link
Copy Markdown
Contributor Author

Thanks @lightzhan-intellif do you mind to fix the ci?

done

@tqchen tqchen merged commit 870246a into apache:main Dec 15, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants