|
1 | 1 | #include <torch/csrc/python_headers.h> |
2 | 2 | #include <system_error> |
3 | 3 |
|
| 4 | +#include <ATen/ops/from_blob.h> |
4 | 5 | #include <c10/core/CPUAllocator.h> |
5 | 6 | #include <torch/csrc/THP.h> |
6 | 7 | #include <torch/csrc/serialization.h> |
@@ -228,32 +229,22 @@ void THPStorage_writeFileRaw( |
228 | 229 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
229 | 230 | uint8_t* data; |
230 | 231 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
231 | | - std::unique_ptr<char[]> cpu_data; |
| 232 | + at::Tensor cpu_tensor; |
232 | 233 | int64_t size_bytes = self->nbytes(); |
233 | 234 | int64_t numel = size_bytes / element_size; |
234 | 235 | if (self->device_type() == at::kCPU) { |
235 | 236 | data = self->data<uint8_t>(); |
236 | | -#if defined(USE_CUDA) && defined(TORCH_HIP_VERSION) && \ |
237 | | - (TORCH_HIP_VERSION >= 301) |
238 | | - } else if (self->device_type() == at::kCUDA) { |
239 | | - cpu_data = std::unique_ptr<char[]>(new char[size_bytes]); |
240 | | - data = (uint8_t*)cpu_data.get(); |
241 | | - C10_CUDA_CHECK(hipMemcpyWithStream( |
242 | | - data, |
243 | | - self->data<uint8_t>(), |
244 | | - size_bytes, |
245 | | - cudaMemcpyDeviceToHost, |
246 | | - c10::hip::getCurrentHIPStreamMasqueradingAsCUDA())); |
247 | | -#elif defined(USE_CUDA) |
248 | | - } else if (self->device_type() == at::kCUDA) { |
249 | | - cpu_data = std::unique_ptr<char[]>(new char[size_bytes]); |
250 | | - data = (uint8_t*)cpu_data.get(); |
251 | | - C10_CUDA_CHECK(cudaMemcpy( |
252 | | - data, self->data<uint8_t>(), size_bytes, cudaMemcpyDeviceToHost)); |
253 | | -#endif |
254 | 237 | } else { |
255 | | - TORCH_CHECK( |
256 | | - false, "writeFileRaw: Device not recognized: ", self->device_type()); |
| 238 | + // Here we use a tensor.to() to impl D2H for all non-CPU device. |
| 239 | + auto device_tensor = at::from_blob( |
| 240 | + self->data<void>(), |
| 241 | + {size_bytes}, |
| 242 | + {1}, |
| 243 | + NULL, |
| 244 | + at::device(self->device()).dtype(c10::kByte), |
| 245 | + {self->device()}); |
| 246 | + cpu_tensor = device_tensor.to(at::kCPU); |
| 247 | + data = (uint8_t*)cpu_tensor.data_ptr(); |
257 | 248 | } |
258 | 249 | if (save_size) { |
259 | 250 | if (torch::utils::THP_nativeByteOrder() == |
@@ -409,22 +400,19 @@ c10::intrusive_ptr<c10::StorageImpl> THPStorage_readFileRaw( |
409 | 400 | } |
410 | 401 | } |
411 | 402 |
|
412 | | -#if defined(USE_CUDA) && defined(TORCH_HIP_VERSION) && \ |
413 | | - (TORCH_HIP_VERSION >= 301) |
414 | | - if (storage->device_type() == at::kCUDA) { |
415 | | - C10_CUDA_CHECK(hipMemcpyWithStream( |
416 | | - storage->data<uint8_t>(), |
417 | | - data, |
418 | | - nbytes, |
419 | | - cudaMemcpyHostToDevice, |
420 | | - c10::hip::getCurrentHIPStreamMasqueradingAsCUDA())); |
421 | | - } |
422 | | -#elif defined(USE_CUDA) |
423 | | - if (storage->device_type() == at::kCUDA) { |
424 | | - C10_CUDA_CHECK(cudaMemcpy( |
425 | | - storage->data<uint8_t>(), data, nbytes, cudaMemcpyHostToDevice)); |
| 403 | + if (storage->device_type() != at::kCPU) { |
| 404 | + // Here we use a tensor.copy_() to impl H2D for all non-CPU device. |
| 405 | + auto cpu_tensor = at::from_blob( |
| 406 | + (void*)data, {nbytes}, at::device(at::kCPU).dtype(c10::kByte)); |
| 407 | + auto device_tensor = at::from_blob( |
| 408 | + storage->data<void>(), |
| 409 | + {nbytes}, |
| 410 | + {1}, |
| 411 | + NULL, |
| 412 | + at::device(storage->device()).dtype(c10::kByte), |
| 413 | + {storage->device()}); |
| 414 | + device_tensor.copy_(cpu_tensor); |
426 | 415 | } |
427 | | -#endif |
428 | 416 | return storage; |
429 | 417 | } |
430 | 418 |
|
|
0 commit comments