Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6251,6 +6251,13 @@ def test_half_tensor(self):
xh2 = torch.load(f)
self.assertEqual(xh.float(), xh2.float())

def test_serialize_device(self):
device_str = ['cpu', 'cpu:0', 'cuda', 'cuda:0']
device_obj = [torch.device(d) for d in device_str]
for device in device_obj:
device_copied = copy.deepcopy(device)
self.assertEqual(device, device_copied)

@unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
def test_half_tensor_cuda(self):
x = torch.randn(5, 5).half()
Expand Down
32 changes: 30 additions & 2 deletions torch/csrc/Device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ PyObject *THPDevice_repr(THPDevice *self)
return THPUtils_packString(oss.str().c_str());
}

PyObject *THPDevice_str(THPDevice*self)
PyObject *THPDevice_str(THPDevice *self)
{
std::ostringstream oss;
if (!self->device.is_default) {
Expand Down Expand Up @@ -137,6 +137,29 @@ PyObject *THPDevice_rc(PyObject *a, PyObject *b, int op) {
END_HANDLE_TH_ERRORS
}

PyObject *THPDevice_reduce(THPDevice *self)
{
HANDLE_TH_ERRORS
auto ret = THPObjectPtr{PyTuple_New(2)};
if (!ret) throw python_error();

py::object torch_module = py::module::import("torch");
py::object torch_device = torch_module.attr("device");

This comment was marked as off-topic.

PyTuple_SET_ITEM(ret.get(), 0, torch_device.release().ptr());

THPObjectPtr args;
if (self->device.is_default) {
args = THPObjectPtr{Py_BuildValue("(s)", deviceTypeString(self->device.type))};
} else {
args = THPObjectPtr{Py_BuildValue("(si)", deviceTypeString(self->device.type), self->device.index)};
}
if (!args) throw python_error();
PyTuple_SET_ITEM(ret.get(), 1, args.release());

return ret.release();
END_HANDLE_TH_ERRORS
}

typedef PyObject *(*getter)(PyObject *, void *);

static struct PyGetSetDef THPDevice_properties[] = {
Expand All @@ -145,6 +168,11 @@ static struct PyGetSetDef THPDevice_properties[] = {
{nullptr}
};

static PyMethodDef THPDevice_methods[] = {
{"__reduce__", (PyCFunction)THPDevice_reduce, METH_NOARGS, nullptr},
{NULL} /* Sentinel */
};

PyTypeObject THPDeviceType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"torch.device", /* tp_name */
Expand Down Expand Up @@ -173,7 +201,7 @@ PyTypeObject THPDeviceType = {
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
0, /* tp_methods */
THPDevice_methods, /* tp_methods */
0, /* tp_members */
THPDevice_properties, /* tp_getset */
0, /* tp_base */
Expand Down