Skip to content

Conversation

@hughperkins
Copy link
Contributor

@hughperkins
Copy link
Contributor Author

(test output:

$ python test/test_nn.py -v -b TestNN.test_gumbel_softmax_st
test_gumbel_softmax_st (__main__.TestNN) ... ok

----------------------------------------------------------------------
Ran 1 test in 0.020s

OK

)

This comment was marked as off-topic.

This comment was marked as off-topic.

@Kaixhin
Copy link
Contributor

Kaixhin commented Oct 29, 2017

Thanks for this Hugh, but we may want to hold off on merging this until the distributions API is more settled - pinging @apaszke to see what the status is.

If we do take this PR, it'll need to be written with the new API, so torch.distributions.Gumbel etc.

@hughperkins
Copy link
Contributor Author

hughperkins commented Nov 12, 2017

note to whoever picks this up (might be me, but probably not until current master is available in versioned form): need to do something like:

type_constr = torch.cuda if logits.is_cuda else torch
y_hard = type_constr.FloatTensor( ...

(or whatever is the standard approach for cuda-izing things).

Edit: PS will be very nice when current master is available in versioned form, so we can use the new distributions api :) )

@nhynes
Copy link
Contributor

nhynes commented Nov 13, 2017

(or whatever is the standard approach for cuda-izing things)

Probably something vaguely similar to

diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index 81654b2..55b272f 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -764,7 +764,7 @@ def softmax(input, dim=None, _stacklevel=3):
     return torch._C._nn.softmax(input, dim)


-def sample_gumbel(shape, eps=1e-10):
+def sample_gumbel(shape, eps=1e-10, out=None):
     """
     Sample from Gumbel(0, 1)

@@ -772,7 +772,7 @@ def sample_gumbel(shape, eps=1e-10):
     https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb ,
     (MIT license)
     """
-    U = torch.rand(shape).float()
+    U = out.resize_(shape).uniform_() if out is not None else torch.rand(shape)
     return - torch.log(eps - torch.log(U + eps))


@@ -785,7 +785,7 @@ def gumbel_softmax_sample(logits, tau=1, eps=1e-10):
     (MIT license)
     """
     dims = len(logits.size())
-    gumbel_noise = sample_gumbel(logits.size(), eps=eps)
+    gumbel_noise = sample_gumbel(logits.size(), eps=eps, out=logits.data.new())
     y = logits + Variable(gumbel_noise)
     return softmax(y / tau, dims - 1)

@@ -816,7 +816,7 @@ def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10):
         _, k = y_soft.data.max(-1)
         # this bit is based on
         # https://discuss.pytorch.org/t/stop-gradients-for-st-gumbel-softmax/530/5
-        y_hard = torch.FloatTensor(*shape).zero_().scatter_(-1, k.view(-1, 1), 1.0)
+        y_hard = logits.data.new(*shape).zero_().scatter_(-1, k.view(-1, 1), 1.0)
         # this cool bit of code achieves two things:
         # - makes the output value exactly one-hot (since we add then
         #   subtract y_soft value)

@hughperkins
Copy link
Contributor Author

cool. do you want to PR that onto my branch? or ... ?

@nhynes
Copy link
Contributor

nhynes commented Nov 15, 2017

GH doesn't seem to like that you renamed pytorch to pytorch-pytorch. Here's a patch file with tests, if you're up for git aming it manually :)

@hughperkins
Copy link
Contributor Author

GH doesn't seem to like that you renamed pytorch to pytorch-pytorch

Haha :) Seems like there is already a repo at hughperkins/pytorch though, so :P

Added in your commit.

@pytorchbot
Copy link
Collaborator

Build finished.

1 similar comment
@pytorchbot
Copy link
Collaborator

Build finished.

@nhynes
Copy link
Contributor

nhynes commented Nov 15, 2017

FAIL: test_gumbel_softmax_st (__main__.TestNN)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test_nn.py", line 1021, in test_gumbel_softmax_st
    self._test_gumbel_softmax_st(False)
  File "test_nn.py", line 1005, in _test_gumbel_softmax_st
    assert logits_var.grad.abs().min().data[0] > 0.001
AssertionError

This assertion is flaky; I actually had to change the RNG seed to get it to work locally. Perhaps manual_seeding has some machine-specific dependency? Is there a reason for 0.001 specifically or does > eps also work?

@hughperkins
Copy link
Contributor Author

hughperkins commented Nov 16, 2017

We want to show there is actually gradient being backpropagated. If random number generator is not invariant across hardware, then I guess we may as well remove the seed. I'd rather not put eps, since eps is used for showing things are effectively the same, whereas we want to check that the gradient is effectively not the same, ie not zero.

What are your thoughts on how we can demonstrate that there is a gradient being back-propagated?

Hmmm. Perhaps, we should replace .min() with .mean() or .max()? .min() does seem ... fragile.

(Edit: and/or .var(), .std() perhaps?)

(Edit2: I shall change to .std() I think)

@nhynes
Copy link
Contributor

nhynes commented Nov 16, 2017

Ah, I'd meant eps in terms of non-zero. std seems to work; another approach might be to just check the fraction of entries that have magnitude greater than eps = 1e-8.

@hughperkins
Copy link
Contributor Author

Ah, well, I dont really trust eps to show non-zeroness. Floats are only reliable-ish to 6-7 sfs, and eps is 1e-8, way smaller than that.

Checking the number of entries greater than say 0.001 would be ok.

Copy link
Contributor

@soumith soumith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sample_gumbel and gumbel_softmax_sample need to be prefixed with _, to make them private functions (eg. _sample_gumbel). After that this PR is good to merge.

https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb
(MIT license)
"""
dims = len(logits.size())

This comment was marked as off-topic.

@hughperkins
Copy link
Contributor Author

Addressed the _ naming prefix, and the logits.dim() change.

@soumith soumith merged commit fc0d940 into pytorch:master Jan 4, 2018
@soumith
Copy link
Contributor

soumith commented Jan 4, 2018

thanks a lot @hughperkins !

@ragulpr ragulpr mentioned this pull request Oct 30, 2018
facebook-github-bot pushed a commit that referenced this pull request Jan 17, 2019
Summary:
Fixes #12643, amends to #3341.

- Allow multidimensional input ~~(but apply softmax over `dim=-1`)~~ with `dim` argument
- Cleaner: Less lines of code
- Faster (1.32x speedup vs original, 2x speedup vs using `torch.Distributions`)
- Small fixes in docstring
- Remove some references in docstring. Was the linked (excellent) ipynb the first to do the straight-through trick? Instead, I propose changing to reference to the two papers most known for it.
- Add deprecationwarning for `eps`. It's not needed anymore.
- Initial commit keeps some code alternatives commented to exploit CI

- As of discussion when `gumbel_softmax` was added (#3341), this was merged into `torch.nn.functional` before all the work with `Distributions` and `Pyro`, and there will probably be multiple other best practices for this in the future.
I've tested building using the `Distributions`-api, but it was too slow, see below.

I therefore propose not using `Distributions` to keep it fast and simple, but adding a comment in docstring that `gumbel_softmax` may be deprecated in the future.

```
dist = torch.distributions.RelaxedOneHotCategorical(temperature=tau, logits=logits, validate_args=False)
y_soft = dist.rsample()
```

Pros:
* Built using tricks like `logsumexp` etc
* Explicitly uses `torch.distributions.utils._finfo` to avoid overflow (old implementation had an `eps` flag)
* Maintained for this exact purpose.

Cons:
* Very slow. Construction of distribution adds overhead see timings below. May be solved in future with speedups of `TransformedDistribution` and `Distribution`.
* Assumes which `dim` to apply softmax over.

```
    y_soft = logits.new(logits.shape)
    y_soft = (logits - y_soft.exponential_().log()) / tau  # Gumbel noise
    y_soft = y_soft.softmax(dim)  # Gumbel softmax noise
```
Pros:
* Faster

```
    import time
    start = time.time()
    num_draws = 1000000
    logits = torch.randn(1,3)

    for draw in range(num_draws):
        y_draw = gumbel_softmax(logits, hard=True)
        counts = counts + y_draw
    print(end - start)

>> 12.995795965194702

>> 7.658372640609741

>> 20.3382670879364
````

Decide on which path to chose. I'll commit in changes to the unit tests in a while to show that it passes both old tests and new tests. I'll also remove the commented code about `RelaxedOneHotCategorical`
Pull Request resolved: #13339

Differential Revision: D13092434

Pulled By: ezyang

fbshipit-source-id: 4c21788df336f4e9c2ac289022e395b261227b4b
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants