-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Bug
When spawning multiple processes (e.g. for training models on different GPUs for hyperparameter search) that have their own dataloader with pin_memory = True and where the model uses torch.cat(), multiple CUDA contexts are created on the first GPU.
To Reproduce
Execute following code sample and play with use_cat and pin_memory variables. If you set one of these to False, no additional context is created on GPU 0.
import torch
import torch.nn as nn
import torch.multiprocessing as mp
from torch.utils.data import DataLoader, Dataset
ctx = mp.get_context("spawn")
class RandomSet(Dataset):
def __getitem__(self, item):
return torch.randn(256,256,256)
def __len__(self):
return 1000
class CatModel(nn.Module):
def __init__(self, use_cat=True):
super(CatModel, self).__init__()
self.use_cat = use_cat
if use_cat:
self.conv = nn.Conv2d(512, 512, 3)
else:
self.conv = nn.Conv2d(256, 512, 3)
def forward(self, x):
if self.use_cat:
x = torch.cat([x, x], dim=1)
return self.conv(x)
# change either one to False: No additional CUDA context created
use_cat = True
pin_memory = True
class Trainer(ctx.Process):
def __init__(self, device):
super(Trainer, self).__init__()
self.device = device
self.dataset = RandomSet()
self.dataloder = DataLoader(
self.dataset,
pin_memory=pin_memory,
batch_size=8,
num_workers=8
)
def run(self):
self.model = CatModel(use_cat).to(self.device)
# loop
for x in self.dataloder:
x = x.to(self.device)
y = self.model(x)
if __name__ == "__main__":
trainer1 = Trainer(0)
trainer2 = Trainer(1)
trainer1.start()
trainer2.start()
trainer1.join()
trainer2.join()Expected behavior
No additional CUDA context should be created when using cat in combination with pin_memory on a multi-processing system.
Environment
PyTorch version: 1.3.1
Is debug build: No
CUDA used to build PyTorch: 10.1.243
OS: Ubuntu 18.04.2 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: version 3.10.2
Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: GeForce RTX 2080 Ti
GPU 1: GeForce RTX 2080 Ti
GPU 2: GeForce RTX 2080 Ti
GPU 3: GeForce RTX 2080 Ti
GPU 4: GeForce RTX 2080 Ti
GPU 5: GeForce RTX 2080 Ti
GPU 6: GeForce RTX 2080 Ti
GPU 7: GeForce RTX 2080 Ti
Nvidia driver version: 440.33.01
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
Versions of relevant libraries:
[pip3] numpy==1.18.1
[pip3] torch==1.2.0
[pip3] torchvision==0.4.0
[conda] blas 1.0 mkl
[conda] mkl 2019.4 243
[conda] mkl-service 2.0.2 py37h7b6447c_0
[conda] mkl_fft 1.0.12 py37ha843d7b_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] pytorch 1.3.1 py3.7_cuda10.1.243_cudnn7.6.3_0 pytorch
[conda] torchvision 0.4.2 py37_cu101 pytorch
Additional context
I also tried with latest pytorch 1.4 -> still happening.