@@ -2154,50 +2154,13 @@ Call this whenever a new thread is created in order to propagate values from
21542154 return torch::should_allow_numbers_as_tensors (name);
21552155 });
21562156
2157- // FIXME(crcrpar): Better to have `at::ScalarType` get mapped to `torch.dtype`
2158- // Currently I see the second item of the key is displayed as
2159- // e.g. `torch._C._te.ScalarType at 0x7fcf318adab0`
2160- // I thought adding an appropriate type_caster of `at::ScalarType` to
2161- // torch/csrc/pybind.h` would solve this but it caused segmentation fault in
2162- // my environment.
2163- using _DeviceDtypeKey = std::pair<at::Device, std::string>;
2164- // Custom hasher is necessary to make unordered_map compilable for Windows
2165- // debug targets. As `at::native::ParamsHash` only works on structs with
2166- // standard layout, but std::string isn't one in Visual C++ debug builds,
2167- // which one can easily verify by running something like:
2168- // #define _DEBUG
2169- // #include <type_traits>
2170- // #include <string>
2171- // static_assert(std::is_standard_layout_v<std::string>, "Oh noes");
2172- // If above condition is not met, VC++ raises a very cryptic compilation
2173- // error. See
2174- // https://github.com/pytorch/pytorch/pull/100007#discussion_r1227116292 for
2175- // more detail
2176- struct _DeviceDtypeHasher {
2177- std::size_t operator ()(const _DeviceDtypeKey& k) const noexcept {
2178- static at::native::ParamsHash<at::Device> device_hasher;
2179- static std::hash<std::string> string_hasher;
2180- return device_hasher (k.first ) ^ string_hasher (k.second );
2181- }
2182- };
2183- using _FlatMap = std::unordered_map<
2184- _DeviceDtypeKey,
2185- at::native::TensorsAndIndicesT,
2186- _DeviceDtypeHasher>;
21872157 py_module.def (
21882158 " _group_tensors_by_device_and_dtype" ,
21892159 [](const std::vector<std::vector<std::optional<at::Tensor>>>&
21902160 nested_tensorlist,
21912161 const bool with_indices) {
2192- _FlatMap map;
2193- for (const auto & iter :
2194- at::native::_group_tensors_by_first_tensors_device_and_dtype (
2195- nested_tensorlist, with_indices)) {
2196- const auto scalar_type_name =
2197- torch::utils::getDtypeNames (iter.first .second ).first ;
2198- map.insert ({{iter.first .first , scalar_type_name}, iter.second });
2199- }
2200- return map;
2162+ return at::native::_group_tensors_by_first_tensors_device_and_dtype (
2163+ nested_tensorlist, with_indices);
22012164 });
22022165
22032166 py_module.def (
0 commit comments