Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,62 @@ def test_embedding_functional(self):
res_F = F.embedding(a, embeddings)
self.assertEqual(res_old, res_F)

def _test_gumbel_softmax_st(self, cuda):
th = torch.cuda if cuda else torch
"""
Things we might want to check:
- if we make various draws, do we get different one-hot values?
- is the proportion approximately in line with the softmax values?
- with hard, is it one-hot?
- with hard, is there still a gradient?
"""
num_draws = 100
K = 3
logits = torch.FloatTensor([[0.2, 0.8, 0.1]])
logits_softmax = torch.nn.functional.softmax(Variable(logits), 1)
y_draws = torch.zeros(num_draws, K)
preds = torch.zeros(num_draws)

if cuda:
logits = logits.cuda()
y_draws = y_draws.cuda()
preds = preds.cuda()

exceed_limits = 0
for draw in range(num_draws):
logits_var = Variable(logits, requires_grad=True)
y_draw = torch.nn.functional.gumbel_softmax(
logits_var,
hard=True)
assert y_draw.size() == logits.size()
# check we have a gradient
assert y_draw.requires_grad
err = y_draw - Variable(logits.new([[0, 0.5, 0.3]]))
loss = (err * err).sum()
loss.backward()
if logits_var.grad.data.std() < 0.01 or logits_var.grad.data.std() > 1.0:
exceed_limits += 1
y_draws[draw] = y_draw.data
_, pred = y_draw.max(1)
preds[draw] = pred.data[0]
assert exceed_limits / num_draws < 0.05
# check it's approximately one-hot
num_ones = (y_draws == 1).int().sum()
num_zeros = (y_draws == 0).int().sum()
assert num_ones + num_zeros == num_draws * K
assert num_ones == num_draws
# check output classes approx in line with logits
num_class_one = (preds == 1).int().sum()
assert num_class_one < num_draws
assert num_class_one > num_draws / 3

def test_gumbel_softmax_st(self):
self._test_gumbel_softmax_st(False)

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_gumbel_softmax_st_cuda(self):
self._test_gumbel_softmax_st(True)

def _test_EmbeddingBag(self, cuda, mode):
# check a known test example
es = nn.EmbeddingBag(5, 2, mode=mode)
Expand Down
64 changes: 64 additions & 0 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,70 @@ def softmax(input, dim=None, _stacklevel=3):
return torch._C._nn.softmax(input, dim)


def _sample_gumbel(shape, eps=1e-10, out=None):
"""
Sample from Gumbel(0, 1)

based on
https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb ,
(MIT license)
"""
U = out.resize_(shape).uniform_() if out is not None else torch.rand(shape)
return - torch.log(eps - torch.log(U + eps))


def _gumbel_softmax_sample(logits, tau=1, eps=1e-10):
"""
Draw a sample from the Gumbel-Softmax distribution

based on
https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb
(MIT license)
"""
dims = logits.dim()
gumbel_noise = _sample_gumbel(logits.size(), eps=eps, out=logits.data.new())
y = logits + Variable(gumbel_noise)
return softmax(y / tau, dims - 1)


def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10):
"""
Sample from the Gumbel-Softmax distribution and optionally discretize.
Args:
logits: [batch_size, n_class] unnormalized log-probs
tau: non-negative scalar temperature
hard: if True, take argmax, but differentiate w.r.t. soft sample y
Returns:
[batch_size, n_class] sample from the Gumbel-Softmax distribution.
If hard=True, then the returned sample will be one-hot, otherwise it will
be a probability distribution that sums to 1 across classes

Constraints:
- this implementation only works on batch_size x num_features tensor for now

based on
https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb ,
(MIT license)
"""
shape = logits.size()
assert len(shape) == 2
y_soft = _gumbel_softmax_sample(logits, tau=tau, eps=eps)
if hard:
_, k = y_soft.data.max(-1)
# this bit is based on
# https://discuss.pytorch.org/t/stop-gradients-for-st-gumbel-softmax/530/5
y_hard = logits.data.new(*shape).zero_().scatter_(-1, k.view(-1, 1), 1.0)
# this cool bit of code achieves two things:
# - makes the output value exactly one-hot (since we add then
# subtract y_soft value)
# - makes the gradient equal to y_soft gradient (since we strip
# all other gradients)
y = Variable(y_hard - y_soft.data) + y_soft
else:
y = y_soft
return y


def log_softmax(input, dim=None, _stacklevel=3):
r"""Applies a softmax followed by a logarithm.

Expand Down