Skip to content

Commit 1c3580b

Browse files
vishwakftwfacebook-github-bot
authored andcommitted
Added hash for device (#9246)
Summary: If this is good, I could write some tests to ensure collision doesn't occur within a given range. Closes #7228 Pull Request resolved: #9246 Differential Revision: D8872608 Pulled By: ezyang fbshipit-source-id: 0ed29a73188f4167b42756f59a5c9a3d5cb37326
1 parent 5c695e3 commit 1c3580b

File tree

3 files changed

+30
-1
lines changed

3 files changed

+30
-1
lines changed

aten/src/ATen/Device.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <cstddef>
77
#include <iosfwd>
88
#include <string>
9+
#include <functional>
910

1011
namespace at {
1112
/// Represents a a compute device on which a tensor is located. A device is
@@ -112,3 +113,16 @@ struct Device {
112113

113114
std::ostream& operator<<(std::ostream& stream, at::Device::Type type);
114115
std::ostream& operator<<(std::ostream& stream, const at::Device& device);
116+
117+
namespace std {
118+
template<> struct hash<at::Device>
119+
{
120+
size_t operator()(const at::Device& device) const noexcept {
121+
size_t hash_val = static_cast<size_t>(device.index() + 1);
122+
if (device.is_cuda()) {
123+
hash_val += 2;
124+
}
125+
return hash_val;
126+
}
127+
};
128+
} // namespace std

test/test_torch.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1909,6 +1909,12 @@ def test_device(self):
19091909
self.assertRaises(TypeError, lambda: torch.device('other'))
19101910
self.assertRaises(TypeError, lambda: torch.device('other:0'))
19111911

1912+
device_set = {'cpu', 'cpu:0', 'cuda', 'cuda:0', 'cuda:1', 'cuda:10', 'cuda:100'}
1913+
device_hash_set = set()
1914+
for device in list(device_set):
1915+
device_hash_set.add(hash(torch.device(device)))
1916+
self.assertEqual(len(device_set), len(device_hash_set))
1917+
19121918
def test_tensor_device(self):
19131919
def assertEqual(device_str, fn):
19141920
self.assertEqual(torch.device(device_str), fn().device)

torch/csrc/Device.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
#include "torch/csrc/utils/object_ptr.h"
55
#include "torch/csrc/utils/python_arg_parser.h"
66
#include "torch/csrc/utils/python_strings.h"
7+
#include "torch/csrc/utils/python_numbers.h"
78
#include "torch/csrc/utils/pybind.h"
89

910
#include <ATen/Device.h>
1011
#include <ATen/Error.h>
1112

1213
#include <cstring>
14+
#include <limits>
1315
#include <structmember.h>
1416
#include <sstream>
1517

@@ -95,6 +97,13 @@ PyObject *THPDevice_index(THPDevice *self)
9597
END_HANDLE_TH_ERRORS
9698
}
9799

100+
static Py_ssize_t THPDevice_hash(THPDevice *self)
101+
{
102+
HANDLE_TH_ERRORS
103+
return static_cast<Py_ssize_t>(std::hash<at::Device>{}(self->device) % std::numeric_limits<Py_ssize_t>::max());
104+
END_HANDLE_TH_ERRORS_RET(-1)
105+
}
106+
98107
PyObject *THPDevice_rc(PyObject *a, PyObject *b, int op) {
99108
HANDLE_TH_ERRORS
100109
if (!THPDevice_Check(a) || !THPDevice_Check(b)) {
@@ -181,7 +190,7 @@ PyTypeObject THPDeviceType = {
181190
0, /* tp_as_number */
182191
0, /* tp_as_sequence */
183192
0, /* tp_as_mapping */
184-
0, /* tp_hash */
193+
(hashfunc)THPDevice_hash, /* tp_hash */
185194
0, /* tp_call */
186195
(reprfunc)THPDevice_str, /* tp_str */
187196
0, /* tp_getattro */

0 commit comments

Comments
 (0)