-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
import torch
import torch.nn as nn
from torch.nn.utils.weight_norm import weight_norm
class LinearWn(nn.Module):
def __init__(self):
super(LinearWn, self).__init__()
layers = [weight_norm(nn.Linear(100, 10), dim=None), nn.ReLU()]
self.main = nn.Sequential(*layers)
def forward(self, x):
return self.main(x)
torch.manual_seed(1)
model = LinearWn()
model = model.cuda()
model = nn.DataParallel(model)
data = torch.rand(2000, 100).cuda()
res = model(data)
res.sum().backward()
error message:
Traceback (most recent call last):
File "/private/home/tinayujiang/VQA/vqa_suite/toy_weight_norm.py", line 24, in
res = model(data)
File "/private/home/tinayujiang/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in call
result = self.forward(*input, **kwargs)
File "/private/home/tinayujiang/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 113, in forward
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
File "/private/home/tinayujiang/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 118, in replicate
return replicate(module, device_ids)
File "/private/home/tinayujiang/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/replicate.py", line 12, in replicate
param_copies = Broadcast.apply(devices, *params)
RuntimeError: slice() cannot be applied to a 0-dim tensor.