@@ -10,22 +10,22 @@ def scatter(inputs, target_gpus, dim=0):
1010 references to objects that are not variables. Does not
1111 support Tensors.
1212 """
13- def scatter_var (obj ):
14- return Scatter .apply (target_gpus , None , dim , obj )
15-
1613 def scatter_map (obj ):
1714 if isinstance (obj , Variable ):
18- return scatter_var ( obj )
15+ return Scatter . apply ( target_gpus , None , dim , obj )
1916 assert not torch .is_tensor (obj ), "Tensors not supported in scatter."
2017 if isinstance (obj , tuple ) and len (obj ) > 0 :
21- return list (zip (* map (scatter_var , obj )))
18+ return list (zip (* map (scatter_map , obj )))
2219 if isinstance (obj , list ) and len (obj ) > 0 :
23- return list (map (list , zip (* map (scatter_var , obj ))))
20+ return list (map (list , zip (* map (scatter_map , obj ))))
2421 if isinstance (obj , dict ) and len (obj ) > 0 :
25- return list (map (type (obj ), zip (* map (scatter_var , obj .items ()))))
22+ return list (map (type (obj ), zip (* map (scatter_map , obj .items ()))))
2623 return [obj for targets in target_gpus ]
2724
28- return scatter_map (inputs )
25+ try :
26+ return scatter_map (inputs )
27+ finally :
28+ scatter_map = None
2929
3030
3131def scatter_kwargs (inputs , kwargs , target_gpus , dim = 0 ):
@@ -46,17 +46,15 @@ def gather(outputs, target_device, dim=0):
4646 Gathers variables from different GPUs on a specified device
4747 (-1 means the CPU).
4848 """
49- def gather_vars (outputs ):
50- return Gather .apply (target_device , dim , * outputs )
51-
5249 def gather_map (outputs ):
5350 out = outputs [0 ]
5451 if isinstance (out , Variable ):
55- return gather_vars ( outputs )
52+ return Gather . apply ( target_device , dim , * outputs )
5653 if out is None :
5754 return None
58- # Assuming outputs is a iterable of iterables and not
59- # an iterable of iterables of iterables (or more)
60- return type (out )(map (gather_vars , zip (* outputs )))
55+ return type (out )(map (gather_map , zip (* outputs )))
6156
62- return gather_map (outputs )
57+ try :
58+ return gather_map (outputs )
59+ finally :
60+ gather_map = None
0 commit comments