Skip to content

add tensor subclass printing support in fx/graph.py#164403

Closed
bobrenjc93 wants to merge 6 commits intogh/bobrenjc93/623/basefrom
gh/bobrenjc93/623/head
Closed

add tensor subclass printing support in fx/graph.py#164403
bobrenjc93 wants to merge 6 commits intogh/bobrenjc93/623/basefrom
gh/bobrenjc93/623/head

Conversation

@bobrenjc93
Copy link
Contributor

@bobrenjc93 bobrenjc93 commented Oct 1, 2025

Stack from ghstack (oldest at bottom):

it was previously quite misleading since it looks like the inputs to the
dynamo graph are plain tensors when in reality they are tensor subclasses

before

class GraphModule(torch.nn.Module):
    def forward(self, L_input_batch_inputs_: "i64[2, 512][512, 1]cuda:0", L_self_parameters_weight_: "f32[202048, 256][256, 1]cuda:0"):

after

    class GraphModule(torch.nn.Module):
        def forward(self, L_input_batch_inputs_: "DTensor(i64[2, 512][512, 1]cuda:0)", L_self_parameters_weight_: "DTensor(f32[202048, 256][256, 1]cuda:0)"):

cc @ezyang @EikanWang @jgong5 @wenzhe-nrv @voznesenskym @penguinwu @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela

it was previously quite misleading since it looks like the inputs to the
dynamo graph are plain tensors when in reality they are dtensors

before
```
class GraphModule(torch.nn.Module):
    def forward(self, L_input_batch_inputs_: "i64[2, 512][512, 1]cuda:0", L_self_parameters_weight_: "f32[202048, 256][256, 1]cuda:0"):
```

after
```
class GraphModule(torch.nn.Module):
    def forward(self, L_input_batch_inputs_: "DTensor(dtype=i64, shape=[2, 512], device=cuda:0, mesh=DeviceMesh((dim1=8, dim2=2, dim3=2), device: 'cuda', stride: (4, 2, 1)), placements=(Replicate(), Replicate(), Replicate()))", L_self_parameters_weight_: "DTensor(dtype=f32, shape=[202048, 256], device=cuda:0, mesh=DeviceMesh((dim1=8, dim2=2, dim3=2), device: 'cuda', stride: (4, 2, 1)), placements=(Replicate(), Replicate(), Replicate()))"):

```

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 1, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164403

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f6fc5cb with merge base dfda239 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: fx release notes category label Oct 1, 2025
@bobrenjc93 bobrenjc93 requested review from aorenste and bdhirsh October 1, 2025 20:21
@bobrenjc93 bobrenjc93 added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 1, 2025
@bobrenjc93 bobrenjc93 requested a review from ezyang October 1, 2025 20:22
@bobrenjc93 bobrenjc93 marked this pull request as ready for review October 1, 2025 20:22
@bobrenjc93 bobrenjc93 closed this Oct 1, 2025
@bobrenjc93 bobrenjc93 changed the title add dtensor printing support in fx/graph.py add tensor subclass printing support in fx/graph.py Oct 1, 2025
it was previously quite misleading since it looks like the inputs to the
dynamo graph are plain tensors when in reality they are tensor subclasses

before
```
class GraphModule(torch.nn.Module):
    def forward(self, L_input_batch_inputs_: "i64[2, 512][512, 1]cuda:0", L_self_parameters_weight_: "f32[202048, 256][256, 1]cuda:0"):
```

after
```
    class GraphModule(torch.nn.Module):
        def forward(self, L_input_batch_inputs_: "DTensor(i64[2, 512][512, 1]cuda:0)", L_self_parameters_weight_: "DTensor(f32[202048, 256][256, 1]cuda:0)"):
```

cc ezyang EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
bobrenjc93 added a commit that referenced this pull request Oct 1, 2025
it was previously quite misleading since it looks like the inputs to the
dynamo graph are plain tensors when in reality they are tensor
subclasses

before
```
class GraphModule(torch.nn.Module):
    def forward(self, L_input_batch_inputs_: "i64[2, 512][512, 1]cuda:0", L_self_parameters_weight_: "f32[202048, 256][256, 1]cuda:0"):
```

after
```
class GraphModule(torch.nn.Module):
    def forward(self, L_input_batch_inputs_: "DTensor(i64[2, 512][512, 1]cuda:0)", L_self_parameters_weight_: "DTensor(f32[202048, 256][256, 1]cuda:0)"):
```

ghstack-source-id: 4054d27
Pull Request resolved: #164403
@bobrenjc93
Copy link
Contributor Author

updated the implementation for now to be more generic and check if it's a tensor subclass and if so just wrap the existing tensor descriptor with the subclass name. maybe in the future i'll extend this to have custom printers, but this solves the high level problem for now.

@bobrenjc93 bobrenjc93 reopened this Oct 1, 2025
it was previously quite misleading since it looks like the inputs to the
dynamo graph are plain tensors when in reality they are tensor subclasses

before
```
class GraphModule(torch.nn.Module):
    def forward(self, L_input_batch_inputs_: "i64[2, 512][512, 1]cuda:0", L_self_parameters_weight_: "f32[202048, 256][256, 1]cuda:0"):
```

after
```
    class GraphModule(torch.nn.Module):
        def forward(self, L_input_batch_inputs_: "DTensor(i64[2, 512][512, 1]cuda:0)", L_self_parameters_weight_: "DTensor(f32[202048, 256][256, 1]cuda:0)"):
```

cc ezyang EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
bobrenjc93 added a commit that referenced this pull request Oct 1, 2025
it was previously quite misleading since it looks like the inputs to the
dynamo graph are plain tensors when in reality they are tensor
subclasses

before
```
class GraphModule(torch.nn.Module):
    def forward(self, L_input_batch_inputs_: "i64[2, 512][512, 1]cuda:0", L_self_parameters_weight_: "f32[202048, 256][256, 1]cuda:0"):
```

after
```
class GraphModule(torch.nn.Module):
    def forward(self, L_input_batch_inputs_: "DTensor(i64[2, 512][512, 1]cuda:0)", L_self_parameters_weight_: "DTensor(f32[202048, 256][256, 1]cuda:0)"):
```

ghstack-source-id: 0810646
Pull Request resolved: #164403
it was previously quite misleading since it looks like the inputs to the
dynamo graph are plain tensors when in reality they are tensor subclasses

before
```
class GraphModule(torch.nn.Module):
    def forward(self, L_input_batch_inputs_: "i64[2, 512][512, 1]cuda:0", L_self_parameters_weight_: "f32[202048, 256][256, 1]cuda:0"):
```

after
```
    class GraphModule(torch.nn.Module):
        def forward(self, L_input_batch_inputs_: "DTensor(i64[2, 512][512, 1]cuda:0)", L_self_parameters_weight_: "DTensor(f32[202048, 256][256, 1]cuda:0)"):
```

cc ezyang EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
bobrenjc93 added a commit that referenced this pull request Oct 1, 2025
it was previously quite misleading since it looks like the inputs to the
dynamo graph are plain tensors when in reality they are tensor
subclasses

before
```
class GraphModule(torch.nn.Module):
    def forward(self, L_input_batch_inputs_: "i64[2, 512][512, 1]cuda:0", L_self_parameters_weight_: "f32[202048, 256][256, 1]cuda:0"):
```

after
```
class GraphModule(torch.nn.Module):
    def forward(self, L_input_batch_inputs_: "DTensor(i64[2, 512][512, 1]cuda:0)", L_self_parameters_weight_: "DTensor(f32[202048, 256][256, 1]cuda:0)"):
```

ghstack-source-id: 02236d5
Pull Request resolved: #164403
it was previously quite misleading since it looks like the inputs to the
dynamo graph are plain tensors when in reality they are tensor subclasses

before
```
class GraphModule(torch.nn.Module):
    def forward(self, L_input_batch_inputs_: "i64[2, 512][512, 1]cuda:0", L_self_parameters_weight_: "f32[202048, 256][256, 1]cuda:0"):
```

after
```
    class GraphModule(torch.nn.Module):
        def forward(self, L_input_batch_inputs_: "DTensor(i64[2, 512][512, 1]cuda:0)", L_self_parameters_weight_: "DTensor(f32[202048, 256][256, 1]cuda:0)"):
```

cc ezyang EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
bobrenjc93 added a commit that referenced this pull request Oct 1, 2025
it was previously quite misleading since it looks like the inputs to the
dynamo graph are plain tensors when in reality they are tensor
subclasses

before
```
class GraphModule(torch.nn.Module):
    def forward(self, L_input_batch_inputs_: "i64[2, 512][512, 1]cuda:0", L_self_parameters_weight_: "f32[202048, 256][256, 1]cuda:0"):
```

after
```
class GraphModule(torch.nn.Module):
    def forward(self, L_input_batch_inputs_: "DTensor(i64[2, 512][512, 1]cuda:0)", L_self_parameters_weight_: "DTensor(f32[202048, 256][256, 1]cuda:0)"):
```

ghstack-source-id: 0d1cfaf
Pull Request resolved: #164403
@ezyang
Copy link
Contributor

ezyang commented Oct 2, 2025

I'd like this to land but it needs more work

it was previously quite misleading since it looks like the inputs to the
dynamo graph are plain tensors when in reality they are tensor subclasses

before
```
class GraphModule(torch.nn.Module):
    def forward(self, L_input_batch_inputs_: "i64[2, 512][512, 1]cuda:0", L_self_parameters_weight_: "f32[202048, 256][256, 1]cuda:0"):
```

after
```
    class GraphModule(torch.nn.Module):
        def forward(self, L_input_batch_inputs_: "DTensor(i64[2, 512][512, 1]cuda:0)", L_self_parameters_weight_: "DTensor(f32[202048, 256][256, 1]cuda:0)"):
```

cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
bobrenjc93 added a commit that referenced this pull request Oct 2, 2025
it was previously quite misleading since it looks like the inputs to the
dynamo graph are plain tensors when in reality they are tensor
subclasses

before
```
class GraphModule(torch.nn.Module):
    def forward(self, L_input_batch_inputs_: "i64[2, 512][512, 1]cuda:0", L_self_parameters_weight_: "f32[202048, 256][256, 1]cuda:0"):
```

after
```
class GraphModule(torch.nn.Module):
    def forward(self, L_input_batch_inputs_: "DTensor(i64[2, 512][512, 1]cuda:0)", L_self_parameters_weight_: "DTensor(f32[202048, 256][256, 1]cuda:0)"):
```

ghstack-source-id: a6fcc2d
Pull Request resolved: #164403
@bobrenjc93
Copy link
Contributor Author

refactored PR to be more DRY

@ezyang
Copy link
Contributor

ezyang commented Oct 2, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
it was previously quite misleading since it looks like the inputs to the
dynamo graph are plain tensors when in reality they are tensor subclasses

before
```
class GraphModule(torch.nn.Module):
    def forward(self, L_input_batch_inputs_: "i64[2, 512][512, 1]cuda:0", L_self_parameters_weight_: "f32[202048, 256][256, 1]cuda:0"):
```

after
```
    class GraphModule(torch.nn.Module):
        def forward(self, L_input_batch_inputs_: "DTensor(i64[2, 512][512, 1]cuda:0)", L_self_parameters_weight_: "DTensor(f32[202048, 256][256, 1]cuda:0)"):
```

Pull Request resolved: pytorch#164403
Approved by: https://github.com/ezyang
@github-actions github-actions bot deleted the gh/bobrenjc93/623/head branch November 2, 2025 02:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request fx Merged module: dynamo release notes: fx release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants