-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Fix model.to(xla_device) #21048
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix model.to(xla_device) #21048
Conversation
f75249a to
6cc1468
Compare
…hem in nn.Module _apply()
81ffc4b to
9f1ac9c
Compare
torch/csrc/autograd/variable.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't we just make this a native function? Then you wouldn't need to make your own parsing either.
torch/csrc/autograd/variable.h
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this comment shoudl refer to is_same_impl_type.
test/test_nn.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't this be cuda()?
test/test_nn.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please comment what this is attempting to test.
torch/nn/modules/module.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should warn yet: there is no alternative they can use yet if they really want to hold on to a reference to another tensor. We need new APIs that we can direct people to first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, is it possible to name the type of module? It might not be obvious because you can move an entire model and a param in some sub module changes.
torch/nn/modules/module.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we please break up this PR? This is trying to do a bunch of different things:
- Fix moving to XLA
- Do proper version tracking of module parameters (sometimes)
- Warn about future breaking changes (without first introducing correct APIs).
torch/nn/modules/module.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why doesn't this follow the other pattern of using no_grad and setting requires_grad at the end?
| with torch.no_grad(): | ||
| # We use `.requires_grad_()` here to make sure the new `param` still | ||
| # has the same `requires_grad` value as the old `param`. | ||
| self._parameters[key] = param_applied.requires_grad_(param.requires_grad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is correct, won't the parameter not even be an nn.Parameter anymore?
#17072 breaks
model.to(xla_device), because movingmodelto XLA device involves changing its parameters' TensorImpl type, and the current implementation ofnn.Module.to()doesn't support changing module parameters' TensorImpl type:A hypothetical way to fix this is to do the following:
However, the biggest problem of this approach is that it makes the
model.to(device)API less predictable: if we are movingmodelfrom CPU to CUDA, all previous references toparamare preserved because we useparam.data = fn(param.data); however, if we are movingmodelfrom CPU to XLA device, all previous references toparamare broken because we are assigning new tensors toparam. In order to preservemodel.to(device)API consistency, we will be changing CPU-CUDA model moving code to also break previous references to the model's parameters, and this will happen in two stages, first as a deprecation notice if we detect previous references to those parameters, second as a hard error (in future releases).cc. @ailzhang
TODO:
Add explanations why the following cases are no longer supported: