add tensor subclass printing support in fx/graph.py#164403
add tensor subclass printing support in fx/graph.py#164403bobrenjc93 wants to merge 6 commits intogh/bobrenjc93/623/basefrom
Conversation
it was previously quite misleading since it looks like the inputs to the
dynamo graph are plain tensors when in reality they are dtensors
before
```
class GraphModule(torch.nn.Module):
def forward(self, L_input_batch_inputs_: "i64[2, 512][512, 1]cuda:0", L_self_parameters_weight_: "f32[202048, 256][256, 1]cuda:0"):
```
after
```
class GraphModule(torch.nn.Module):
def forward(self, L_input_batch_inputs_: "DTensor(dtype=i64, shape=[2, 512], device=cuda:0, mesh=DeviceMesh((dim1=8, dim2=2, dim3=2), device: 'cuda', stride: (4, 2, 1)), placements=(Replicate(), Replicate(), Replicate()))", L_self_parameters_weight_: "DTensor(dtype=f32, shape=[202048, 256], device=cuda:0, mesh=DeviceMesh((dim1=8, dim2=2, dim3=2), device: 'cuda', stride: (4, 2, 1)), placements=(Replicate(), Replicate(), Replicate()))"):
```
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164403
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit f6fc5cb with merge base dfda239 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
it was previously quite misleading since it looks like the inputs to the
dynamo graph are plain tensors when in reality they are tensor subclasses
before
```
class GraphModule(torch.nn.Module):
def forward(self, L_input_batch_inputs_: "i64[2, 512][512, 1]cuda:0", L_self_parameters_weight_: "f32[202048, 256][256, 1]cuda:0"):
```
after
```
class GraphModule(torch.nn.Module):
def forward(self, L_input_batch_inputs_: "DTensor(i64[2, 512][512, 1]cuda:0)", L_self_parameters_weight_: "DTensor(f32[202048, 256][256, 1]cuda:0)"):
```
cc ezyang EikanWang jgong5 wenzhe-nrv
[ghstack-poisoned]
it was previously quite misleading since it looks like the inputs to the
dynamo graph are plain tensors when in reality they are tensor
subclasses
before
```
class GraphModule(torch.nn.Module):
def forward(self, L_input_batch_inputs_: "i64[2, 512][512, 1]cuda:0", L_self_parameters_weight_: "f32[202048, 256][256, 1]cuda:0"):
```
after
```
class GraphModule(torch.nn.Module):
def forward(self, L_input_batch_inputs_: "DTensor(i64[2, 512][512, 1]cuda:0)", L_self_parameters_weight_: "DTensor(f32[202048, 256][256, 1]cuda:0)"):
```
ghstack-source-id: 4054d27
Pull Request resolved: #164403
|
updated the implementation for now to be more generic and check if it's a tensor subclass and if so just wrap the existing tensor descriptor with the subclass name. maybe in the future i'll extend this to have custom printers, but this solves the high level problem for now. |
it was previously quite misleading since it looks like the inputs to the
dynamo graph are plain tensors when in reality they are tensor subclasses
before
```
class GraphModule(torch.nn.Module):
def forward(self, L_input_batch_inputs_: "i64[2, 512][512, 1]cuda:0", L_self_parameters_weight_: "f32[202048, 256][256, 1]cuda:0"):
```
after
```
class GraphModule(torch.nn.Module):
def forward(self, L_input_batch_inputs_: "DTensor(i64[2, 512][512, 1]cuda:0)", L_self_parameters_weight_: "DTensor(f32[202048, 256][256, 1]cuda:0)"):
```
cc ezyang EikanWang jgong5 wenzhe-nrv
[ghstack-poisoned]
it was previously quite misleading since it looks like the inputs to the
dynamo graph are plain tensors when in reality they are tensor
subclasses
before
```
class GraphModule(torch.nn.Module):
def forward(self, L_input_batch_inputs_: "i64[2, 512][512, 1]cuda:0", L_self_parameters_weight_: "f32[202048, 256][256, 1]cuda:0"):
```
after
```
class GraphModule(torch.nn.Module):
def forward(self, L_input_batch_inputs_: "DTensor(i64[2, 512][512, 1]cuda:0)", L_self_parameters_weight_: "DTensor(f32[202048, 256][256, 1]cuda:0)"):
```
ghstack-source-id: 0810646
Pull Request resolved: #164403
it was previously quite misleading since it looks like the inputs to the
dynamo graph are plain tensors when in reality they are tensor subclasses
before
```
class GraphModule(torch.nn.Module):
def forward(self, L_input_batch_inputs_: "i64[2, 512][512, 1]cuda:0", L_self_parameters_weight_: "f32[202048, 256][256, 1]cuda:0"):
```
after
```
class GraphModule(torch.nn.Module):
def forward(self, L_input_batch_inputs_: "DTensor(i64[2, 512][512, 1]cuda:0)", L_self_parameters_weight_: "DTensor(f32[202048, 256][256, 1]cuda:0)"):
```
cc ezyang EikanWang jgong5 wenzhe-nrv
[ghstack-poisoned]
it was previously quite misleading since it looks like the inputs to the
dynamo graph are plain tensors when in reality they are tensor
subclasses
before
```
class GraphModule(torch.nn.Module):
def forward(self, L_input_batch_inputs_: "i64[2, 512][512, 1]cuda:0", L_self_parameters_weight_: "f32[202048, 256][256, 1]cuda:0"):
```
after
```
class GraphModule(torch.nn.Module):
def forward(self, L_input_batch_inputs_: "DTensor(i64[2, 512][512, 1]cuda:0)", L_self_parameters_weight_: "DTensor(f32[202048, 256][256, 1]cuda:0)"):
```
ghstack-source-id: 02236d5
Pull Request resolved: #164403
it was previously quite misleading since it looks like the inputs to the
dynamo graph are plain tensors when in reality they are tensor subclasses
before
```
class GraphModule(torch.nn.Module):
def forward(self, L_input_batch_inputs_: "i64[2, 512][512, 1]cuda:0", L_self_parameters_weight_: "f32[202048, 256][256, 1]cuda:0"):
```
after
```
class GraphModule(torch.nn.Module):
def forward(self, L_input_batch_inputs_: "DTensor(i64[2, 512][512, 1]cuda:0)", L_self_parameters_weight_: "DTensor(f32[202048, 256][256, 1]cuda:0)"):
```
cc ezyang EikanWang jgong5 wenzhe-nrv
[ghstack-poisoned]
it was previously quite misleading since it looks like the inputs to the
dynamo graph are plain tensors when in reality they are tensor
subclasses
before
```
class GraphModule(torch.nn.Module):
def forward(self, L_input_batch_inputs_: "i64[2, 512][512, 1]cuda:0", L_self_parameters_weight_: "f32[202048, 256][256, 1]cuda:0"):
```
after
```
class GraphModule(torch.nn.Module):
def forward(self, L_input_batch_inputs_: "DTensor(i64[2, 512][512, 1]cuda:0)", L_self_parameters_weight_: "DTensor(f32[202048, 256][256, 1]cuda:0)"):
```
ghstack-source-id: 0d1cfaf
Pull Request resolved: #164403
|
I'd like this to land but it needs more work |
it was previously quite misleading since it looks like the inputs to the
dynamo graph are plain tensors when in reality they are tensor subclasses
before
```
class GraphModule(torch.nn.Module):
def forward(self, L_input_batch_inputs_: "i64[2, 512][512, 1]cuda:0", L_self_parameters_weight_: "f32[202048, 256][256, 1]cuda:0"):
```
after
```
class GraphModule(torch.nn.Module):
def forward(self, L_input_batch_inputs_: "DTensor(i64[2, 512][512, 1]cuda:0)", L_self_parameters_weight_: "DTensor(f32[202048, 256][256, 1]cuda:0)"):
```
cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela
[ghstack-poisoned]
it was previously quite misleading since it looks like the inputs to the
dynamo graph are plain tensors when in reality they are tensor
subclasses
before
```
class GraphModule(torch.nn.Module):
def forward(self, L_input_batch_inputs_: "i64[2, 512][512, 1]cuda:0", L_self_parameters_weight_: "f32[202048, 256][256, 1]cuda:0"):
```
after
```
class GraphModule(torch.nn.Module):
def forward(self, L_input_batch_inputs_: "DTensor(i64[2, 512][512, 1]cuda:0)", L_self_parameters_weight_: "DTensor(f32[202048, 256][256, 1]cuda:0)"):
```
ghstack-source-id: a6fcc2d
Pull Request resolved: #164403
|
refactored PR to be more DRY |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
it was previously quite misleading since it looks like the inputs to the
dynamo graph are plain tensors when in reality they are tensor subclasses
before
```
class GraphModule(torch.nn.Module):
def forward(self, L_input_batch_inputs_: "i64[2, 512][512, 1]cuda:0", L_self_parameters_weight_: "f32[202048, 256][256, 1]cuda:0"):
```
after
```
class GraphModule(torch.nn.Module):
def forward(self, L_input_batch_inputs_: "DTensor(i64[2, 512][512, 1]cuda:0)", L_self_parameters_weight_: "DTensor(f32[202048, 256][256, 1]cuda:0)"):
```
Pull Request resolved: pytorch#164403
Approved by: https://github.com/ezyang
Stack from ghstack (oldest at bottom):
it was previously quite misleading since it looks like the inputs to the
dynamo graph are plain tensors when in reality they are tensor subclasses
before
after
cc @ezyang @EikanWang @jgong5 @wenzhe-nrv @voznesenskym @penguinwu @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela