-
Notifications
You must be signed in to change notification settings - Fork 26.3k
serialization for torch.device #7713
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
torch/csrc/Device.cpp
Outdated
|
|
||
| PyObject *THPDevice_reduce(THPDevice *self) | ||
| { | ||
| PyObject *ret, *mod, *obj; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_torch.py
Outdated
| net2 = torch.load(f) | ||
| self.assertEqual(type(net), type(net2)) | ||
| self.assertEqual(net.state_dict(), net2.state_dict()) | ||
| self.assertEqual(net.device, net2.device) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| if (!ret) throw python_error(); | ||
|
|
||
| py::object torch_module = py::module::import("torch"); | ||
| py::object torch_device = torch_module.attr("device"); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/Device.cpp
Outdated
| PyTuple_SET_ITEM(ret.get(), 0, torch_device.release().ptr()); | ||
|
|
||
| if (self->device.is_default) { | ||
| PyTuple_SET_ITEM(ret.get(), 1, Py_BuildValue("(s)", deviceTypeString(self->device.type))); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
…e2_core_hip * 'caffe2_core_hip' of github.com:petrex/pytorch: (40 commits) [auto] Update onnx to 52f7528 - add more shape inference tests (onnx/onnx#971) onnx/onnx@52f7528 JIT cleanup (pytorch#7631) fix to build sleef when using cmake 3.11.1 (pytorch#7679) Fix typo in document (pytorch#7725) [auto] Update onnx to 6f4b1b1 - Tests for Gemm operator (onnx/onnx#885) onnx/onnx@6f4b1b1 [auto] Update onnx to c6c6aad - Enhance the 1-element broadcast case (onnx/onnx#902) onnx/onnx@c6c6aad serialization for torch.device (pytorch#7713) Fix compile flags for MSVC (pytorch#7703) Fix exporting Sum to onnx (pytorch#7685) Renanme ZFNet to ZFNet512 (pytorch#7723) Implement __reduce__ for torch.dtype (pytorch#7699) Remove unnecessary include in vec256_float.h (pytorch#7711) Update from facebook (pytorch#7696) fix for cuda 9.2 builds (pytorch#7709) make BatchSampler subclass of Sampler, and expose (pytorch#7707) Dont emit warning for ABI incompatibility when PyTorch was built from source (pytorch#7681) remove index from python bindings (fixes: pytorch#7639) (pytorch#7690) Update _torch_docs.py (pytorch#7700) Fix the wrong usage of environment variables detection in cmake Changes from D7881937 and D7963936 plus an edit (pytorch#7605) ...
This PR addresses #7545 .
Implement reduce for
torch.deviceto return a newtorch.deviceobject with the same device type & index.Note this
torch.deviceis slightly different in the sense thatcopy.deepcopy()returns a new object instead of the same one intorch.dtype.Also added a test in test_torch to show how module object can be serialized. Nit: it doesn't work if NetwithDevice class is put inside test_save_net_with_device() function. This is due to
netbeing an instance of<class '__main__.TestTorch.test_save_net_with_device.<locals>.NetwithDevice'>andTestTorchis not serializable(no getstate implemented). I can't think of a better way to change the class scope than putting it outside. Let me know if you have a better idea.