Skip to content

Fix nn.PixelUnshuffle wrong channel ordering (replace SpaceToDepth with Reshape→Transpose→Reshape)#2892

Open
Copilot wants to merge 2 commits intomainfrom
copilot/fix-pixelunshuffle-export
Open

Fix nn.PixelUnshuffle wrong channel ordering (replace SpaceToDepth with Reshape→Transpose→Reshape)#2892
Copilot wants to merge 2 commits intomainfrom
copilot/fix-pixelunshuffle-export

Conversation

Copy link
Copy Markdown
Contributor

Copilot AI commented Apr 17, 2026

nn.PixelUnshuffle was exported as SpaceToDepth, which only supports DCR channel ordering — but PixelUnshuffle is the inverse of PixelShuffle, which uses DepthToSpace with mode="CRD". This mismatch produced incorrect output channel order.

Changes

  • aten_pixel_unshuffle: Replace op.SpaceToDepth with an explicit Reshape→Transpose→Reshape sequence that correctly implements the inverse of CRD DepthToSpace:
    1. Reshape [..., C, H×r, W×r][batch, C, H, r, W, r]
    2. Transpose perm=[0, 1, 3, 5, 2, 4][batch, C, r, r, H, W] (inverse of CRD perm [0, 1, 4, 2, 5, 3])
    3. Reshape → [batch, C×r², H, W]
  • The ND (>4D) input path is unified: leading dims are always collapsed to a batch dim before the core 4D logic, then restored — removing the now-unnecessary if len(self.shape) == 4 special case.

This satisfies PyTorch's PixelUnshuffle semantics: output[n, c·r²+r_h·r+r_w, h, w] = input[n, c, h·r+r_h, w·r+r_w].

Warning

Firewall rules blocked me from connecting to one or more addresses (expand for details)

I tried to connect to the following addresses, but was blocked by firewall rules:

  • download-r2.pytorch.org
    • Triggering command: /usr/bin/pip pip install torch onnx onnxruntime --index-url REDACTED -q (dns block)

If you need me to access, download, or install something from one of these locations, you can either:

…paceToDepth

SpaceToDepth only supports DCR channel ordering, but PixelUnshuffle (inverse of
PixelShuffle) requires CRD ordering. Implement via:
1. Reshape [..., C, H*r, W*r] -> [batch, C, H, r, W, r]
2. Transpose with perm [0, 1, 3, 5, 2, 4] (inverse of CRD DepthToSpace perm)
3. Reshape to [batch, C*r*r, H, W]

Agent-Logs-Url: https://github.com/microsoft/onnxscript/sessions/44a370bd-6b05-4d23-b35a-f6a6c1df1d23

Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Copilot AI changed the title [WIP] Fix nn.PixelUnshuffle wrongfully exported as SpaceToDepth Fix nn.PixelUnshuffle wrong channel ordering (replace SpaceToDepth with Reshape→Transpose→Reshape) Apr 17, 2026
Copilot AI requested a review from justinchuby April 17, 2026 15:59
@justinchuby justinchuby marked this pull request as ready for review April 17, 2026 16:28
@codecov
Copy link
Copy Markdown

codecov bot commented Apr 17, 2026

❌ 2 Tests Failed:

Tests completed Failed Passed Skipped
11010 2 11008 936
View the top 2 failed test(s) by shortest run time
tests.function_libs.torch_lib.ops_test.TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__nn_functional_pixel_unshuffle_cpu_float32
Stack Traces | 2.37s run time
.../function_libs/torch_lib/ops_test_common.py:591: in _capture_graph_and_evaluate_torch_script_evaluator
    return _safe_ort_session_run(model_proto.SerializeToString(), ort_inputs)
.../function_libs/torch_lib/ops_test_common.py:389: in _safe_ort_session_run
    raise return_dict["error"]
E   onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Reshape node. Name:'node_Reshape_4' Status Message: .../cpu/tensor/reshape_helper.h:39 onnxruntime::ReshapeHelper::ReshapeHelper(const onnxruntime::TensorShape&, onnxruntime::TensorShapeVector&, bool) size != 0 && (input_shape_size % size) == 0 was false. The input tensor cannot be reshaped to the requested shape. Input shape:{1,1,1,0}, requested shape:{-1,1,1,0}

The above exception was the direct cause of the following exception:
.../function_libs/torch_lib/ops_test.py:206: in run_test_output_match
    function_output = function_executor(test_name, reference_torch_outputs)(
.../function_libs/torch_lib/ops_test_common.py:607: in _capture_graph_and_evaluate_torch_script_evaluator
    raise RuntimeError(
E   RuntimeError: ONNX Runtime failed to evaluate:
E   Inputs:
E   {'input_0': array([], shape=(1, 1, 1, 0), dtype=float32)}
E   Model:
E   <
E      ir_version: 10,
E      opset_import: ["" : 18, "pkg.torch.onnx" : 1, "pkg.onnxscript.torch_lib.common" : 1, "pkg.onnxscript.torch_lib" : 1],
E      producer_name: "torch_test"
E   >
E   main_graph (float[1,1,1,0] input_0) => (float[1,1,1,0] val_22) 
E      <int64[1] val_0, int64[3] val_1, int64[1] val_2, int64[4] val_3, int64[1] val_5, int64[unk__0] val_6, int64[unk__1] val_7, int64[unk__2] val_8, int64[unk__1] val_9, int64[unk__2] val_10, int64[1] val_11, int64[unk__3] val_12, int64[1] val_15, int64[unk__0] val_16, int64[1] val_17, int64[unk__4] val_18, int64[unk__5] val_20, int64[unk__6] val_21>
E   {
E      val_0 = Shape <end: int = -3, start: int = 0> (input_0)
E      val_1 = Shape <start: int = -3> (input_0)
E      val_2 = Constant <value_ints: ints = [-1]> ()
E      val_3 = Concat <axis: int = 0> (val_2, val_1)
E      val_4 = Reshape <allowzero: int = 0> (input_0, val_3)
E      val_5 = Constant <value_ints: ints = [1]> ()
E      val_6 = Shape <end: int = 2, start: int = 1> (val_4)
E      val_7 = Shape <end: int = 3, start: int = 2> (val_4)
E      val_8 = Shape <end: int = 4, start: int = 3> (val_4)
E      val_9 = Div (val_7, val_5)
E      val_10 = Div (val_8, val_5)
E      val_11 = Constant <value_ints: ints = [-1]> ()
E      val_12 = Concat <axis: int = 0> (val_11, val_6, val_9, val_5, val_10, val_5)
E      val_13 = Reshape <allowzero: int = 0> (val_4, val_12)
E      val_14 = Transpose <perm: ints = [0, 1, 3, 5, 2, 4]> (val_13)
E      val_15 = Mul (val_5, val_5)
E      val_16 = Mul (val_6, val_15)
E      val_17 = Constant <value_ints: ints = [-1]> ()
E      val_18 = Concat <axis: int = 0> (val_17, val_16, val_9, val_10)
E      val_19 = Reshape <allowzero: int = 0> (val_14, val_18)
E      val_20 = Shape <start: int = 1> (val_19)
E      val_21 = Concat <axis: int = 0> (val_0, val_20)
E      val_22 = Reshape <allowzero: int = 1> (val_19, val_21)
E   }
tests.function_libs.torch_lib.ops_test.TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__nn_functional_pixel_unshuffle_cpu_float16
Stack Traces | 2.44s run time
.../function_libs/torch_lib/ops_test_common.py:591: in _capture_graph_and_evaluate_torch_script_evaluator
    return _safe_ort_session_run(model_proto.SerializeToString(), ort_inputs)
.../function_libs/torch_lib/ops_test_common.py:389: in _safe_ort_session_run
    raise return_dict["error"]
E   onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Reshape node. Name:'node_Reshape_4' Status Message: .../cpu/tensor/reshape_helper.h:39 onnxruntime::ReshapeHelper::ReshapeHelper(const onnxruntime::TensorShape&, onnxruntime::TensorShapeVector&, bool) size != 0 && (input_shape_size % size) == 0 was false. The input tensor cannot be reshaped to the requested shape. Input shape:{1,1,1,0}, requested shape:{-1,1,1,0}

The above exception was the direct cause of the following exception:
.../function_libs/torch_lib/ops_test.py:206: in run_test_output_match
    function_output = function_executor(test_name, reference_torch_outputs)(
.../function_libs/torch_lib/ops_test_common.py:607: in _capture_graph_and_evaluate_torch_script_evaluator
    raise RuntimeError(
E   RuntimeError: ONNX Runtime failed to evaluate:
E   Inputs:
E   {'input_0': array([], shape=(1, 1, 1, 0), dtype=float16)}
E   Model:
E   <
E      ir_version: 10,
E      opset_import: ["" : 18, "pkg.torch.onnx" : 1, "pkg.onnxscript.torch_lib.common" : 1, "pkg.onnxscript.torch_lib" : 1],
E      producer_name: "torch_test"
E   >
E   main_graph (float16[1,1,1,0] input_0) => (float16[1,1,1,0] val_22) 
E      <int64[1] val_0, int64[3] val_1, int64[1] val_2, int64[4] val_3, int64[1] val_5, int64[unk__0] val_6, int64[unk__1] val_7, int64[unk__2] val_8, int64[unk__1] val_9, int64[unk__2] val_10, int64[1] val_11, int64[unk__3] val_12, int64[1] val_15, int64[unk__0] val_16, int64[1] val_17, int64[unk__4] val_18, int64[unk__5] val_20, int64[unk__6] val_21>
E   {
E      val_0 = Shape <end: int = -3, start: int = 0> (input_0)
E      val_1 = Shape <start: int = -3> (input_0)
E      val_2 = Constant <value_ints: ints = [-1]> ()
E      val_3 = Concat <axis: int = 0> (val_2, val_1)
E      val_4 = Reshape <allowzero: int = 0> (input_0, val_3)
E      val_5 = Constant <value_ints: ints = [1]> ()
E      val_6 = Shape <end: int = 2, start: int = 1> (val_4)
E      val_7 = Shape <end: int = 3, start: int = 2> (val_4)
E      val_8 = Shape <end: int = 4, start: int = 3> (val_4)
E      val_9 = Div (val_7, val_5)
E      val_10 = Div (val_8, val_5)
E      val_11 = Constant <value_ints: ints = [-1]> ()
E      val_12 = Concat <axis: int = 0> (val_11, val_6, val_9, val_5, val_10, val_5)
E      val_13 = Reshape <allowzero: int = 0> (val_4, val_12)
E      val_14 = Transpose <perm: ints = [0, 1, 3, 5, 2, 4]> (val_13)
E      val_15 = Mul (val_5, val_5)
E      val_16 = Mul (val_6, val_15)
E      val_17 = Constant <value_ints: ints = [-1]> ()
E      val_18 = Concat <axis: int = 0> (val_17, val_16, val_9, val_10)
E      val_19 = Reshape <allowzero: int = 0> (val_14, val_18)
E      val_20 = Shape <start: int = 1> (val_19)
E      val_21 = Concat <axis: int = 0> (val_0, val_20)
E      val_22 = Reshape <allowzero: int = 1> (val_19, val_21)
E   }
View the full list of 1 ❄️ flaky test(s)
tests.function_libs.torch_lib.ops_test.TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__logsumexp_cpu_float16

Flake rate in main: 14.59% (Passed 2324 times, Failed 397 times)

Stack Traces | 0.754s run time
.../function_libs/torch_lib/ops_test.py:243: in run_test_output_match
    torch.testing.assert_close(
E   AssertionError: Tensor-likes are not close!
E   
E   Mismatched elements: 1 / 5 (20.0%)
E   Greatest absolute difference: 2.288818359375e-05 at index (1,) (up to 1e-05 allowed)
E   Greatest relative difference: 0.0022869110107421875 at index (1,) (up to 0.001 allowed)

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Development

Successfully merging this pull request may close these issues.

nn.PixelUnshuffle is wrongfully exported as SpaceToDepth

2 participants