[Relax][PyTorch] Enhance scale_factor handling in interpolation#18550
[Relax][PyTorch] Enhance scale_factor handling in interpolation#18550mshr-h merged 1 commit intoapache:mainfrom
Conversation
Summary of ChangesHello @guan404ming, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical bug in the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request correctly fixes a bug in the PyTorch frontend where interpolation with a list scale_factor was handled incorrectly, preventing different scaling for height and width. The change removes the faulty indexing and passes the scale_factor directly to the underlying implementation, which is the correct approach. The addition of numerical tests for nearest, bilinear, and bicubic modes with both scalar and list scale factors is a great way to ensure correctness. I've suggested a refactoring for the new test to reduce code duplication and improve maintainability using pytest.mark.parametrize.
| def test_interpolate_scale_factor_list(): | ||
| class InterpolateNearestScalar(Module): | ||
| """Nearest interpolation with scalar scale_factor.""" | ||
|
|
||
| def forward(self, input): | ||
| return torch.nn.functional.interpolate(input, scale_factor=2.0, mode="nearest") | ||
|
|
||
| class InterpolateNearestList(Module): | ||
| """Nearest interpolation with list scale_factor (different H and W scaling).""" | ||
|
|
||
| def forward(self, input): | ||
| return torch.nn.functional.interpolate(input, scale_factor=[2.0, 3.0], mode="nearest") | ||
|
|
||
| class InterpolateBilinearScalar(Module): | ||
| """Bilinear interpolation with scalar scale_factor.""" | ||
|
|
||
| def forward(self, input): | ||
| return torch.nn.functional.interpolate( | ||
| input, scale_factor=2.0, mode="bilinear", align_corners=False | ||
| ) | ||
|
|
||
| class InterpolateBilinearList(Module): | ||
| """Bilinear interpolation with list scale_factor (different H and W scaling).""" | ||
|
|
||
| def forward(self, input): | ||
| return torch.nn.functional.interpolate( | ||
| input, scale_factor=[2.0, 3.0], mode="bilinear", align_corners=False | ||
| ) | ||
|
|
||
| class InterpolateBicubicScalar(Module): | ||
| """Bicubic interpolation with scalar scale_factor.""" | ||
|
|
||
| def forward(self, input): | ||
| return torch.nn.functional.interpolate( | ||
| input, scale_factor=2.0, mode="bicubic", align_corners=False | ||
| ) | ||
|
|
||
| class InterpolateBicubicList(Module): | ||
| """Bicubic interpolation with list scale_factor (different H and W scaling).""" | ||
|
|
||
| def forward(self, input): | ||
| return torch.nn.functional.interpolate( | ||
| input, scale_factor=[2.0, 3.0], mode="bicubic", align_corners=False | ||
| ) | ||
|
|
||
| # Test with 32x32 input | ||
| example_args = (torch.randn(1, 3, 32, 32, dtype=torch.float32),) | ||
|
|
||
| # Test nearest mode with scalar and list scale_factor | ||
| verify_model_numerically(InterpolateNearestScalar(), example_args, rtol=1e-5, atol=1e-5) | ||
| verify_model_numerically(InterpolateNearestList(), example_args, rtol=1e-5, atol=1e-5) | ||
|
|
||
| # Test bilinear mode with scalar and list scale_factor | ||
| verify_model_numerically(InterpolateBilinearScalar(), example_args, rtol=1e-5, atol=1e-5) | ||
| verify_model_numerically(InterpolateBilinearList(), example_args, rtol=1e-5, atol=1e-5) | ||
|
|
||
| # Test bicubic mode with scalar and list scale_factor | ||
| verify_model_numerically(InterpolateBicubicScalar(), example_args, rtol=1e-5, atol=1e-5) | ||
| verify_model_numerically(InterpolateBicubicList(), example_args, rtol=1e-5, atol=1e-5) | ||
|
|
There was a problem hiding this comment.
This test is great for covering the different interpolation modes with both scalar and list scale_factors. However, there's a lot of repeated code between the different test cases. You can significantly simplify this by using pytest.mark.parametrize to iterate through the different modes and scale factors. This will make the test more concise and easier to maintain or extend in the future.
@pytest.mark.parametrize(
"mode, scale_factor",
[
("nearest", 2.0),
("nearest", [2.0, 3.0]),
("bilinear", 2.0),
("bilinear", [2.0, 3.0]),
("bicubic", 2.0),
("bicubic", [2.0, 3.0]),
],
)
def test_interpolate_scale_factor_list(mode, scale_factor):
"""Test interpolation with various modes and scale factors."""
class InterpolateModel(Module):
def forward(self, input_tensor):
kwargs = {"scale_factor": scale_factor, "mode": mode}
if mode != "nearest":
kwargs["align_corners"] = False
return torch.nn.functional.interpolate(input_tensor, **kwargs)
# Test with 32x32 input
example_args = (torch.randn(1, 3, 32, 32, dtype=torch.float32),)
verify_model_numerically(InterpolateModel(), example_args, rtol=1e-5, atol=1e-5)e54e3a0 to
9eb8caf
Compare
9eb8caf to
b922a52
Compare
|
Thanks! |
Why
Fixes interpolation to support different scaling factors for height and width (e.g., scale_factor=[2.0, 3.0])
How