Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -1195,7 +1195,7 @@ def _conv_determine_backend_memory_format(
def _has_storage(x: Tensor) -> _bool: ...
def _construct_storage_from_data_pointer(data_ptr: _int, device: torch.device, size: _int) -> Storage: ...
def _should_allow_numbers_as_tensors(func_name: str) -> _bool: ...
def _group_tensors_by_device_and_dtype(nested_tensorlists: List[List[Optional[Tensor]]], with_indices: _bool = False) -> Dict[Tuple[torch.device, str], Tuple[List[List[Optional[Tensor]]], List[_int]]]: ...
def _group_tensors_by_device_and_dtype(nested_tensorlists: List[List[Optional[Tensor]]], with_indices: _bool = False) -> Dict[Tuple[torch.device, torch.dtype], Tuple[List[List[Optional[Tensor]]], List[_int]]]: ...
def _check_tp_alloc_is_default(cls: Type) -> _bool: ...

# NB: There is no Capsule type in typing, see
Expand Down
41 changes: 2 additions & 39 deletions torch/csrc/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2154,50 +2154,13 @@ Call this whenever a new thread is created in order to propagate values from
return torch::should_allow_numbers_as_tensors(name);
});

// FIXME(crcrpar): Better to have `at::ScalarType` get mapped to `torch.dtype`
// Currently I see the second item of the key is displayed as
// e.g. `torch._C._te.ScalarType at 0x7fcf318adab0`
// I thought adding an appropriate type_caster of `at::ScalarType` to
// torch/csrc/pybind.h` would solve this but it caused segmentation fault in
// my environment.
using _DeviceDtypeKey = std::pair<at::Device, std::string>;
// Custom hasher is necessary to make unordered_map compilable for Windows
// debug targets. As `at::native::ParamsHash` only works on structs with
// standard layout, but std::string isn't one in Visual C++ debug builds,
// which one can easily verify by running something like:
// #define _DEBUG
// #include <type_traits>
// #include <string>
// static_assert(std::is_standard_layout_v<std::string>, "Oh noes");
// If above condition is not met, VC++ raises a very cryptic compilation
// error. See
// https://github.com/pytorch/pytorch/pull/100007#discussion_r1227116292 for
// more detail
struct _DeviceDtypeHasher {
std::size_t operator()(const _DeviceDtypeKey& k) const noexcept {
static at::native::ParamsHash<at::Device> device_hasher;
static std::hash<std::string> string_hasher;
return device_hasher(k.first) ^ string_hasher(k.second);
}
};
using _FlatMap = std::unordered_map<
_DeviceDtypeKey,
at::native::TensorsAndIndicesT,
_DeviceDtypeHasher>;
py_module.def(
"_group_tensors_by_device_and_dtype",
[](const std::vector<std::vector<std::optional<at::Tensor>>>&
nested_tensorlist,
const bool with_indices) {
_FlatMap map;
for (const auto& iter :
at::native::_group_tensors_by_first_tensors_device_and_dtype(
nested_tensorlist, with_indices)) {
const auto scalar_type_name =
torch::utils::getDtypeNames(iter.first.second).first;
map.insert({{iter.first.first, scalar_type_name}, iter.second});
}
return map;
return at::native::_group_tensors_by_first_tensors_device_and_dtype(
nested_tensorlist, with_indices);
});

py_module.def(
Expand Down
7 changes: 1 addition & 6 deletions torch/utils/_foreach_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,7 @@ def _group_tensors_by_device_and_dtype(
tensorlistlist: TensorListList,
with_indices: bool = False,
) -> Dict[Tuple[torch.device, torch.dtype], Tuple[TensorListList, Indices]]:
return {
(device, getattr(torch, str_dtype)): value
for (device, str_dtype), value in
torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices).items()
}

return torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices)

def _device_has_foreach_support(device: torch.device) -> bool:
return device.type in (_get_foreach_kernels_supported_devices() + ["cpu"]) and not torch.jit.is_scripting()
Expand Down