Skip to content

Commit e0919aa

Browse files
committed
use toType
1 parent 922402b commit e0919aa

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

torch/csrc/cuda/comm.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ std::vector<Tensor> broadcast(const Tensor& tensor, IntList devices) {
3535
"first on devices list");
3636
std::vector<Tensor> tensors;
3737
tensors.reserve(devices.size());
38-
tensors.push_back(tensor);
3938
#ifdef USE_NCCL
4039
if (nccl::is_available({tensor})) {
40+
tensors.push_back(tensor);
4141
for (auto device : devices.slice(1)) {
4242
at::DeviceGuard _device_guard(device);
4343
tensors.push_back(type.tensor(tensor.sizes()));
@@ -48,9 +48,9 @@ std::vector<Tensor> broadcast(const Tensor& tensor, IntList devices) {
4848
{
4949
#endif
5050
auto & gpu_type = type.toBackend(type.is_sparse() ? at::kSparseCUDA : at::kCUDA);
51-
for (auto device : devices.slice(1)) {
52-
at::DeviceGuard _device_guard(device);
53-
tensors.push_back(gpu_type.copy(tensor, true));
51+
for (auto device : devices) {
52+
AutoGPU _gpu_guard(device);
53+
tensors.push_back(tensor.toType(gpu_type, true));
5454
}
5555
}
5656
return tensors;

0 commit comments

Comments
 (0)