-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
Problem
import torch
import torch.nn as nn
embed = nn.Embedding(101, 20, padding_idx=0, sparse=True,)
param = list(embed.parameters())[0]
x = torch.LongTensor([0, 0])
x_embed = embed(x)
y = torch.sum(x_embed)
y.backward()
x = torch.ones(2).long()
x_embed = embed(x)
y = torch.sum(x_embed)
y.backward()
The first y.backward() results in a gradient (param.grad) of size (101, 20) of all zeros. This is represented in PyTorch as a sparse tensor of size (101, 20) with empty indices and empty values and (dimI, dimV) = (2, 0).
torch.sparse.FloatTensor of size (101,20) with indices:
tensor([], dtype=torch.int64)
and values:
tensor([])
The second y.backward() results in a gradient of size (101, 20) of not all zeros. This is represented in PyTorch as a sparse tensor of size (101, 20) with (dimI, dimV) = (1, 1):
torch.sparse.FloatTensor of size (101,20) with indices:
tensor([[ 1, 1]])
and values:
tensor([[ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1.],
[ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1.]])
Because both y.backward() update the grad of param, the two gradients are summed and fail because (dimI, dimV) of the two sparse tensors don't match.
Possible solutions
Ignore (dimI, dimV) when adding a sparse tensor that is all zeros to another sparse tensor as long as the share the same size(). This is the "each size has a unique zero-filled sparse tensor" solution.
Stick with the invariant that (dimI, dimV) of sparse tensors have to be the same for them to be added, and say that the gradient from the first y.backward() should have indices of size (1, 0) and values of size (0, 0), based on the shape of the input to embedding. This is the "each (size, dimI, dimV) tuple defines a unique zero-filled sparse tensor" solution.