@@ -43,7 +43,7 @@ PyObject *THPDevice_repr(THPDevice *self)
4343 return THPUtils_packString (oss.str ().c_str ());
4444}
4545
46- PyObject *THPDevice_str (THPDevice*self)
46+ PyObject *THPDevice_str (THPDevice *self)
4747{
4848 std::ostringstream oss;
4949 if (!self->device .is_default ) {
@@ -137,6 +137,29 @@ PyObject *THPDevice_rc(PyObject *a, PyObject *b, int op) {
137137 END_HANDLE_TH_ERRORS
138138}
139139
140+ PyObject *THPDevice_reduce (THPDevice *self)
141+ {
142+ HANDLE_TH_ERRORS
143+ auto ret = THPObjectPtr{PyTuple_New (2 )};
144+ if (!ret) throw python_error ();
145+
146+ py::object torch_module = py::module::import (" torch" );
147+ py::object torch_device = torch_module.attr (" device" );
148+ PyTuple_SET_ITEM (ret.get (), 0 , torch_device.release ().ptr ());
149+
150+ THPObjectPtr args;
151+ if (self->device .is_default ) {
152+ args = THPObjectPtr{Py_BuildValue (" (s)" , deviceTypeString (self->device .type ))};
153+ } else {
154+ args = THPObjectPtr{Py_BuildValue (" (si)" , deviceTypeString (self->device .type ), self->device .index )};
155+ }
156+ if (!args) throw python_error ();
157+ PyTuple_SET_ITEM (ret.get (), 1 , args.release ());
158+
159+ return ret.release ();
160+ END_HANDLE_TH_ERRORS
161+ }
162+
140163typedef PyObject *(*getter)(PyObject *, void *);
141164
142165static struct PyGetSetDef THPDevice_properties[] = {
@@ -145,6 +168,11 @@ static struct PyGetSetDef THPDevice_properties[] = {
145168 {nullptr }
146169};
147170
171+ static PyMethodDef THPDevice_methods[] = {
172+ {" __reduce__" , (PyCFunction)THPDevice_reduce, METH_NOARGS, nullptr },
173+ {NULL } /* Sentinel */
174+ };
175+
148176PyTypeObject THPDeviceType = {
149177 PyVarObject_HEAD_INIT (nullptr , 0 )
150178 " torch.device" , /* tp_name */
@@ -173,7 +201,7 @@ PyTypeObject THPDeviceType = {
173201 0 , /* tp_weaklistoffset */
174202 0 , /* tp_iter */
175203 0 , /* tp_iternext */
176- 0 , /* tp_methods */
204+ THPDevice_methods, /* tp_methods */
177205 0 , /* tp_members */
178206 THPDevice_properties, /* tp_getset */
179207 0 , /* tp_base */
0 commit comments