[Unity][FX] Add support for PT2.0 scaled_dot_product_attention#14841
[Unity][FX] Add support for PT2.0 scaled_dot_product_attention#14841vinx13 merged 7 commits intoapache:unityfrom
Conversation
|
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
|
@masahi torch is installed in CI GPU container, but not in CI CPU container. |
Oh I thought the cpu image has PT as well. It is a bit ironic since even for GPU image we install CPU-build of PT (since we only use PT for reference). Is there a reason the CPU image cannot have PT? It is pretty bad if we have to add |
|
I also noticed this before, don't know the reason why pt cpu version was used in gpu container, I don't know the reason why cpu image doesn't have pt (and onnx) either |
Because their gpu build often causes a trouble when we upgrade the PT version in our CI. It seems our cpu image has tflite, mxnet, and caffe installed. I see no reason PT cannot be installed. I'll work on that next week. |
Just tried to install torch 2.0 in my local cpu container (run from the same image as CI), it works well |
|
Okay after #14842 lands and the new image is published, I'll update our cpu image. |
|
note that float mask that's added to score can be supported via |
a180871 to
987656c
Compare
|
Cleaned up test cases quite a bit. |
diffusers started to use
scaled_dot_product_attentionas of v0.16 in SD VAE. See the doc https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html. We cannot support attention mask, dropout, and causal mask optimization for now.The PT attention op requires a different input format than our attention op, so we need to transpose inputs. Luckily those transpose can be canceled in practice, since diffusers does transpose on q/k/v before calling
scaled_dot_product_attentionanyway (a design mistake?): https://github.com/huggingface/diffusers/blob/909742dbd6873052995dc6cd5f4150ff238015d2/src/diffusers/models/attention_processor.py#L906-L908I'm also doing clean up on FX test cases. We shouldn't need
@tvm.testing.requires_gpusince the tests don't execute anything on GPU. I'll also try removing localimport torchif possible (CI should have PT installed, right?)@MasterJH5574 @jinhongyii @cyx-6