Skip to content

Commit 3bcc3cd

Browse files
Shan19900305pytorchmergebot
authored andcommitted
Using scalarType instead string in function _group_tensors_by_device_and_dtype. (#127869)
Now torch.dtype can pass through pybind11, so modify function _group_tensors_by_device_and_dtype to using scalar type. And without convert torch.dtype and string in python and c++ side. @ezyang @bdhirsh Pull Request resolved: #127869 Approved by: https://github.com/ezyang
1 parent 0ff6023 commit 3bcc3cd

File tree

3 files changed

+4
-46
lines changed

3 files changed

+4
-46
lines changed

torch/_C/__init__.pyi.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1195,7 +1195,7 @@ def _conv_determine_backend_memory_format(
11951195
def _has_storage(x: Tensor) -> _bool: ...
11961196
def _construct_storage_from_data_pointer(data_ptr: _int, device: torch.device, size: _int) -> Storage: ...
11971197
def _should_allow_numbers_as_tensors(func_name: str) -> _bool: ...
1198-
def _group_tensors_by_device_and_dtype(nested_tensorlists: List[List[Optional[Tensor]]], with_indices: _bool = False) -> Dict[Tuple[torch.device, str], Tuple[List[List[Optional[Tensor]]], List[_int]]]: ...
1198+
def _group_tensors_by_device_and_dtype(nested_tensorlists: List[List[Optional[Tensor]]], with_indices: _bool = False) -> Dict[Tuple[torch.device, torch.dtype], Tuple[List[List[Optional[Tensor]]], List[_int]]]: ...
11991199
def _check_tp_alloc_is_default(cls: Type) -> _bool: ...
12001200

12011201
# NB: There is no Capsule type in typing, see

torch/csrc/Module.cpp

Lines changed: 2 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2154,50 +2154,13 @@ Call this whenever a new thread is created in order to propagate values from
21542154
return torch::should_allow_numbers_as_tensors(name);
21552155
});
21562156

2157-
// FIXME(crcrpar): Better to have `at::ScalarType` get mapped to `torch.dtype`
2158-
// Currently I see the second item of the key is displayed as
2159-
// e.g. `torch._C._te.ScalarType at 0x7fcf318adab0`
2160-
// I thought adding an appropriate type_caster of `at::ScalarType` to
2161-
// torch/csrc/pybind.h` would solve this but it caused segmentation fault in
2162-
// my environment.
2163-
using _DeviceDtypeKey = std::pair<at::Device, std::string>;
2164-
// Custom hasher is necessary to make unordered_map compilable for Windows
2165-
// debug targets. As `at::native::ParamsHash` only works on structs with
2166-
// standard layout, but std::string isn't one in Visual C++ debug builds,
2167-
// which one can easily verify by running something like:
2168-
// #define _DEBUG
2169-
// #include <type_traits>
2170-
// #include <string>
2171-
// static_assert(std::is_standard_layout_v<std::string>, "Oh noes");
2172-
// If above condition is not met, VC++ raises a very cryptic compilation
2173-
// error. See
2174-
// https://github.com/pytorch/pytorch/pull/100007#discussion_r1227116292 for
2175-
// more detail
2176-
struct _DeviceDtypeHasher {
2177-
std::size_t operator()(const _DeviceDtypeKey& k) const noexcept {
2178-
static at::native::ParamsHash<at::Device> device_hasher;
2179-
static std::hash<std::string> string_hasher;
2180-
return device_hasher(k.first) ^ string_hasher(k.second);
2181-
}
2182-
};
2183-
using _FlatMap = std::unordered_map<
2184-
_DeviceDtypeKey,
2185-
at::native::TensorsAndIndicesT,
2186-
_DeviceDtypeHasher>;
21872157
py_module.def(
21882158
"_group_tensors_by_device_and_dtype",
21892159
[](const std::vector<std::vector<std::optional<at::Tensor>>>&
21902160
nested_tensorlist,
21912161
const bool with_indices) {
2192-
_FlatMap map;
2193-
for (const auto& iter :
2194-
at::native::_group_tensors_by_first_tensors_device_and_dtype(
2195-
nested_tensorlist, with_indices)) {
2196-
const auto scalar_type_name =
2197-
torch::utils::getDtypeNames(iter.first.second).first;
2198-
map.insert({{iter.first.first, scalar_type_name}, iter.second});
2199-
}
2200-
return map;
2162+
return at::native::_group_tensors_by_first_tensors_device_and_dtype(
2163+
nested_tensorlist, with_indices);
22012164
});
22022165

22032166
py_module.def(

torch/utils/_foreach_utils.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,7 @@ def _group_tensors_by_device_and_dtype(
3434
tensorlistlist: TensorListList,
3535
with_indices: bool = False,
3636
) -> Dict[Tuple[torch.device, torch.dtype], Tuple[TensorListList, Indices]]:
37-
return {
38-
(device, getattr(torch, str_dtype)): value
39-
for (device, str_dtype), value in
40-
torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices).items()
41-
}
42-
37+
return torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices)
4338

4439
def _device_has_foreach_support(device: torch.device) -> bool:
4540
return device.type in (_get_foreach_kernels_supported_devices() + ["cpu"]) and not torch.jit.is_scripting()

0 commit comments

Comments
 (0)