Skip to content

Commit 115c8d5

Browse files
committed
revert dense_flat changes
1 parent e0919aa commit 115c8d5

File tree

2 files changed

+12
-18
lines changed

2 files changed

+12
-18
lines changed

torch/csrc/cuda/comm.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,15 @@ std::vector<Tensor> broadcast(const Tensor& tensor, IntList devices) {
4848
{
4949
#endif
5050
auto & gpu_type = type.toBackend(type.is_sparse() ? at::kSparseCUDA : at::kCUDA);
51-
for (auto device : devices) {
51+
if (type.is_cuda()) {
52+
tensors.push_back(tensor);
53+
} else {
54+
AutoGPU _gpu_guard(devices[0]);
55+
tensors.push_back(gpu_type.copy(tensor, true));
56+
}
57+
for (auto device : devices.slice(1)) {
5258
AutoGPU _gpu_guard(device);
53-
tensors.push_back(tensor.toType(gpu_type, true));
59+
tensors.push_back(gpu_type.copy(tensor, true));
5460
}
5561
}
5662
return tensors;

torch/csrc/utils/tensor_flatten.h

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,10 @@
99
namespace torch { namespace utils {
1010

1111
inline at::Tensor flatten_dense_tensors(at::TensorList tensors) {
12-
if (tensors.size() == 1) {
13-
return tensors[0].reshape({-1});
14-
} else {
15-
int64_t total_numel = 0;
16-
for (const auto & tensor : tensors) {
17-
total_numel += tensor.numel();
18-
}
19-
auto flat = tensors[0].type().tensor({total_numel});
20-
int64_t offset = 0;
21-
for (const auto & tensor : tensors) {
22-
auto numel = tensor.numel();
23-
flat.narrow(0, offset, numel).view_as(tensor).copy_(tensor);
24-
offset += numel;
25-
}
26-
return flat;
27-
}
12+
static auto flatten = [](const at::Tensor &t) { return t.contiguous().view({-1}); };
13+
if (tensors.size() == 1)
14+
return flatten(tensors[0]);
15+
return at::cat(fmap(tensors, flatten));
2816
}
2917

3018
inline std::vector<at::Tensor> unflatten_dense_tensors(const at::Tensor& flat, at::TensorList tensors) {

0 commit comments

Comments
 (0)