-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[dtensor][8/N] switch DeviceMesh to use numpy array for devices #92931
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Switching from torch.Tensor to numpy array to avoid possible interactions with tracing subsystems [ghstack-poisoned]
aazzolini
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! One suggestion for unit testing would be to create a DeviceMesh in FakeMode to reproduce the issue that I had! Or maybe create a DeviceMesh inside of a simple function and make_fx on it?
…vices" Switching from torch.Tensor to numpy array to avoid possible interactions with tracing subsystems [ghstack-poisoned]
…vices" Switching from torch.Tensor to numpy array to avoid possible interactions with tracing subsystems [ghstack-poisoned]
Thanks @aazzolini! Just added a test with the FakeTensorMode, it works as expected. However if we want to fully avoid this issue, we might want to stop supporting DeviceMesh taking tensor as input for the mesh field, and use either numpy array or n-d list instead. Let me know if you want to go with that direction :) |
…vices" Switching from torch.Tensor to numpy array to avoid possible interactions with tracing subsystems [ghstack-poisoned]
This added the numpy typing plugin to mypy config so that we could use it for DeviceMesh typing annotations Please see #92931 about why we need this. For example, we are currently saving the DeviceMesh's mesh field as torch.Tensor, where when we do sth like: ```python with FakeTensorMode(): device_mesh = DeviceMesh("cuda", torch.arange(4)) ``` It would throw error because FakeTensorMode or any TorchDispatchMode tracks every tensor creation and interactions. While DeviceMesh just want to save a nd-array to record the mesh topology, and would like to avoid the interaction with subsystems like FakeTensor, so we want to support saving `mesh` as numpy array instead. [ghstack-poisoned]
This added the numpy typing plugin to mypy config so that we could use it for DeviceMesh typing annotations Please see #92931 about why we need this. For example, we are currently saving the DeviceMesh's mesh field as torch.Tensor, where when we do sth like: ```python with FakeTensorMode(): device_mesh = DeviceMesh("cuda", torch.arange(4)) ``` It would throw error because FakeTensorMode or any TorchDispatchMode tracks every tensor creation and interactions. While DeviceMesh just want to save a nd-array to record the mesh topology, and would like to avoid the interaction with subsystems like FakeTensor, so we want to support saving `mesh` as numpy array instead. [ghstack-poisoned]
…vices" Switching from torch.Tensor to numpy array to avoid possible interactions with tracing subsystems [ghstack-poisoned]
This added the numpy typing plugin to mypy config so that we could use it for DeviceMesh typing annotations Please see #92931 about why we need this. For example, we are currently saving the DeviceMesh's mesh field as torch.Tensor, where when we do sth like: ```python with FakeTensorMode(): device_mesh = DeviceMesh("cuda", torch.arange(4)) ``` It would throw error because FakeTensorMode or any TorchDispatchMode tracks every tensor creation and interactions. While DeviceMesh just want to save a nd-array to record the mesh topology, and would like to avoid the interaction with subsystems like FakeTensor, so we want to support saving `mesh` as numpy array instead. [ghstack-poisoned]
This added the numpy typing plugin to mypy config so that we could use it for DeviceMesh typing annotations Please see #92931 about why we need this. For example, we are currently saving the DeviceMesh's mesh field as torch.Tensor, where when we do sth like: ```python with FakeTensorMode(): device_mesh = DeviceMesh("cuda", torch.arange(4)) ``` It would throw error because FakeTensorMode or any TorchDispatchMode tracks every tensor creation and interactions. While DeviceMesh just want to save a nd-array to record the mesh topology, and would like to avoid the interaction with subsystems like FakeTensor, so we want to support saving `mesh` as numpy array instead. [ghstack-poisoned]
Would allowing passing Tensor as argument but converting it to numpy array or n-d list in |
…vices" Switching from torch.Tensor to numpy array to avoid possible interactions with tracing subsystems [ghstack-poisoned]
This added the numpy typing plugin to mypy config so that we could use it for DeviceMesh typing annotations Please see #92931 about why we need this. For example, we are currently saving the DeviceMesh's mesh field as torch.Tensor, where when we do sth like: ```python with FakeTensorMode(): device_mesh = DeviceMesh("cuda", torch.arange(4)) ``` It would throw error because FakeTensorMode or any TorchDispatchMode tracks every tensor creation and interactions. While DeviceMesh just want to save a nd-array to record the mesh topology, and would like to avoid the interaction with subsystems like FakeTensor, so we want to support saving `mesh` as numpy array instead. [ghstack-poisoned]
This added the numpy typing plugin to mypy config so that we could use it for DeviceMesh typing annotations Please see #92931 about why we need this. For example, we are currently saving the DeviceMesh's mesh field as torch.Tensor, where when we do sth like: ```python with FakeTensorMode(): device_mesh = DeviceMesh("cuda", torch.arange(4)) ``` It would throw error because FakeTensorMode or any TorchDispatchMode tracks every tensor creation and interactions. While DeviceMesh just want to save a nd-array to record the mesh topology, and would like to avoid the interaction with subsystems like FakeTensor, so we want to support saving `mesh` as numpy array instead. [ghstack-poisoned]
…vices" Switching from torch.Tensor to numpy array to avoid possible interactions with tracing subsystems [ghstack-poisoned]
…vices" Switching from torch.Tensor to numpy array to avoid possible interactions with tracing subsystems [ghstack-poisoned]
This added the numpy typing plugin to mypy config so that we could use it for DeviceMesh typing annotations Please see #92931 about why we need this. For example, we are currently saving the DeviceMesh's mesh field as torch.Tensor, where when we do sth like: ```python with FakeTensorMode(): device_mesh = DeviceMesh("cuda", torch.arange(4)) ``` It would throw error because FakeTensorMode or any TorchDispatchMode tracks every tensor creation and interactions. While DeviceMesh just want to save a nd-array to record the mesh topology, and would like to avoid the interaction with subsystems like FakeTensor, so we want to support saving `mesh` as numpy array instead. Pull Request resolved: #92930 Approved by: https://github.com/ezyang, https://github.com/malfet
…vices" Switching from torch.Tensor to numpy array to avoid possible interactions with tracing subsystems [ghstack-poisoned]
|
@wanchaol has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
…vices" Switching from torch.Tensor to numpy array to avoid possible interactions with tracing subsystems Differential Revision: [D42876247](https://our.internmc.facebook.com/intern/diff/D42876247) [ghstack-poisoned]
|
@wanchaol has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
…vices" Switching from torch.Tensor to numpy array to avoid possible interactions with tracing subsystems Differential Revision: [D42876247](https://our.internmc.facebook.com/intern/diff/D42876247) [ghstack-poisoned]
…vices" Switching from torch.Tensor to numpy array to avoid possible interactions with tracing subsystems Differential Revision: [D42876247](https://our.internmc.facebook.com/intern/diff/D42876247) [ghstack-poisoned]
…vices" Switching from torch.Tensor to numpy array to avoid possible interactions with tracing subsystems Differential Revision: [D42876247](https://our.internmc.facebook.com/intern/diff/D42876247) [ghstack-poisoned]
This removes the typing hack, part of #92931 [ghstack-poisoned]
This removes the typing hack, part of #92931 [ghstack-poisoned]
This removes the typing hack, part of #92931 [ghstack-poisoned]
This removes the typing hack, part of #92931 [ghstack-poisoned]
This removes the typing hack, part of #92931 [ghstack-poisoned]
This removes the typing hack, part of #92931 Pull Request resolved: #94526 Approved by: https://github.com/XilunWu
Stack from ghstack (oldest at bottom):
Switching from torch.Tensor to numpy array to avoid possible
interactions with tracing subsystems
Differential Revision: D42876247