Skip to content

Commit 81973ea

Browse files
committed
Better fix
1 parent 0bf5126 commit 81973ea

File tree

1 file changed

+14
-16
lines changed

1 file changed

+14
-16
lines changed

torch/nn/parallel/scatter_gather.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,22 @@ def scatter(inputs, target_gpus, dim=0):
1010
references to objects that are not variables. Does not
1111
support Tensors.
1212
"""
13-
def scatter_var(obj):
14-
return Scatter.apply(target_gpus, None, dim, obj)
15-
1613
def scatter_map(obj):
1714
if isinstance(obj, Variable):
18-
return scatter_var(obj)
15+
return Scatter.apply(target_gpus, None, dim, obj)
1916
assert not torch.is_tensor(obj), "Tensors not supported in scatter."
2017
if isinstance(obj, tuple) and len(obj) > 0:
21-
return list(zip(*map(scatter_var, obj)))
18+
return list(zip(*map(scatter_map, obj)))
2219
if isinstance(obj, list) and len(obj) > 0:
23-
return list(map(list, zip(*map(scatter_var, obj))))
20+
return list(map(list, zip(*map(scatter_map, obj))))
2421
if isinstance(obj, dict) and len(obj) > 0:
25-
return list(map(type(obj), zip(*map(scatter_var, obj.items()))))
22+
return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
2623
return [obj for targets in target_gpus]
2724

28-
return scatter_map(inputs)
25+
try:
26+
return scatter_map(inputs)
27+
finally:
28+
scatter_map = None
2929

3030

3131
def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
@@ -46,17 +46,15 @@ def gather(outputs, target_device, dim=0):
4646
Gathers variables from different GPUs on a specified device
4747
(-1 means the CPU).
4848
"""
49-
def gather_vars(outputs):
50-
return Gather.apply(target_device, dim, *outputs)
51-
5249
def gather_map(outputs):
5350
out = outputs[0]
5451
if isinstance(out, Variable):
55-
return gather_vars(outputs)
52+
return Gather.apply(target_device, dim, *outputs)
5653
if out is None:
5754
return None
58-
# Assuming outputs is a iterable of iterables and not
59-
# an iterable of iterables of iterables (or more)
60-
return type(out)(map(gather_vars, zip(*outputs)))
55+
return type(out)(map(gather_map, zip(*outputs)))
6156

62-
return gather_map(outputs)
57+
try:
58+
return gather_map(outputs)
59+
finally:
60+
gather_map = None

0 commit comments

Comments
 (0)