Skip to content

Commit b4ae80d

Browse files
ailzhangapaszke
authored andcommitted
serialization for torch.device (#7713)
1 parent ee6e3fe commit b4ae80d

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

test/test_torch.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6257,6 +6257,13 @@ def test_half_tensor(self):
62576257
xh2 = torch.load(f)
62586258
self.assertEqual(xh.float(), xh2.float())
62596259

6260+
def test_serialize_device(self):
6261+
device_str = ['cpu', 'cpu:0', 'cuda', 'cuda:0']
6262+
device_obj = [torch.device(d) for d in device_str]
6263+
for device in device_obj:
6264+
device_copied = copy.deepcopy(device)
6265+
self.assertEqual(device, device_copied)
6266+
62606267
@unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
62616268
def test_half_tensor_cuda(self):
62626269
x = torch.randn(5, 5).half()

torch/csrc/Device.cpp

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ PyObject *THPDevice_repr(THPDevice *self)
4343
return THPUtils_packString(oss.str().c_str());
4444
}
4545

46-
PyObject *THPDevice_str(THPDevice*self)
46+
PyObject *THPDevice_str(THPDevice *self)
4747
{
4848
std::ostringstream oss;
4949
if (!self->device.is_default) {
@@ -137,6 +137,29 @@ PyObject *THPDevice_rc(PyObject *a, PyObject *b, int op) {
137137
END_HANDLE_TH_ERRORS
138138
}
139139

140+
PyObject *THPDevice_reduce(THPDevice *self)
141+
{
142+
HANDLE_TH_ERRORS
143+
auto ret = THPObjectPtr{PyTuple_New(2)};
144+
if (!ret) throw python_error();
145+
146+
py::object torch_module = py::module::import("torch");
147+
py::object torch_device = torch_module.attr("device");
148+
PyTuple_SET_ITEM(ret.get(), 0, torch_device.release().ptr());
149+
150+
THPObjectPtr args;
151+
if (self->device.is_default) {
152+
args = THPObjectPtr{Py_BuildValue("(s)", deviceTypeString(self->device.type))};
153+
} else {
154+
args = THPObjectPtr{Py_BuildValue("(si)", deviceTypeString(self->device.type), self->device.index)};
155+
}
156+
if (!args) throw python_error();
157+
PyTuple_SET_ITEM(ret.get(), 1, args.release());
158+
159+
return ret.release();
160+
END_HANDLE_TH_ERRORS
161+
}
162+
140163
typedef PyObject *(*getter)(PyObject *, void *);
141164

142165
static struct PyGetSetDef THPDevice_properties[] = {
@@ -145,6 +168,11 @@ static struct PyGetSetDef THPDevice_properties[] = {
145168
{nullptr}
146169
};
147170

171+
static PyMethodDef THPDevice_methods[] = {
172+
{"__reduce__", (PyCFunction)THPDevice_reduce, METH_NOARGS, nullptr},
173+
{NULL} /* Sentinel */
174+
};
175+
148176
PyTypeObject THPDeviceType = {
149177
PyVarObject_HEAD_INIT(nullptr, 0)
150178
"torch.device", /* tp_name */
@@ -173,7 +201,7 @@ PyTypeObject THPDeviceType = {
173201
0, /* tp_weaklistoffset */
174202
0, /* tp_iter */
175203
0, /* tp_iternext */
176-
0, /* tp_methods */
204+
THPDevice_methods, /* tp_methods */
177205
0, /* tp_members */
178206
THPDevice_properties, /* tp_getset */
179207
0, /* tp_base */

0 commit comments

Comments
 (0)