Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions aten/src/ATen/Device.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <cstddef>
#include <iosfwd>
#include <string>
#include <functional>

namespace at {
/// Represents a a compute device on which a tensor is located. A device is
Expand Down Expand Up @@ -112,3 +113,16 @@ struct Device {

std::ostream& operator<<(std::ostream& stream, at::Device::Type type);
std::ostream& operator<<(std::ostream& stream, const at::Device& device);

namespace std {
template<> struct hash<at::Device>
{
size_t operator()(const at::Device& device) const noexcept {
size_t hash_val = static_cast<size_t>(device.index() + 1);
if (device.is_cuda()) {
hash_val += 2;
}
return hash_val;
}
};
} // namespace std
6 changes: 6 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1819,6 +1819,12 @@ def test_device(self):
self.assertRaises(TypeError, lambda: torch.device('other'))
self.assertRaises(TypeError, lambda: torch.device('other:0'))

device_set = {'cpu', 'cpu:0', 'cuda', 'cuda:0', 'cuda:1', 'cuda:10', 'cuda:100'}
device_hash_set = set()
for device in list(device_set):
device_hash_set.add(hash(torch.device(device)))
self.assertEqual(len(device_set), len(device_hash_set))

def test_tensor_device(self):
def assertEqual(device_str, fn):
self.assertEqual(torch.device(device_str), fn().device)
Expand Down
11 changes: 10 additions & 1 deletion torch/csrc/Device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
#include "torch/csrc/utils/object_ptr.h"
#include "torch/csrc/utils/python_arg_parser.h"
#include "torch/csrc/utils/python_strings.h"
#include "torch/csrc/utils/python_numbers.h"
#include "torch/csrc/utils/pybind.h"

#include <ATen/Device.h>
#include <ATen/Error.h>

#include <cstring>
#include <limits>
#include <structmember.h>
#include <sstream>

Expand Down Expand Up @@ -95,6 +97,13 @@ PyObject *THPDevice_index(THPDevice *self)
END_HANDLE_TH_ERRORS
}

static Py_ssize_t THPDevice_hash(THPDevice *self)
{
HANDLE_TH_ERRORS
return static_cast<Py_ssize_t>(std::hash<at::Device>{}(self->device) % std::numeric_limits<Py_ssize_t>::max());
END_HANDLE_TH_ERRORS_RET(-1)
}

PyObject *THPDevice_rc(PyObject *a, PyObject *b, int op) {
HANDLE_TH_ERRORS
if (!THPDevice_Check(a) || !THPDevice_Check(b)) {
Expand Down Expand Up @@ -181,7 +190,7 @@ PyTypeObject THPDeviceType = {
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
(hashfunc)THPDevice_hash, /* tp_hash */
0, /* tp_call */
(reprfunc)THPDevice_str, /* tp_str */
0, /* tp_getattro */
Expand Down