File tree Expand file tree Collapse file tree 2 files changed +11
-4
lines changed
Expand file tree Collapse file tree 2 files changed +11
-4
lines changed Original file line number Diff line number Diff line change @@ -369,7 +369,7 @@ def broadcast(a, b):
369369 @unittest .skipIf (IS_WINDOWS , "NYI: fuser support for Windows" )
370370 @unittest .skipIf (not RUN_CUDA_MULTI_GPU , "needs non-zero device" )
371371 def test_fuse_last_device (self ):
372- device = 'cuda:' + str (torch . cuda . device_count () - 1 )
372+ device = 'cuda:' + str (1 )
373373 x = torch .tensor ([0.4 ], dtype = torch .float , device = device )
374374 y = torch .tensor ([0.7 ], dtype = torch .float , device = device )
375375
Original file line number Diff line number Diff line change 88
99#include " ATen/ATen.h"
1010#ifdef WITH_CUDA
11+ #include " THC/THC.h"
1112#include " torch/csrc/cuda/cuda_check.h"
1213#include < nvrtc.h>
1314#include < cuda.h>
@@ -550,13 +551,19 @@ struct CUDAFusionFunction : public CompiledFusionFunction {
550551 // it is possible that this is the first cuda call on this thread
551552 // so make sure we initialize the Driver API's context
552553 // cudaFree(0) accomplishes this.
553- cudaFree (0 );
554-
554+ CUcontext pctx = 0 ;
555+ TORCH_CU_CHECK (cuCtxGetCurrent (&pctx));
556+ if (!pctx) {
557+ std::unique_lock<std::mutex> cudaFreeMutexLock (
558+ *(THCCachingAllocator_getCudaFreeMutex ()));
559+ cudaFree (0 );
560+ }
561+ CUstream stream = at::globalContext ().getCurrentCUDAStream ();
555562 TORCH_CU_CHECK (cuLaunchKernel (
556563 function,
557564 numBlocks, 1 , 1 ,
558565 blockSize, 1 , 1 ,
559- 0 , nullptr ,
566+ 0 , stream ,
560567 arguments,
561568 nullptr ));
562569 }
You can’t perform that action at this time.
0 commit comments