File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed
Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -35,9 +35,9 @@ std::vector<Tensor> broadcast(const Tensor& tensor, IntList devices) {
3535 " first on devices list" );
3636 std::vector<Tensor> tensors;
3737 tensors.reserve (devices.size ());
38- tensors.push_back (tensor);
3938#ifdef USE_NCCL
4039 if (nccl::is_available ({tensor})) {
40+ tensors.push_back (tensor);
4141 for (auto device : devices.slice (1 )) {
4242 at::DeviceGuard _device_guard (device);
4343 tensors.push_back (type.tensor (tensor.sizes ()));
@@ -48,9 +48,9 @@ std::vector<Tensor> broadcast(const Tensor& tensor, IntList devices) {
4848 {
4949#endif
5050 auto & gpu_type = type.toBackend (type.is_sparse () ? at::kSparseCUDA : at::kCUDA );
51- for (auto device : devices. slice ( 1 ) ) {
52- at::DeviceGuard _device_guard (device);
53- tensors.push_back (gpu_type. copy (tensor , true ));
51+ for (auto device : devices) {
52+ AutoGPU _gpu_guard (device);
53+ tensors.push_back (tensor. toType (gpu_type , true ));
5454 }
5555 }
5656 return tensors;
You can’t perform that action at this time.
0 commit comments