Skip to content

Commit 53f43a7

Browse files
committed
gumbel_softmax tweaks
1 parent dcd3d3a commit 53f43a7

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

test/test_nn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,8 +1016,6 @@ def test_embedding_functional(self):
10161016

10171017
def _test_gumbel_softmax_st(self, cuda):
10181018
th = torch.cuda if cuda else torch
1019-
old_rng_state = th.get_rng_state()
1020-
th.manual_seed(42)
10211019
"""
10221020
Things we might want to check:
10231021
- if we make various draws, do we get different one-hot values?
@@ -1037,6 +1035,7 @@ def _test_gumbel_softmax_st(self, cuda):
10371035
y_draws = y_draws.cuda()
10381036
preds = preds.cuda()
10391037

1038+
exceed_limits = 0
10401039
for draw in range(num_draws):
10411040
logits_var = Variable(logits, requires_grad=True)
10421041
y_draw = torch.nn.functional.gumbel_softmax(
@@ -1048,10 +1047,12 @@ def _test_gumbel_softmax_st(self, cuda):
10481047
err = y_draw - Variable(logits.new([[0, 0.5, 0.3]]))
10491048
loss = (err * err).sum()
10501049
loss.backward()
1051-
assert logits_var.grad.abs().min().data[0] > 0.001
1050+
if logits_var.grad.data.std() < 0.01 or logits_var.grad.data.std() > 1.0:
1051+
exceed_limits += 1
10521052
y_draws[draw] = y_draw.data
10531053
_, pred = y_draw.max(1)
10541054
preds[draw] = pred.data[0]
1055+
assert exceed_limits / num_draws < 0.05
10551056
# check it's approximately one-hot
10561057
num_ones = (y_draws == 1).int().sum()
10571058
num_zeros = (y_draws == 0).int().sum()
@@ -1061,7 +1062,6 @@ def _test_gumbel_softmax_st(self, cuda):
10611062
num_class_one = (preds == 1).int().sum()
10621063
assert num_class_one < num_draws
10631064
assert num_class_one > num_draws / 3
1064-
th.set_rng_state(old_rng_state)
10651065

10661066
def test_gumbel_softmax_st(self):
10671067
self._test_gumbel_softmax_st(False)

torch/nn/functional.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,7 @@ def softmax(input, dim=None, _stacklevel=3):
835835
return torch._C._nn.softmax(input, dim)
836836

837837

838-
def sample_gumbel(shape, eps=1e-10, out=None):
838+
def _sample_gumbel(shape, eps=1e-10, out=None):
839839
"""
840840
Sample from Gumbel(0, 1)
841841
@@ -847,16 +847,16 @@ def sample_gumbel(shape, eps=1e-10, out=None):
847847
return - torch.log(eps - torch.log(U + eps))
848848

849849

850-
def gumbel_softmax_sample(logits, tau=1, eps=1e-10):
850+
def _gumbel_softmax_sample(logits, tau=1, eps=1e-10):
851851
"""
852852
Draw a sample from the Gumbel-Softmax distribution
853853
854854
based on
855855
https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb
856856
(MIT license)
857857
"""
858-
dims = len(logits.size())
859-
gumbel_noise = sample_gumbel(logits.size(), eps=eps, out=logits.data.new())
858+
dims = logits.dim()
859+
gumbel_noise = _sample_gumbel(logits.size(), eps=eps, out=logits.data.new())
860860
y = logits + Variable(gumbel_noise)
861861
return softmax(y / tau, dims - 1)
862862

@@ -882,7 +882,7 @@ def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10):
882882
"""
883883
shape = logits.size()
884884
assert len(shape) == 2
885-
y_soft = gumbel_softmax_sample(logits, tau=tau, eps=eps)
885+
y_soft = _gumbel_softmax_sample(logits, tau=tau, eps=eps)
886886
if hard:
887887
_, k = y_soft.data.max(-1)
888888
# this bit is based on

0 commit comments

Comments
 (0)