Skip to content

[Bugfix] Make ThreadAllReduce pass compatible with int64#14991

Merged
junrushao merged 2 commits intoapache:mainfrom
yzh119:fix-allreduce-dtype
May 31, 2023
Merged

[Bugfix] Make ThreadAllReduce pass compatible with int64#14991
junrushao merged 2 commits intoapache:mainfrom
yzh119:fix-allreduce-dtype

Conversation

@yzh119
Copy link
Copy Markdown
Member

@yzh119 yzh119 commented May 30, 2023

The Issue

Currently, the ThreadAllReduce pass would throw an error when the mask data type is uint32 and the group_index's data type is int64:

  1: tvm::tir::ThreadAllreduceBuilder::MakeAllreduce(tvm::tir::CallNode const*)
        at /home/zhye/repos/relax/src/tir/transforms/lower_thread_allreduce.cc:362
  0: tvm::tir::BufferStore::BufferStore(tvm::tir::Buffer, tvm::PrimExpr, tvm::runtime::Array<tvm::PrimExpr, void>, tvm::Span)
        at /home/zhye/repos/relax/src/tir/ir/stmt.cc:477
  File "/home/zhye/repos/relax/src/tir/ir/stmt.cc", line 477
TypeError: dtype mismatch on BufferStore: buffer's dtype is `uint32`, the lanes of indexing are: `1`, but RHS's dtype is `int64`

As int64 becomes the standard index data type for large models, we should fix the issue.

The Fix

This PR resolves the issue by casting the group_index to the data type used in mask.

cc @vinx13 @junrushao @Hzfengsy

@tvm-bot
Copy link
Copy Markdown
Collaborator

tvm-bot commented May 30, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

  • No users to tag found in teams: bugfix See #10317 for details

Generated by tvm-bot

@github-actions github-actions Bot requested review from junrushao and vinx13 May 30, 2023 23:48
@github-actions github-actions Bot requested a review from Hzfengsy May 31, 2023 06:55
@junrushao junrushao merged commit c98e29b into apache:main May 31, 2023
mei-ye pushed a commit to mei-ye/tvm that referenced this pull request Jun 1, 2023
# The Issue
Currently, the ThreadAllReduce pass would throw an error when the mask data type is uint32 and the `group_index`'s data type is int64:
```bash
  1: tvm::tir::ThreadAllreduceBuilder::MakeAllreduce(tvm::tir::CallNode const*)
        at /home/zhye/repos/relax/src/tir/transforms/lower_thread_allreduce.cc:362
  0: tvm::tir::BufferStore::BufferStore(tvm::tir::Buffer, tvm::PrimExpr, tvm::runtime::Array<tvm::PrimExpr, void>, tvm::Span)
        at /home/zhye/repos/relax/src/tir/ir/stmt.cc:477
  File "/home/zhye/repos/relax/src/tir/ir/stmt.cc", line 477
TypeError: dtype mismatch on BufferStore: buffer's dtype is `uint32`, the lanes of indexing are: `1`, but RHS's dtype is `int64`
```

As int64 becomes the standard index data type for large models, we should fix the issue.

# The Fix
This PR resolves the issue by casting the `group_index` to the data type used in mask.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants