From ce0131ebafa71e4c4b8db5feca2babc67df660eb Mon Sep 17 00:00:00 2001 From: Linus Jungemann Date: Thu, 18 Jun 2026 12:37:30 +0200 Subject: [PATCH] Fix version converter --- .../version_converter/_version_converter.py | 18 +++++++++--------- .../_version_converter_test.py | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/onnxscript/version_converter/_version_converter.py b/onnxscript/version_converter/_version_converter.py index 99e30417d4..7b610834e7 100644 --- a/onnxscript/version_converter/_version_converter.py +++ b/onnxscript/version_converter/_version_converter.py @@ -86,10 +86,10 @@ def lookup_adapters( self, domain: str, opname: str, - original_version: int, + target_version: int, up_conversion: bool = True, ) -> AdapterFunction | None: - adapter_func = self.op_adapters.get((domain, opname, original_version, up_conversion)) + adapter_func = self.op_adapters.get((domain, opname, target_version, up_conversion)) if adapter_func is not None: return adapter_func return None @@ -97,7 +97,7 @@ def lookup_adapters( def register( self, opname: str, domain: str = "", node_version=None, up_conversion=True ) -> Callable[[AdapterFunction], AdapterFunction]: - """Register an adapter based on the domain, operator type, node version and whether to upgrade/downgrade node version""" + """Register an adapter based on the domain, operator type, target node version and whether the target version is a result of an up-conversion or down-conversion.""" def decorator(function: AdapterFunction) -> AdapterFunction: @functools.wraps(function) @@ -154,7 +154,7 @@ def _get_str_attribute(node: ir.Node, name: str, default: str | None = None) -> # Opset 19 -> 20 -@register("DFT", node_version=19, up_conversion=True) +@register("DFT", node_version=20, up_conversion=True) def dft_19_20(node: ir.Node, op): input = node.inputs[0] dft_length = node.inputs[1] if len(node.inputs) > 1 else None @@ -167,7 +167,7 @@ def dft_19_20(node: ir.Node, op): return None -@register("GridSample", node_version=19, up_conversion=True) +@register("GridSample", node_version=20, up_conversion=True) def gridsample_19_20(node: ir.Node, op): x = node.inputs[0] grid = node.inputs[1] @@ -188,7 +188,7 @@ def gridsample_19_20(node: ir.Node, op): # Opset 20 -> 21 -@register("GroupNormalization", node_version=20, up_conversion=True) +@register("GroupNormalization", node_version=21, up_conversion=True) def groupnormalization_20_21(node: ir.Node, op): x = _get_input(node, 0) scale = _get_input(node, 1) @@ -249,11 +249,11 @@ def __init__(self, target_version: int): ) def process_node( - self, node: ir.Node, from_version: int, up_conversion: bool = True + self, node: ir.Node, to_version: int, up_conversion: bool = True ) -> Replacement | None: assert node.domain == "" adapter = registry.lookup_adapters( - node.domain, node.op_type, from_version, up_conversion + node.domain, node.op_type, to_version, up_conversion ) if adapter is None: return None @@ -293,7 +293,7 @@ def visit_node( to_version = from_version + 1 else: to_version = from_version - 1 - replacement = self.process_node(node, from_version, up_conversion) + replacement = self.process_node(node, to_version, up_conversion) if replacement is None: # No change. Process attributes. for attr in node.attributes.values(): diff --git a/onnxscript/version_converter/_version_converter_test.py b/onnxscript/version_converter/_version_converter_test.py index 2635635557..193ae6db93 100644 --- a/onnxscript/version_converter/_version_converter_test.py +++ b/onnxscript/version_converter/_version_converter_test.py @@ -37,7 +37,7 @@ def test_upstream_coverage(self): domain, name, upgrade_version = ( adapter_info[0], adapter_info[1], - adapter_info[2] + 1, + adapter_info[2], ) self.assertEqual(domain, "") self.assertIn((name, upgrade_version), op_upgrades)