-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
I'm running a network across multiple GPUs and pass the input data through a dictionary. The Variables stored as items of the input dictionary are properly scattered across the batch dimension, and the forward pass terminates correctly.
However, when returning the output as a dictionary I get the following runtime error:
torch/nn/parallel/scatter_gather.pyc in gather_map(outputs)
47 if out is None:
48 return None
---> 49 return type(out)(map(gather_map, zip(*outputs)))
50 return gather_map(outputs)
torch/nn/parallel/scatter_gather.pyc in gather_map(outputs)
43 def gather_map(outputs):
44 out = outputs[0]
---> 45 if isinstance(out, Variable):
46 return Gather(target_device, dim=dim)(*outputs)
47 if out is None:
RuntimeError: maximum recursion depth exceeded in __instancecheck__
The reason is that the function gather_map in scatter_gather.py only supports Variables or iterables of variables as its input.
However, scatter_map in scatter_gather.py also supports dictionaries.
Is there a reason for this discrepancy? Would it be useful if I made a pull request and added this functionality?
I am implementing a network with multiple sub-networks whose outputs may or may not be computed (based on a config file) and it would be useful to be able to pass all of them in a compact way out of the data parallel wrapper.
SNIPPET REPLICATING ERROR:
import torch
import torch.nn as nn
from torch.autograd import Variable
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.block1 = nn.Linear(10, 20)
self.block2 = nn.Linear(20, 20)
def forward(self, x):
x = self.block1(x)
x = self.block2(x)
return x
class MyModelDictInput(nn.Module):
def __init__(self):
super(MyModelDictInput, self).__init__()
self.block1 = nn.Linear(10, 20)
self.block2 = nn.Linear(20, 20)
def forward(self, d):
x = d['an_input']
x = self.block1(x)
x = self.block2(x)
return x
class MyModelDictOutput(nn.Module):
def __init__(self):
super(MyModelDictOutput, self).__init__()
self.block1 = nn.Linear(10, 20)
self.block2 = nn.Linear(20, 20)
def forward(self, x):
x = self.block1(x)
x = self.block2(x)
d = dict()
d['an_output'] = x
return d
# create random input
i = Variable(torch.rand((4,10)))
d = {'an_input':i}
# example 1:
print('input is a Variable, output is a Variable')
net = nn.DataParallel(MyModel()).cuda()
o = net.forward(i)
print(o)
# example 2:
print('input is a dict, output is a Variable')
net = nn.DataParallel(MyModelDictInput()).cuda()
o = net.forward(d)
print(o)
# example 3:
print('input is a Variable, output is a dict')
net = nn.DataParallel(MyModelDictOutput()).cuda()
o = net.forward(i)
print(o)