@@ -54,7 +54,10 @@ def _worker_loop(dataset, index_queue, data_queue, collate_fn):
5454 data_queue .put ((idx , samples ))
5555
5656
57- def _worker_manager_loop (in_queue , out_queue , done_event , pin_memory ):
57+ def _worker_manager_loop (in_queue , out_queue , done_event , pin_memory , device_id ):
58+ if pin_memory :
59+ torch .cuda .set_device (device_id )
60+
5861 while True :
5962 try :
6063 r = in_queue .get ()
@@ -176,7 +179,7 @@ def __init__(self, loader):
176179 self .collate_fn = loader .collate_fn
177180 self .batch_sampler = loader .batch_sampler
178181 self .num_workers = loader .num_workers
179- self .pin_memory = loader .pin_memory
182+ self .pin_memory = loader .pin_memory and torch . cuda . is_available ()
180183 self .timeout = loader .timeout
181184 self .done_event = threading .Event ()
182185
@@ -206,7 +209,8 @@ def __init__(self, loader):
206209 self .data_queue = queue .Queue ()
207210 self .worker_manager_thread = threading .Thread (
208211 target = _worker_manager_loop ,
209- args = (self .worker_result_queue , self .data_queue , self .done_event , self .pin_memory ))
212+ args = (self .worker_result_queue , self .data_queue , self .done_event , self .pin_memory ,
213+ torch .cuda .current_device ()))
210214 self .worker_manager_thread .daemon = True
211215 self .worker_manager_thread .start ()
212216 else :
0 commit comments