@@ -60,7 +60,10 @@ def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, wo
6060 data_queue .put ((idx , samples ))
6161
6262
63- def _worker_manager_loop (in_queue , out_queue , done_event , pin_memory ):
63+ def _worker_manager_loop (in_queue , out_queue , done_event , pin_memory , device_id ):
64+ if pin_memory :
65+ torch .cuda .set_device (device_id )
66+
6467 while True :
6568 try :
6669 r = in_queue .get ()
@@ -182,7 +185,7 @@ def __init__(self, loader):
182185 self .collate_fn = loader .collate_fn
183186 self .batch_sampler = loader .batch_sampler
184187 self .num_workers = loader .num_workers
185- self .pin_memory = loader .pin_memory
188+ self .pin_memory = loader .pin_memory and torch . cuda . is_available ()
186189 self .timeout = loader .timeout
187190 self .done_event = threading .Event ()
188191
@@ -211,7 +214,8 @@ def __init__(self, loader):
211214 self .data_queue = queue .Queue ()
212215 self .worker_manager_thread = threading .Thread (
213216 target = _worker_manager_loop ,
214- args = (self .worker_result_queue , self .data_queue , self .done_event , self .pin_memory ))
217+ args = (self .worker_result_queue , self .data_queue , self .done_event , self .pin_memory ,
218+ torch .cuda .current_device ()))
215219 self .worker_manager_thread .daemon = True
216220 self .worker_manager_thread .start ()
217221 else :
0 commit comments