Skip to content

Fix missing rank/type propagation after torch.cond (ONNX If node)#626

Draft
Copilot wants to merge 2 commits into
mainfrom
copilot/fix-missing-rank-after-node
Draft

Fix missing rank/type propagation after torch.cond (ONNX If node)#626
Copilot wants to merge 2 commits into
mainfrom
copilot/fix-missing-rank-after-node

Conversation

Copilot AI commented Feb 25, 2026

Copy link
Copy Markdown
Contributor

After aten_cond creates an ONNX If node, the outputs were left without rank, type, or shape metadata in the graph builder — meaning _known_ranks, _known_types, and _known_shapes had no entries for those outputs. Any downstream optimization or graph-building pass that queries has_rank() / get_rank() on those results would get empty or wrong answers.

Changes

  • experimental_experiment/torch_interpreter/_aten_functions.py — After g.make_node("If", ...) in aten_cond, propagate rank/type/shape from the true_graph local function builder to each If output. The builder already has complete metadata since both branches were fully interpreted before this call. Mirrors the same pattern used for submodule functions in interpreter.py.
if g.has_local_function(true_graph, g.local_domain, builder=True):
    fct_builder = g.get_local_function(true_graph, g.local_domain, builder=True)
    for bout, out in zip(fct_builder.output_names, outputs):
        if fct_builder.has_type(bout):
            g.set_type(out, fct_builder.get_type(bout))
        if fct_builder.has_device(bout):
            g.set_device(out, fct_builder.get_device(bout))
        if fct_builder.has_shape(bout):
            g.set_shape(out, fct_builder.get_shape(bout))
        elif fct_builder.has_rank(bout):
            g.set_rank(out, fct_builder.get_rank(bout))

Works correctly for single-output, multi-output, and dynamic-shape cases (builders store dynamic dimension strings such as 's77').

  • _unittests/ut_torch_interpreter/test_onnx_export_control_flow.py — Added three targeted tests covering single-output cond, multi-output cond, and cond with dynamic shapes.
Original prompt

This section details on the original issue you should resolve

<issue_title>fix missing rank after a node if (torch.cond)</issue_title>
<issue_description></issue_description>

Comments on the Issue (you are @copilot in this section)


💬 We'd love your input! Share your thoughts on Copilot coding agent in our 2 minute survey.

Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com>
Copilot AI changed the title [WIP] Fix missing rank after a node in torch.cond Fix missing rank/type propagation after torch.cond (ONNX If node) Feb 25, 2026
Copilot AI requested a review from xadupre February 25, 2026 17:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

fix missing rank after a node if (torch.cond)

2 participants