Skip to content

Commit 1feb1a9

Browse files
ngimelsoumith
authored andcommitted
small fixes in fusion_compiler (#7776)
* small fixes in fusion_compiler * address review comments
1 parent 7d0de4f commit 1feb1a9

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

test/test_jit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def broadcast(a, b):
369369
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
370370
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
371371
def test_fuse_last_device(self):
372-
device = 'cuda:' + str(torch.cuda.device_count() - 1)
372+
device = 'cuda:' + str(1)
373373
x = torch.tensor([0.4], dtype=torch.float, device=device)
374374
y = torch.tensor([0.7], dtype=torch.float, device=device)
375375

torch/csrc/jit/fusion_compiler.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "ATen/ATen.h"
1010
#ifdef WITH_CUDA
11+
#include "THC/THC.h"
1112
#include "torch/csrc/cuda/cuda_check.h"
1213
#include <nvrtc.h>
1314
#include <cuda.h>
@@ -550,13 +551,19 @@ struct CUDAFusionFunction : public CompiledFusionFunction {
550551
// it is possible that this is the first cuda call on this thread
551552
// so make sure we initialize the Driver API's context
552553
// cudaFree(0) accomplishes this.
553-
cudaFree(0);
554-
554+
CUcontext pctx = 0;
555+
TORCH_CU_CHECK(cuCtxGetCurrent(&pctx));
556+
if (!pctx) {
557+
std::unique_lock<std::mutex> cudaFreeMutexLock(
558+
*(THCCachingAllocator_getCudaFreeMutex()));
559+
cudaFree(0);
560+
}
561+
CUstream stream = at::globalContext().getCurrentCUDAStream();
555562
TORCH_CU_CHECK(cuLaunchKernel(
556563
function,
557564
numBlocks, 1, 1,
558565
blockSize, 1, 1,
559-
0, nullptr,
566+
0, stream,
560567
arguments,
561568
nullptr));
562569
}

0 commit comments

Comments
 (0)