Skip to content

Commit bc6bd62

Browse files
csarofeenapaszke
authored andcommitted
Fix distributed dataloader so it pins memory to current GPU not GPU 0.
1 parent cb4f6c3 commit bc6bd62

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

torch/utils/data/dataloader.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,10 @@ def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, wo
6060
data_queue.put((idx, samples))
6161

6262

63-
def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory):
63+
def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id):
64+
if pin_memory:
65+
torch.cuda.set_device(device_id)
66+
6467
while True:
6568
try:
6669
r = in_queue.get()
@@ -182,7 +185,7 @@ def __init__(self, loader):
182185
self.collate_fn = loader.collate_fn
183186
self.batch_sampler = loader.batch_sampler
184187
self.num_workers = loader.num_workers
185-
self.pin_memory = loader.pin_memory
188+
self.pin_memory = loader.pin_memory and torch.cuda.is_available()
186189
self.timeout = loader.timeout
187190
self.done_event = threading.Event()
188191

@@ -211,7 +214,8 @@ def __init__(self, loader):
211214
self.data_queue = queue.Queue()
212215
self.worker_manager_thread = threading.Thread(
213216
target=_worker_manager_loop,
214-
args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory))
217+
args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
218+
torch.cuda.current_device()))
215219
self.worker_manager_thread.daemon = True
216220
self.worker_manager_thread.start()
217221
else:

0 commit comments

Comments
 (0)