Skip to content

Commit 3eca907

Browse files
committed
address comments
1 parent 115c8d5 commit 3eca907

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

aten/src/ATen/DeviceGuard.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ struct DeviceGuard {
5656
}
5757
}
5858

59-
/// Sets the device to the given one if its index is not `nullopt`.
59+
/// Sets the device to the given one.
6060
void set_index(int32_t index) {
6161
if (index == -1) {
6262
return;

torch/csrc/cuda/comm.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,12 @@ std::vector<Tensor> broadcast(const Tensor& tensor, IntList devices) {
3535
"first on devices list");
3636
std::vector<Tensor> tensors;
3737
tensors.reserve(devices.size());
38+
at::DeviceGuard _device_guard;
3839
#ifdef USE_NCCL
3940
if (nccl::is_available({tensor})) {
4041
tensors.push_back(tensor);
4142
for (auto device : devices.slice(1)) {
42-
at::DeviceGuard _device_guard(device);
43+
_device_guard.set_index(device);
4344
tensors.push_back(type.tensor(tensor.sizes()));
4445
}
4546
nccl::broadcast(tensors);
@@ -50,12 +51,10 @@ std::vector<Tensor> broadcast(const Tensor& tensor, IntList devices) {
5051
auto & gpu_type = type.toBackend(type.is_sparse() ? at::kSparseCUDA : at::kCUDA);
5152
if (type.is_cuda()) {
5253
tensors.push_back(tensor);
53-
} else {
54-
AutoGPU _gpu_guard(devices[0]);
55-
tensors.push_back(gpu_type.copy(tensor, true));
5654
}
57-
for (auto device : devices.slice(1)) {
58-
AutoGPU _gpu_guard(device);
55+
IntList loop_devices = type.is_cuda() ? devices.slice(1) : devices;
56+
for (auto device : loop_devices) {
57+
_device_guard.set_index(device);
5958
tensors.push_back(gpu_type.copy(tensor, true));
6059
}
6160
}

0 commit comments

Comments
 (0)