Skip to content

[Torch] Experimental support for FX-quantized models#10091

Merged
jroesch merged 15 commits intoapache:mainfrom
masahi:fx-quant
Feb 4, 2022
Merged

[Torch] Experimental support for FX-quantized models#10091
jroesch merged 15 commits intoapache:mainfrom
masahi:fx-quant

Conversation

@masahi
Copy link
Copy Markdown
Member

@masahi masahi commented Jan 28, 2022

This is the first step toward supporting models quantized with the FX-based workflow as described in https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_static.html.

The required change was surprisingly simple: Simple graph surgery done by inline_input_quant_params_for_fx(...) in qnn_torch.py is enough. So far, I was able to quantize imagenet models, deeplab v3, ssd-vgg, and yolov5, either fully or semi automatically. See the attached test cases.

Since my current interest is to collect real-world quantized workloads for performance benchmarking, I didn't care about calibration.

Also added aten::clamp_min support for SSD-VGG.

@comaniac @lhutton1 @junrushao1994 @siju-samuel @t-vi @AndrewZhaoLuo

Copy link
Copy Markdown
Contributor

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Thank you @masahi
I have two minor suggestions for optional improvements/follow-ups (should not keep you from merging though).

Comment thread python/tvm/relay/frontend/pytorch.py Outdated
Comment thread python/tvm/relay/frontend/pytorch.py Outdated
amin = get_v(inputs[1], np.finfo(np.float32).min)
amax = get_v(inputs[2], np.finfo(np.float32).max)
if min_only:
amin = get_v(inputs[1], np.finfo(np.float32).min)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not introduced in this PR, but is using float32's max prudent here? We might have all sorts of dtypes as inputs, also -inf should probably not be clamped to -3.4028235e+38.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, fixed but not sure what to do with clamping inf.

Copy link
Copy Markdown
Contributor

@t-vi t-vi Jan 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think -inf should stay -inf if we clamp with only max. PyTorch uses separate implementation kernels that only do the required ops for this (kind of the reverse of what we do here), I don't know if that might be a good choice for TVM.

Copy link
Copy Markdown
Member Author

@masahi masahi Jan 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, that makes sense. Probably the best way for us is to allow None value in clip op for one-way clipping, and do only max or min inside topi. I left a TODO item for that.

Probably other op conversions also have issues with inf handling...

@masahi
Copy link
Copy Markdown
Member Author

masahi commented Feb 1, 2022

Can somebody merge this PR? I'm not aware of anyone who is familiar with PT quantization, but I believe the change should be no brainer. This is a nice feature to land, which was also requested in https://discuss.tvm.apache.org/t/pytorch-qnn-cannot-import-torchscript-produced-by-fx-graph-mode-quantization/11954

@jroesch jroesch merged commit 95aac92 into apache:main Feb 4, 2022
ylc pushed a commit to ylc/tvm that referenced this pull request Feb 16, 2022
* works on resnet18 and deeplabv3

* yolo5 conversion worked

* fixed sigmoid

* [Torch] Support clamp_min, clamp_max

* fixed clamp_min

* fixed quantize for 1 dim input

* cleanup

* improve inline_qparam impl

* add clamp_min/max test

* add fx quant test

* cleanup

* skip build in testing

* black

* improve clamp conversion

* leave TODO on inf handling
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants