-
Notifications
You must be signed in to change notification settings - Fork 26.3k
add gumbel_softmax, based on Eric Jang's implementation #3341
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
(test output: ) |
torch/nn/functional.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
efd7b2c to
80e53d7
Compare
|
Thanks for this Hugh, but we may want to hold off on merging this until the distributions API is more settled - pinging @apaszke to see what the status is. If we do take this PR, it'll need to be written with the new API, so |
|
note to whoever picks this up (might be me, but probably not until current master is available in versioned form): need to do something like: (or whatever is the standard approach for cuda-izing things). Edit: PS will be very nice when current master is available in versioned form, so we can use the new distributions api :) ) |
Probably something vaguely similar to diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index 81654b2..55b272f 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -764,7 +764,7 @@ def softmax(input, dim=None, _stacklevel=3):
return torch._C._nn.softmax(input, dim)
-def sample_gumbel(shape, eps=1e-10):
+def sample_gumbel(shape, eps=1e-10, out=None):
"""
Sample from Gumbel(0, 1)
@@ -772,7 +772,7 @@ def sample_gumbel(shape, eps=1e-10):
https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb ,
(MIT license)
"""
- U = torch.rand(shape).float()
+ U = out.resize_(shape).uniform_() if out is not None else torch.rand(shape)
return - torch.log(eps - torch.log(U + eps))
@@ -785,7 +785,7 @@ def gumbel_softmax_sample(logits, tau=1, eps=1e-10):
(MIT license)
"""
dims = len(logits.size())
- gumbel_noise = sample_gumbel(logits.size(), eps=eps)
+ gumbel_noise = sample_gumbel(logits.size(), eps=eps, out=logits.data.new())
y = logits + Variable(gumbel_noise)
return softmax(y / tau, dims - 1)
@@ -816,7 +816,7 @@ def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10):
_, k = y_soft.data.max(-1)
# this bit is based on
# https://discuss.pytorch.org/t/stop-gradients-for-st-gumbel-softmax/530/5
- y_hard = torch.FloatTensor(*shape).zero_().scatter_(-1, k.view(-1, 1), 1.0)
+ y_hard = logits.data.new(*shape).zero_().scatter_(-1, k.view(-1, 1), 1.0)
# this cool bit of code achieves two things:
# - makes the output value exactly one-hot (since we add then
# subtract y_soft value) |
|
cool. do you want to PR that onto my branch? or ... ? |
|
GH doesn't seem to like that you renamed pytorch to pytorch-pytorch. Here's a patch file with tests, if you're up for |
Haha :) Seems like there is already a repo at Added in your commit. |
|
Build finished. |
1 similar comment
|
Build finished. |
This assertion is flaky; I actually had to change the RNG seed to get it to work locally. Perhaps |
|
We want to show there is actually gradient being backpropagated. If random number generator is not invariant across hardware, then I guess we may as well remove the seed. I'd rather not put What are your thoughts on how we can demonstrate that there is a gradient being back-propagated? Hmmm. Perhaps, we should replace (Edit: and/or (Edit2: I shall change to |
|
Ah, I'd meant |
|
Ah, well, I dont really trust eps to show non-zeroness. Floats are only reliable-ish to 6-7 sfs, and eps is 1e-8, way smaller than that. Checking the number of entries greater than say 0.001 would be ok. |
945c3f6 to
3ad8929
Compare
soumith
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sample_gumbel and gumbel_softmax_sample need to be prefixed with _, to make them private functions (eg. _sample_gumbel). After that this PR is good to merge.
torch/nn/functional.py
Outdated
| https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb | ||
| (MIT license) | ||
| """ | ||
| dims = len(logits.size()) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
3ad8929 to
53f43a7
Compare
|
Addressed the |
|
thanks a lot @hughperkins ! |
Summary: Fixes #12643, amends to #3341. - Allow multidimensional input ~~(but apply softmax over `dim=-1`)~~ with `dim` argument - Cleaner: Less lines of code - Faster (1.32x speedup vs original, 2x speedup vs using `torch.Distributions`) - Small fixes in docstring - Remove some references in docstring. Was the linked (excellent) ipynb the first to do the straight-through trick? Instead, I propose changing to reference to the two papers most known for it. - Add deprecationwarning for `eps`. It's not needed anymore. - Initial commit keeps some code alternatives commented to exploit CI - As of discussion when `gumbel_softmax` was added (#3341), this was merged into `torch.nn.functional` before all the work with `Distributions` and `Pyro`, and there will probably be multiple other best practices for this in the future. I've tested building using the `Distributions`-api, but it was too slow, see below. I therefore propose not using `Distributions` to keep it fast and simple, but adding a comment in docstring that `gumbel_softmax` may be deprecated in the future. ``` dist = torch.distributions.RelaxedOneHotCategorical(temperature=tau, logits=logits, validate_args=False) y_soft = dist.rsample() ``` Pros: * Built using tricks like `logsumexp` etc * Explicitly uses `torch.distributions.utils._finfo` to avoid overflow (old implementation had an `eps` flag) * Maintained for this exact purpose. Cons: * Very slow. Construction of distribution adds overhead see timings below. May be solved in future with speedups of `TransformedDistribution` and `Distribution`. * Assumes which `dim` to apply softmax over. ``` y_soft = logits.new(logits.shape) y_soft = (logits - y_soft.exponential_().log()) / tau # Gumbel noise y_soft = y_soft.softmax(dim) # Gumbel softmax noise ``` Pros: * Faster ``` import time start = time.time() num_draws = 1000000 logits = torch.randn(1,3) for draw in range(num_draws): y_draw = gumbel_softmax(logits, hard=True) counts = counts + y_draw print(end - start) >> 12.995795965194702 >> 7.658372640609741 >> 20.3382670879364 ```` Decide on which path to chose. I'll commit in changes to the unit tests in a while to show that it passes both old tests and new tests. I'll also remove the commented code about `RelaxedOneHotCategorical` Pull Request resolved: #13339 Differential Revision: D13092434 Pulled By: ezyang fbshipit-source-id: 4c21788df336f4e9c2ac289022e395b261227b4b
based on:
Gumbel softmax lets you use the reparameterization trick for discrete variables