Skip to content

Conversation

@wanchaol
Copy link
Collaborator

@wanchaol wanchaol commented Jan 24, 2023

Switching from torch.Tensor to numpy array to avoid possible
interactions with tracing subsystems

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Jan 24, 2023

Copy link
Contributor

@aazzolini aazzolini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! One suggestion for unit testing would be to create a DeviceMesh in FakeMode to reproduce the issue that I had! Or maybe create a DeviceMesh inside of a simple function and make_fx on it?

…vices"

Switching from torch.Tensor to numpy array to avoid possible
interactions with tracing subsystems

[ghstack-poisoned]
…vices"

Switching from torch.Tensor to numpy array to avoid possible
interactions with tracing subsystems

[ghstack-poisoned]
@wanchaol wanchaol changed the title [dtensor][7/N] switch DeviceMesh to use numpy array for devices [dtensor][8/N] switch DeviceMesh to use numpy array for devices Jan 26, 2023
@wanchaol
Copy link
Collaborator Author

Thanks! One suggestion for unit testing would be to create a DeviceMesh in FakeMode to reproduce the issue that I had! Or maybe create a DeviceMesh inside of a simple function and make_fx on it?

Thanks @aazzolini! Just added a test with the FakeTensorMode, it works as expected. However if we want to fully avoid this issue, we might want to stop supporting DeviceMesh taking tensor as input for the mesh field, and use either numpy array or n-d list instead. Let me know if you want to go with that direction :)

…vices"

Switching from torch.Tensor to numpy array to avoid possible
interactions with tracing subsystems

[ghstack-poisoned]
wanchaol added a commit that referenced this pull request Jan 26, 2023
This added the numpy typing plugin to mypy config so that we could
use it for DeviceMesh typing annotations

Please see #92931 about why we need this. For example, we are currently saving the DeviceMesh's mesh field as torch.Tensor, where when we do sth like:
```python
with FakeTensorMode():
    device_mesh = DeviceMesh("cuda", torch.arange(4))
```
It would throw error because FakeTensorMode or any TorchDispatchMode tracks every tensor creation and interactions. While DeviceMesh just want to save a nd-array to record the mesh topology, and would like to avoid the interaction with subsystems like FakeTensor, so we want to support saving `mesh` as numpy array instead.


[ghstack-poisoned]
wanchaol added a commit that referenced this pull request Jan 26, 2023
This added the numpy typing plugin to mypy config so that we could
use it for DeviceMesh typing annotations

Please see #92931 about why we need this. For example, we are currently saving the DeviceMesh's mesh field as torch.Tensor, where when we do sth like:
```python
with FakeTensorMode():
    device_mesh = DeviceMesh("cuda", torch.arange(4))
```
It would throw error because FakeTensorMode or any TorchDispatchMode tracks every tensor creation and interactions. While DeviceMesh just want to save a nd-array to record the mesh topology, and would like to avoid the interaction with subsystems like FakeTensor, so we want to support saving `mesh` as numpy array instead.


[ghstack-poisoned]
…vices"

Switching from torch.Tensor to numpy array to avoid possible
interactions with tracing subsystems

[ghstack-poisoned]
wanchaol added a commit that referenced this pull request Jan 26, 2023
This added the numpy typing plugin to mypy config so that we could
use it for DeviceMesh typing annotations

Please see #92931 about why we need this. For example, we are currently saving the DeviceMesh's mesh field as torch.Tensor, where when we do sth like:
```python
with FakeTensorMode():
    device_mesh = DeviceMesh("cuda", torch.arange(4))
```
It would throw error because FakeTensorMode or any TorchDispatchMode tracks every tensor creation and interactions. While DeviceMesh just want to save a nd-array to record the mesh topology, and would like to avoid the interaction with subsystems like FakeTensor, so we want to support saving `mesh` as numpy array instead.


[ghstack-poisoned]
wanchaol added a commit that referenced this pull request Jan 26, 2023
This added the numpy typing plugin to mypy config so that we could
use it for DeviceMesh typing annotations

Please see #92931 about why we need this. For example, we are currently saving the DeviceMesh's mesh field as torch.Tensor, where when we do sth like:
```python
with FakeTensorMode():
    device_mesh = DeviceMesh("cuda", torch.arange(4))
```
It would throw error because FakeTensorMode or any TorchDispatchMode tracks every tensor creation and interactions. While DeviceMesh just want to save a nd-array to record the mesh topology, and would like to avoid the interaction with subsystems like FakeTensor, so we want to support saving `mesh` as numpy array instead.


[ghstack-poisoned]
@XilunWu
Copy link
Contributor

XilunWu commented Jan 26, 2023

Thanks! One suggestion for unit testing would be to create a DeviceMesh in FakeMode to reproduce the issue that I had! Or maybe create a DeviceMesh inside of a simple function and make_fx on it?

Thanks @aazzolini! Just added a test with the FakeTensorMode, it works as expected. However if we want to fully avoid this issue, we might want to stop supporting DeviceMesh taking tensor as input for the mesh field, and use either numpy array or n-d list instead. Let me know if you want to go with that direction :)

Would allowing passing Tensor as argument but converting it to numpy array or n-d list in DeviceMesh._init_ avoid the tracing issue?

…vices"

Switching from torch.Tensor to numpy array to avoid possible
interactions with tracing subsystems

[ghstack-poisoned]
wanchaol added a commit that referenced this pull request Jan 30, 2023
This added the numpy typing plugin to mypy config so that we could
use it for DeviceMesh typing annotations

Please see #92931 about why we need this. For example, we are currently saving the DeviceMesh's mesh field as torch.Tensor, where when we do sth like:
```python
with FakeTensorMode():
    device_mesh = DeviceMesh("cuda", torch.arange(4))
```
It would throw error because FakeTensorMode or any TorchDispatchMode tracks every tensor creation and interactions. While DeviceMesh just want to save a nd-array to record the mesh topology, and would like to avoid the interaction with subsystems like FakeTensor, so we want to support saving `mesh` as numpy array instead.


[ghstack-poisoned]
wanchaol added a commit that referenced this pull request Jan 30, 2023
This added the numpy typing plugin to mypy config so that we could
use it for DeviceMesh typing annotations

Please see #92931 about why we need this. For example, we are currently saving the DeviceMesh's mesh field as torch.Tensor, where when we do sth like:
```python
with FakeTensorMode():
    device_mesh = DeviceMesh("cuda", torch.arange(4))
```
It would throw error because FakeTensorMode or any TorchDispatchMode tracks every tensor creation and interactions. While DeviceMesh just want to save a nd-array to record the mesh topology, and would like to avoid the interaction with subsystems like FakeTensor, so we want to support saving `mesh` as numpy array instead.


[ghstack-poisoned]
…vices"

Switching from torch.Tensor to numpy array to avoid possible
interactions with tracing subsystems

[ghstack-poisoned]
…vices"

Switching from torch.Tensor to numpy array to avoid possible
interactions with tracing subsystems

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Jan 31, 2023
This added the numpy typing plugin to mypy config so that we could
use it for DeviceMesh typing annotations

Please see #92931 about why we need this. For example, we are currently saving the DeviceMesh's mesh field as torch.Tensor, where when we do sth like:
```python
with FakeTensorMode():
    device_mesh = DeviceMesh("cuda", torch.arange(4))
```
It would throw error because FakeTensorMode or any TorchDispatchMode tracks every tensor creation and interactions. While DeviceMesh just want to save a nd-array to record the mesh topology, and would like to avoid the interaction with subsystems like FakeTensor, so we want to support saving `mesh` as numpy array instead.

Pull Request resolved: #92930
Approved by: https://github.com/ezyang, https://github.com/malfet
…vices"

Switching from torch.Tensor to numpy array to avoid possible
interactions with tracing subsystems

[ghstack-poisoned]
@wanchaol
Copy link
Collaborator Author

@wanchaol has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

…vices"

Switching from torch.Tensor to numpy array to avoid possible
interactions with tracing subsystems

Differential Revision: [D42876247](https://our.internmc.facebook.com/intern/diff/D42876247)

[ghstack-poisoned]
@wanchaol
Copy link
Collaborator Author

@wanchaol has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

…vices"

Switching from torch.Tensor to numpy array to avoid possible
interactions with tracing subsystems

Differential Revision: [D42876247](https://our.internmc.facebook.com/intern/diff/D42876247)

[ghstack-poisoned]
…vices"

Switching from torch.Tensor to numpy array to avoid possible
interactions with tracing subsystems

Differential Revision: [D42876247](https://our.internmc.facebook.com/intern/diff/D42876247)

[ghstack-poisoned]
…vices"

Switching from torch.Tensor to numpy array to avoid possible
interactions with tracing subsystems

Differential Revision: [D42876247](https://our.internmc.facebook.com/intern/diff/D42876247)

[ghstack-poisoned]
wanchaol added a commit that referenced this pull request Feb 9, 2023
This removes the typing hack, part of #92931

[ghstack-poisoned]
wanchaol added a commit that referenced this pull request Feb 9, 2023
This removes the typing hack, part of #92931

[ghstack-poisoned]
wanchaol added a commit that referenced this pull request Feb 9, 2023
This removes the typing hack, part of #92931

[ghstack-poisoned]
wanchaol added a commit that referenced this pull request Feb 9, 2023
This removes the typing hack, part of #92931

ghstack-source-id: fdb52c7
Pull Request resolved: #94526
wanchaol added a commit that referenced this pull request Mar 28, 2023
This removes the typing hack, part of #92931

ghstack-source-id: e5890ae
Pull Request resolved: #94526
wanchaol added a commit that referenced this pull request Mar 28, 2023
This removes the typing hack, part of #92931

[ghstack-poisoned]
wanchaol added a commit that referenced this pull request Mar 28, 2023
This removes the typing hack, part of #92931

[ghstack-poisoned]
@wanchaol wanchaol closed this Mar 28, 2023
pytorchmergebot pushed a commit that referenced this pull request Mar 29, 2023
This removes the typing hack, part of #92931
Pull Request resolved: #94526
Approved by: https://github.com/XilunWu
@facebook-github-bot facebook-github-bot deleted the gh/wanchaol/251/head branch June 8, 2023 19:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants