Skip to content

Commit 4574301

Browse files
mengluy0125pytorchmergebot
authored andcommitted
[PT2][Optimus] Skip meta update on symblic shape (#134975)
Summary: We noticed that there will be runtime error to do the dim broadcast when the meta example value has symbolic shape, thus we skip it. Test Plan: ``` buck2 run mode/opt //caffe2/benchmarks/dynamo/fb:torchbench_run_ads_dhen_5x_training -- -m ads_dhen_5x -t training ``` P1559019921 Differential Revision: D62115015 Pull Request resolved: #134975 Approved by: https://github.com/xuzhao9
1 parent 9ffcca7 commit 4574301

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

torch/_inductor/fx_passes/group_batch_fusion.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -707,15 +707,24 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
707707
torch.baddbmm,
708708
args=(unsqueeze_biases, stack_inputs, transpose_weight),
709709
)
710-
bmm.meta["example_value"] = torch.baddbmm(
711-
unsqueeze_biases.meta["example_value"],
712-
stack_inputs.meta["example_value"],
713-
transpose_weight.meta["example_value"],
714-
)
715-
bmm_meta = bmm.meta["example_value"]
710+
try:
711+
# it will have runtime error to broadcast when it has dynamic shape included
712+
# in the meta data, so we need to skip the update meta data
713+
bmm.meta["example_value"] = torch.baddbmm(
714+
unsqueeze_biases.meta["example_value"],
715+
stack_inputs.meta["example_value"],
716+
transpose_weight.meta["example_value"],
717+
)
718+
bmm_meta = bmm.meta["example_value"]
719+
except Exception as e:
720+
log.debug(
721+
f" exception when update bmm meta data with stack error tracekey {e}" # noqa: G004
722+
)
723+
bmm_meta = None
716724

717725
bmm = graph.call_function(torch.unbind, args=(bmm,), kwargs={"dim": 0})
718-
bmm.meta["example_value"] = torch.unbind(bmm_meta, dim=0)
726+
if bmm_meta is not None:
727+
bmm.meta["example_value"] = torch.unbind(bmm_meta, dim=0)
719728
for i, linear in enumerate(batch_nodes):
720729
with graph.inserting_after(bmm):
721730
getitem = graph.call_function(operator.getitem, args=(bmm, i))

0 commit comments

Comments
 (0)