Skip to content

Handling initializers in GraphBuilder#2889

Merged
gramalingam merged 11 commits intomainfrom
rama/root-initializers
Apr 16, 2026
Merged

Handling initializers in GraphBuilder#2889
gramalingam merged 11 commits intomainfrom
rama/root-initializers

Conversation

@gramalingam
Copy link
Copy Markdown
Collaborator

@gramalingam gramalingam commented Apr 15, 2026

Motivation: Use of some standard graph-building techniques in Mobius fail when used to create ir.Function, because ONNX Functions do not allow initializers. This PR helps address this. It also reduces some of the boiler-plate code in creating functions via an utility function.

(a) Update the logic for creating initializers to ensure that they are always added to the root (main) Graph.
(b) Add a utility to convert initializers into Constant nodes (which is necessary for a Function in ONNX).
(c) Add a utility to build a function

Note: An alternative considered was adding an option to the GraphBuilder so that we automatically construct Constant nodes instead of initializers when the options says so. While that avoids the extra pass in the end, it has some minor implications to what the graph would look like (in terms of whether we want all Constant nodes upfront in one place, and what kind of node-names (based on node numbering) we generate etc. Hence, leaving it in the current form. But it can be changed as above if desirable.

Usage: With the utility function, we should be able to create ir.Functions as below (example from Mobius).

def causal_conv_nd_with_state(*, kernel_size, channels, ndim, activation):
    def body(op, input_val, weight_val, bias_val, conv_state_val):
        # ... body (unchanged) ...
        return output, present_state

    return builder.build_function(
        body,
        [
            ir.Value(name="input"),
            ir.Value(name="weight"),
            ir.Value(name="bias"),
            ir.Value(name="conv_state"),
        ],
        domain=DOMAIN,
        name="CausalConvWithState",
        attributes=[ir.Attr("activation", ir.AttributeType.STRING, activation)],
        opset_imports={"": OPSET_VERSION},
    )

gramalingam and others added 3 commits April 15, 2026 01:12
…zers_to_constants

Two related changes to fix how Python literals work in subgraph and
function body contexts:

(a) GraphBuilder now delegates all constant initializer creation to the
    root builder. The _constant_cache lives only on the root builder,
    and _get_or_create_constant() always registers initializers in the
    root graph. This is correct because ONNX allows inner scopes
    (subgraphs) to reference outer-scope initializers. CastLike nodes
    are still created in the local builder's graph (correct scope).

(b) New utility function lift_initializers_to_constants(graph) converts
    all initializers in a graph into Constant nodes. This is needed for
    ir.Function bodies, which cannot have initializers. The function
    preserves ir.Value identity by reusing existing Value objects as
    Constant node outputs. Graph inputs that are also initializers
    (default-value pattern) are skipped.

Together these changes mean:
- Subgraph code can freely use Python literals — initializers go to
  root and are visible from inner scopes
- Function body code can use Python literals — callers apply
  lift_initializers_to_constants() before wrapping in ir.Function

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: G Ramalingam <grama@microsoft.com>
Use f'initializer_{value.name}' as the node name for Constant nodes
created by lift_initializers_to_constants. This follows the same
likely-unique heuristic as the builder's own naming scheme
(f'{op_type}_node_{count}') and makes lifted constants easily
identifiable in the graph.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: G Ramalingam <grama@microsoft.com>
Extend the root-graph invariant to the initializer() method itself,
not just the cached literal path in _get_or_create_constant(). Now
every call to builder.initializer() registers the value in
self._root._graph, ensuring that direct initializer creation from
sub-builders also goes to the root graph.

This maintains a clean invariant: all initializers live in the root
graph, regardless of which builder created them.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: G Ramalingam <grama@microsoft.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Updates GraphBuilder so constant promotion (and explicit initializer() calls) registers initializers on the root graph for subgraph visibility, and adds a post-pass utility to convert initializers to Constant nodes for ONNX ir.Function bodies.

Changes:

  • Add lift_initializers_to_constants(graph) to replace graph initializers with Constant nodes (skipping input-default initializers).
  • Move constant caching to the root builder so sibling subgraphs share promoted constants.
  • Add unit tests validating root initializer placement and the lifting behavior.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 6 comments.

File Description
onnxscript/_internal/builder.py Adds initializer→Constant lifting utility; routes constant creation/initializer registration through the root builder for subgraph visibility.
onnxscript/_internal/builder_test.py Adds tests covering root initializer storage, cache sharing across subgraphs, and lifting initializers into Constant nodes for function bodies.

Comment thread onnxscript/_internal/builder_test.py Outdated
Comment thread onnxscript/_internal/builder_test.py
Comment thread onnxscript/_internal/builder.py Outdated
Comment thread onnxscript/_internal/builder.py Outdated
Comment thread onnxscript/_internal/builder.py Outdated
Comment thread onnxscript/_internal/builder_test.py Outdated
@codecov
Copy link
Copy Markdown

codecov bot commented Apr 15, 2026

Codecov Report

❌ Patch coverage is 96.55172% with 10 lines in your changes missing coverage. Please review.
✅ Project coverage is 72.48%. Comparing base (63ffecf) to head (09098be).
⚠️ Report is 1 commits behind head on main.
✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
onnxscript/_internal/builder.py 89.79% 5 Missing and 5 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2889      +/-   ##
==========================================
+ Coverage   72.30%   72.48%   +0.17%     
==========================================
  Files         241      241              
  Lines       29678    29915     +237     
  Branches     2916     2935      +19     
==========================================
+ Hits        21459    21684     +225     
- Misses       7227     7233       +6     
- Partials      992      998       +6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

- Remove duplicate if __name__ == '__main__' block before RootInitializerTest
- Replace assert with ValueError in lift_initializers_to_constants
- Fix initializer() docstring to note input(const_value=) exception
- Use self.initializer() for non-cached path to preserve scope naming
- Remove unused captured_values in test

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: G Ramalingam <grama@microsoft.com>
Comment thread onnxscript/_internal/builder_test.py Fixed
Comment thread onnxscript/_internal/builder_test.py Fixed
Comment thread onnxscript/_internal/builder.py Fixed
- Use set comprehension instead of set(generator) (RUFF C401)
- Use assertGreater instead of assertTrue(len > 0) (CodeQL)
- Reformat multiline list comprehension (ruff)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: G Ramalingam <grama@microsoft.com>
@gramalingam gramalingam enabled auto-merge (squash) April 15, 2026 20:30
gramalingam and others added 5 commits April 15, 2026 21:23
- Add build_function() for creating ir.Function with automatic
  initializer lifting, optional input support, and strict output
  semantics (return XOR append, not both)
- Add make_input() convenience helper for creating ir.Value inputs
- Rewrite build_graph() to accept Sequence[ir.Value | None] for
  inputs and Sequence[ir.Value] for outputs (breaking change)
- Add _split_optional_inputs() with ownership validation
- Update subgraph() to match new build_graph signature
- Remove deprecated InputOutputSpec and _normalize_io_spec
- Update all test call sites in builder_test.py and _module_test.py
- Add comprehensive BuildFunctionTest and MakeInputTest classes

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: G Ramalingam <grama@microsoft.com>
Aligns naming with ir.val() — both create ir.Value objects and are not
limited to graph inputs. Updated docstring to reference ir.val() and
clarify the TypeSpec difference.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: G Ramalingam <grama@microsoft.com>
When _split_optional_inputs encounters a None entry, it now creates an
untyped ir.Value placeholder (named input_N) and adds it to graph_inputs.
This ensures the function signature declares all formal parameters,
including absent optional ones, while the trace function still receives
None so it can branch with 'if x is None:'.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: G Ramalingam <grama@microsoft.com>
Remove the default of {"":23} — callers must explicitly specify
opset_imports. This avoids silent version assumptions.
subgraph() is unchanged since it inherits from the parent graph.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: G Ramalingam <grama@microsoft.com>
Signed-off-by: G Ramalingam <grama@microsoft.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

Comment thread onnxscript/_internal/builder.py
Comment thread onnxscript/_internal/builder.py Outdated
Comment thread onnxscript/_internal/builder.py Outdated
- Add v.graph check in _split_optional_inputs to reject values already
  attached to a graph (not just those with a producer)
- Fix build_graph docstring: None entries create placeholder formal
  inputs, not excluded from graph inputs
- Fix build_function docstring: same correction
- Add test for input-attached-to-graph validation

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: G Ramalingam <grama@microsoft.com>
@gramalingam gramalingam merged commit c6e8ec6 into main Apr 16, 2026
30 of 33 checks passed
@gramalingam gramalingam deleted the rama/root-initializers branch April 16, 2026 16:52
justinchuby pushed a commit that referenced this pull request Apr 17, 2026
(a) Update the logic for creating initializers to ensure that they are
always added to the root (main) Graph.
(b) Add a utility to convert initializers into Constant nodes (which is
necessary for a Function in ONNX).

Note: An alternative considered was adding an option to the GraphBuilder
so that we automatically construct Constant nodes instead of
initializers when the options says so. While that avoids the extra pass
in the end, it has some minor implications to what the graph would look
like (in terms of whether we want all Constant nodes upfront in one
place, and what kind of node-names (based on node numbering) we generate
etc. Hence, leaving it in the current form. But it can be changed as
above if desirable.

---------

Signed-off-by: G Ramalingam <grama@microsoft.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Development

Successfully merging this pull request may close these issues.

4 participants