Skip to content

Commit a902c0c

Browse files
committed
Fix distributed dataloader so it pins memory to current GPU not GPU 0.
1 parent 6d72c82 commit a902c0c

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

torch/utils/data/dataloader.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,10 @@ def _worker_loop(dataset, index_queue, data_queue, collate_fn):
5454
data_queue.put((idx, samples))
5555

5656

57-
def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory):
57+
def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id):
58+
if pin_memory:
59+
torch.cuda.set_device(device_id)
60+
5861
while True:
5962
try:
6063
r = in_queue.get()
@@ -176,7 +179,7 @@ def __init__(self, loader):
176179
self.collate_fn = loader.collate_fn
177180
self.batch_sampler = loader.batch_sampler
178181
self.num_workers = loader.num_workers
179-
self.pin_memory = loader.pin_memory
182+
self.pin_memory = loader.pin_memory and torch.cuda.is_available()
180183
self.timeout = loader.timeout
181184
self.done_event = threading.Event()
182185

@@ -206,7 +209,8 @@ def __init__(self, loader):
206209
self.data_queue = queue.Queue()
207210
self.worker_manager_thread = threading.Thread(
208211
target=_worker_manager_loop,
209-
args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory))
212+
args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
213+
torch.cuda.current_device()))
210214
self.worker_manager_thread.daemon = True
211215
self.worker_manager_thread.start()
212216
else:

0 commit comments

Comments
 (0)