Skip to content

Commit 97e63b4

Browse files
committed
repush commit
1 parent ce17bb9 commit 97e63b4

File tree

2 files changed

+25
-37
lines changed

2 files changed

+25
-37
lines changed

test/test_torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def test_storage_meta_errors(self, device, dtype):
358358
s0.tolist()
359359

360360
with tempfile.NamedTemporaryFile() as f:
361-
with self.assertRaisesRegex(RuntimeError, r'Device not recognized'):
361+
with self.assertRaisesRegex(NotImplementedError, r'Cannot copy out'):
362362
s0._write_file(f, True, True, s0.element_size())
363363

364364
for device in ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']:

torch/csrc/serialization.cpp

Lines changed: 24 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <torch/csrc/python_headers.h>
22
#include <system_error>
33

4+
#include <ATen/ops/from_blob.h>
45
#include <c10/core/CPUAllocator.h>
56
#include <torch/csrc/THP.h>
67
#include <torch/csrc/serialization.h>
@@ -228,32 +229,22 @@ void THPStorage_writeFileRaw(
228229
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
229230
uint8_t* data;
230231
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
231-
std::unique_ptr<char[]> cpu_data;
232+
at::Tensor cpu_tensor;
232233
int64_t size_bytes = self->nbytes();
233234
int64_t numel = size_bytes / element_size;
234235
if (self->device_type() == at::kCPU) {
235236
data = self->data<uint8_t>();
236-
#if defined(USE_CUDA) && defined(TORCH_HIP_VERSION) && \
237-
(TORCH_HIP_VERSION >= 301)
238-
} else if (self->device_type() == at::kCUDA) {
239-
cpu_data = std::unique_ptr<char[]>(new char[size_bytes]);
240-
data = (uint8_t*)cpu_data.get();
241-
C10_CUDA_CHECK(hipMemcpyWithStream(
242-
data,
243-
self->data<uint8_t>(),
244-
size_bytes,
245-
cudaMemcpyDeviceToHost,
246-
c10::hip::getCurrentHIPStreamMasqueradingAsCUDA()));
247-
#elif defined(USE_CUDA)
248-
} else if (self->device_type() == at::kCUDA) {
249-
cpu_data = std::unique_ptr<char[]>(new char[size_bytes]);
250-
data = (uint8_t*)cpu_data.get();
251-
C10_CUDA_CHECK(cudaMemcpy(
252-
data, self->data<uint8_t>(), size_bytes, cudaMemcpyDeviceToHost));
253-
#endif
254237
} else {
255-
TORCH_CHECK(
256-
false, "writeFileRaw: Device not recognized: ", self->device_type());
238+
// Here we use a tensor.to() to impl D2H for all non-CPU device.
239+
auto device_tensor = at::from_blob(
240+
self->data<void>(),
241+
{size_bytes},
242+
{1},
243+
NULL,
244+
at::device(self->device()).dtype(c10::kByte),
245+
{self->device()});
246+
cpu_tensor = device_tensor.to(at::kCPU);
247+
data = (uint8_t*)cpu_tensor.data_ptr();
257248
}
258249
if (save_size) {
259250
if (torch::utils::THP_nativeByteOrder() ==
@@ -409,22 +400,19 @@ c10::intrusive_ptr<c10::StorageImpl> THPStorage_readFileRaw(
409400
}
410401
}
411402

412-
#if defined(USE_CUDA) && defined(TORCH_HIP_VERSION) && \
413-
(TORCH_HIP_VERSION >= 301)
414-
if (storage->device_type() == at::kCUDA) {
415-
C10_CUDA_CHECK(hipMemcpyWithStream(
416-
storage->data<uint8_t>(),
417-
data,
418-
nbytes,
419-
cudaMemcpyHostToDevice,
420-
c10::hip::getCurrentHIPStreamMasqueradingAsCUDA()));
421-
}
422-
#elif defined(USE_CUDA)
423-
if (storage->device_type() == at::kCUDA) {
424-
C10_CUDA_CHECK(cudaMemcpy(
425-
storage->data<uint8_t>(), data, nbytes, cudaMemcpyHostToDevice));
403+
if (storage->device_type() != at::kCPU) {
404+
// Here we use a tensor.copy_() to impl H2D for all non-CPU device.
405+
auto cpu_tensor = at::from_blob(
406+
(void*)data, {nbytes}, at::device(at::kCPU).dtype(c10::kByte));
407+
auto device_tensor = at::from_blob(
408+
storage->data<void>(),
409+
{nbytes},
410+
{1},
411+
NULL,
412+
at::device(storage->device()).dtype(c10::kByte),
413+
{storage->device()});
414+
device_tensor.copy_(cpu_tensor);
426415
}
427-
#endif
428416
return storage;
429417
}
430418

0 commit comments

Comments
 (0)