Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions aten/src/THC/generic/THCTensorCopy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ void THTensor_(copyCuda)(THCState *state, THTensor *self, struct THCTensor *src)

{
THTensor *selfc = THTensor_(newContiguous)(self);
int tensorDevice = THCTensor_(getDevice)(state, src);

This comment was marked as off-topic.

int currentDevice;
THCudaCheck(cudaGetDevice(&currentDevice));

if (currentDevice != tensorDevice) {
THCudaCheck(cudaSetDevice(tensorDevice));
}
src = THCTensor_(newContiguous)(state, src);

cudaStream_t stream = THCState_getCurrentStream(state);
Expand All @@ -68,6 +75,10 @@ void THTensor_(copyCuda)(THCState *state, THTensor *self, struct THCTensor *src)
stream));
THCudaCheck(cudaStreamSynchronize(stream));

if (currentDevice != tensorDevice) {
THCudaCheck(cudaSetDevice(currentDevice));
}

THCTensor_(free)(state, src);
THTensor_(freeCopyTo)(selfc, self);
}
Expand Down
1 change: 1 addition & 0 deletions test/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
'cpp_extensions',
'c10d',
'cuda',
'cuda_primary_ctx',
'dataloader',
'distributed',
'distributions',
Expand Down
58 changes: 58 additions & 0 deletions test/test_cuda_primary_ctx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import ctypes
import torch
from common import TestCase, run_tests, skipIfRocm
import unittest

# NOTE: this needs to be run in a brand new process

# We cannot import TEST_CUDA and TEST_MULTIGPU from common_cuda here,
# because if we do that, the TEST_CUDNN line from common_cuda will be executed
# multiple times as well during the execution of this test suite, and it will
# cause CUDA OOM error on Windows.
TEST_CUDA = torch.cuda.is_available()
TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2

if not TEST_CUDA:
print('CUDA not available, skipping tests')
TestCase = object # noqa: F811


def get_is_primary_context_created(device):
flags = ctypes.cast((ctypes.c_uint * 1)(), ctypes.POINTER(ctypes.c_uint))
active = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
result = torch.cuda.cudart().cuDevicePrimaryCtxGetState(ctypes.c_int(device), flags, active)
assert result == 0, 'cuDevicePrimaryCtxGetState failed'
return bool(active[0])


class TestCudaPrimaryCtx(TestCase):
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
@skipIfRocm
def test_cuda_primary_ctx(self):
# Ensure context has not been created beforehand
self.assertFalse(get_is_primary_context_created(0))
self.assertFalse(get_is_primary_context_created(1))

x = torch.randn(1, device='cuda:1')

# We should have only created context on 'cuda:1'
self.assertFalse(get_is_primary_context_created(0))
self.assertTrue(get_is_primary_context_created(1))

print(x)

# We should still have only created context on 'cuda:1'
self.assertFalse(get_is_primary_context_created(0))
self.assertTrue(get_is_primary_context_created(1))

y = torch.randn(1, device='cpu')
y.copy_(x)

# We should still have only created context on 'cuda:1'
self.assertFalse(get_is_primary_context_created(0))
self.assertTrue(get_is_primary_context_created(1))

# DO NOT ADD ANY OTHER TESTS HERE! ABOVE TEST REQUIRES FRESH PROCESS

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.


if __name__ == '__main__':
run_tests()