Skip to content

[Unity][FX] Add support for PT2.0 scaled_dot_product_attention#14841

Merged
vinx13 merged 7 commits intoapache:unityfrom
masahi:fx-sdp
May 16, 2023
Merged

[Unity][FX] Add support for PT2.0 scaled_dot_product_attention#14841
vinx13 merged 7 commits intoapache:unityfrom
masahi:fx-sdp

Conversation

@masahi
Copy link
Copy Markdown
Member

@masahi masahi commented May 12, 2023

diffusers started to use scaled_dot_product_attention as of v0.16 in SD VAE. See the doc https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html. We cannot support attention mask, dropout, and causal mask optimization for now.

The PT attention op requires a different input format than our attention op, so we need to transpose inputs. Luckily those transpose can be canceled in practice, since diffusers does transpose on q/k/v before calling scaled_dot_product_attention anyway (a design mistake?): https://github.com/huggingface/diffusers/blob/909742dbd6873052995dc6cd5f4150ff238015d2/src/diffusers/models/attention_processor.py#L906-L908

I'm also doing clean up on FX test cases. We shouldn't need @tvm.testing.requires_gpu since the tests don't execute anything on GPU. I'll also try removing local import torch if possible (CI should have PT installed, right?)

@MasterJH5574 @jinhongyii @cyx-6

@tvm-bot
Copy link
Copy Markdown
Collaborator

tvm-bot commented May 12, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@yongwww
Copy link
Copy Markdown
Member

yongwww commented May 12, 2023

@masahi torch is installed in CI GPU container, but not in CI CPU container.

@masahi
Copy link
Copy Markdown
Member Author

masahi commented May 12, 2023

@masahi torch is installed in CI GPU container, but not in CI CPU container.

Oh I thought the cpu image has PT as well. It is a bit ironic since even for GPU image we install CPU-build of PT (since we only use PT for reference).

Is there a reason the CPU image cannot have PT? It is pretty bad if we have to add @tvm.testing.requires_gpu (despite nothing runs on GPU) and do local import in every tests, only because CPU image doesn't have PT.

@yongwww
Copy link
Copy Markdown
Member

yongwww commented May 12, 2023

I also noticed this before, don't know the reason why pt cpu version was used in gpu container, I don't know the reason why cpu image doesn't have pt (and onnx) either

@masahi
Copy link
Copy Markdown
Member Author

masahi commented May 12, 2023

I also noticed this before, don't know the reason why pt cpu version was used in gpu container.

Because their gpu build often causes a trouble when we upgrade the PT version in our CI.

It seems our cpu image has tflite, mxnet, and caffe installed. I see no reason PT cannot be installed. I'll work on that next week.

@yongwww
Copy link
Copy Markdown
Member

yongwww commented May 12, 2023

I also noticed this before, don't know the reason why pt cpu version was used in gpu container.

Because their gpu build often causes a trouble when we upgrade the PT version in our CI.

It seems our cpu image has tflite, mxnet, and caffe installed. I see no reason PT cannot be installed. I'll work on that next week.

Just tried to install torch 2.0 in my local cpu container (run from the same image as CI), it works well

@masahi
Copy link
Copy Markdown
Member Author

masahi commented May 12, 2023

Okay after #14842 lands and the new image is published, I'll update our cpu image.

@vinx13
Copy link
Copy Markdown
Member

vinx13 commented May 13, 2023

note that float mask that's added to score can be supported via bias input of relax attention

@masahi masahi force-pushed the fx-sdp branch 2 times, most recently from a180871 to 987656c Compare May 14, 2023 23:53
Copy link
Copy Markdown
Member

@yongwww yongwww left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@masahi
Copy link
Copy Markdown
Member Author

masahi commented May 15, 2023

Cleaned up test cases quite a bit.

@vinx13 vinx13 merged commit e812a21 into apache:unity May 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants