@@ -35,11 +35,12 @@ std::vector<Tensor> broadcast(const Tensor& tensor, IntList devices) {
3535 " first on devices list" );
3636 std::vector<Tensor> tensors;
3737 tensors.reserve (devices.size ());
38+ at::DeviceGuard _device_guard;
3839#ifdef USE_NCCL
3940 if (nccl::is_available ({tensor})) {
4041 tensors.push_back (tensor);
4142 for (auto device : devices.slice (1 )) {
42- at::DeviceGuard _device_guard (device);
43+ _device_guard. set_index (device);
4344 tensors.push_back (type.tensor (tensor.sizes ()));
4445 }
4546 nccl::broadcast (tensors);
@@ -50,12 +51,10 @@ std::vector<Tensor> broadcast(const Tensor& tensor, IntList devices) {
5051 auto & gpu_type = type.toBackend (type.is_sparse () ? at::kSparseCUDA : at::kCUDA );
5152 if (type.is_cuda ()) {
5253 tensors.push_back (tensor);
53- } else {
54- AutoGPU _gpu_guard (devices[0 ]);
55- tensors.push_back (gpu_type.copy (tensor, true ));
5654 }
57- for (auto device : devices.slice (1 )) {
58- AutoGPU _gpu_guard (device);
55+ IntList loop_devices = type.is_cuda () ? devices.slice (1 ) : devices;
56+ for (auto device : loop_devices) {
57+ _device_guard.set_index (device);
5958 tensors.push_back (gpu_type.copy (tensor, true ));
6059 }
6160 }
0 commit comments