Skip to content

Conversation

@t-vi
Copy link
Collaborator

@t-vi t-vi commented Jul 27, 2018

In the shortcut for n_sample=1, when category 0 has 0 weight,
we should not map the (uniform) sample 0 to category 0.
The conversion uniform->multinomial was apparently written to work on
a (0,1] range (like curand uses), but PyTorch uses a [0,1) range.

Fixes: #4858. Thank you, Roy Fejgin for reporting.

In the shortcut for n_sample=1, when category 0 has 0 weight,
we should not map the (uniform) sample 0 to category 0.
The conversion uniform->multinomial was apparently written to work on
a (0,1] range (like curand uses), but PyTorch uses a [0,1) range.

Fixes: pytorch#4858. Thank you, Roy Fejgin for reporting.
@rfejgin
Copy link

rfejgin commented Jul 30, 2018

@t-vi: Thanks for the fix! I've tried it and can confirm that with the fix I am no longer able to reproduce the problem. This is using my own test case which was previously able to cause the problem reliably (regardless of the particular random seed).

While I'm not able to fully review the fix from an algorithmic point of view, I do suggest to add a comment in the code explaining why we are substituting the uniform samples with (1.0 - uniform_samples) since -- to me -- it's rather non-obvious and depends on stuff that curand does in a different part of the codebase and on implicit assumptions in the code below it.

@t-vi
Copy link
Collaborator Author

t-vi commented Jul 31, 2018

Fair enough, I've changed the fix to substitute 0 for 1 (mirroring the move from curand -> PyTorch uniform) and added a comment. Quite likely, one might rewrite the function a bit to work with [0,1) when it is migrated to ATen native or so.

@rfejgin
Copy link

rfejgin commented Jul 31, 2018

Thanks @t-vi . Do you think the second test in this line still makes sense after this fix?
https://github.com/t-vi/pytorch/blob/f6eac1b3cfffba97ce5d3ce2d429fbf22a678151/aten/src/THC/THCTensorRandom.cuh#L215

It seems like that second condition can now never be true. Also, if the code was originally written to work with samples in the interval (0,1] I wonder why it's testing for sample==0?

@t-vi
Copy link
Collaborator Author

t-vi commented Jul 31, 2018

Ha. That was exactly the test I had missed. Thanks Roy!
So before I had moved to fix it the way I did now (for which I checked the history of the code), I tried to adjust the code below to handle 0 correctly. Now, I had missed this part and the continue here causes the category to be 0, and my tested adjustment to the binning intervals didn't work out.
I'll revisit that plan for a clean fix. Thanks for pointing it out!

Make the intervals for the binning half open [p,p_next).
Also remove the check for sample==0 that would short-circuit the
bin search and give a wrong sample.

Thank you, Roy Fejgin for the report and the discussion!
@li-roy li-roy added the ready for review (this tag is deprecated) All PRs are ready for review unless they are draft, WIP, or have undismissed requested changes label Jul 31, 2018
@rfejgin
Copy link

rfejgin commented Jul 31, 2018

hmm, unfortunately my script is still able to catch some zero-probability events being sampled. But now they are not at the 0'th index anymore. Maybe we are peeling an onion here...
I'll try to make my script a bit more minimal and will post it.

@t-vi
Copy link
Collaborator Author

t-vi commented Jul 31, 2018

Thanks! You could just capture and save the RNG state to a temporary variable (torch.cuda.get_rng_state) and output it after detecting failure.
Or if you just have the probabilities, that would be cool, too.

@rfejgin
Copy link

rfejgin commented Jul 31, 2018

Here are the probabilities. Will post a script shortly.

weights.zip

@rfejgin
Copy link

rfejgin commented Jul 31, 2018

Attaching RNG state
rng_state.zip

@rfejgin
Copy link

rfejgin commented Jul 31, 2018

Here is my script. I have not had a chance to create an absolute minimal test case, but if you run this you should see the problem when it gets iteration 61678, which takes about 52 seconds on my GPU (1080ti).
debug_zero_prob.zip

@rfejgin
Copy link

rfejgin commented Jul 31, 2018

Attaching a new version of the test case and output. This version has slightly clearer printouts.

debug_zero_prob.zip

@rfejgin
Copy link

rfejgin commented Aug 1, 2018

The strange thing is that now, these zero probability bins getting selected are not bin zero, nor bins that are near any bin that has non-zero probability. For example the PDF can look like this:
[0.0, 0.0, 0.0, ... 0.0, 0.3, 0.4, 0.1, 0.0, 0.0, 0.0, ... 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
.. and the bold bin gets selected. How is that bin different from any of its neighboring bins? The CDF should have the same value for this bin and its neighbors, so seems to me that it can't be a subtlety in the bucket boundary checks.
I'm starting to suspect something more fundamental is wrong with the CUDA implementation... but just speculating.

@rfejgin
Copy link

rfejgin commented Aug 1, 2018

I added some tracing in the CUDA code. In the case I caught, the uniform sample equals 0.99999994. Then many buckets (to the right of the last non-zero-probability bin) ended up meeting the condition

(cat < categories) && (!THCNumerics<T>::ge(sample, curBucket)) && (THCNumerics<T>::ge(sample, prevBucket));

All of them had curBucket == 1.0 and prevBucket == 0.99999994

Then many threads end up writing out the result in this line
dest[curDist] = cat + TH_INDEX_BASE;

and I think that maybe the last of the threads to make that update wins.

Now I guess the question is why curBucket does not equal prevBucket even though all the nearby bins have zero probability.

@rfejgin
Copy link

rfejgin commented Aug 1, 2018

Oops, I just realize you must have wanted the RNG state before the bug occurs so that you could regenerate the uniform samples that trigger the problem. What I posted was the RNG state after the bug occurred -sorry, I misunderstood. I'm not at my desk now to re-capture that data, but hopefully the test case I posted and the uniform sample value is enough to analyze / reproduce this.

@t-vi
Copy link
Collaborator Author

t-vi commented Aug 1, 2018 via email

@t-vi
Copy link
Collaborator Author

t-vi commented Aug 1, 2018

So I took the liberty of condensing your example to the following:

import torch
print (torch.__version__)
batch_size = 1024
dist_size = 2048
k = 16095
torch.cuda.manual_seed(k)
torch.manual_seed(k)
weights = torch.empty(batch_size, dist_size, device='cuda')
weights.uniform_(0.4, 0.6)
weights[:, :int(dist_size/2)] = 0
weights[:, int(dist_size/2)+4:] = 0
s = weights.multinomial(1).squeeze(1)
selected_probs = weights[torch.arange(batch_size), s]
assert not (selected_probs==0).any()

(and it really fails with e5c5ae3 , I had just realized I had another branch when I first posted this).

@t-vi
Copy link
Collaborator Author

t-vi commented Aug 1, 2018

So interestingly, when I use the identical kernel with a caller that calls torch.rand and then the kernel from a PyTorch extension, that seems to not have the error. I'll see to extract the random part from the actual call.
A random thing I've noticed and been wondering is whether there is a typo regarding the shared memory requirements: There is the product sizeof(real)*sizeof(accreal) looks pretty suspicious and probably should be a sum (possibly with coeffs). But that likely is unrelated to our present bug.

@rfejgin
Copy link

rfejgin commented Aug 1, 2018

It's starting to look like a floating point numeric accuracy thing, that the code isn't resilient to. I'm seeing cases where curBucket should be equal to prevBucket (because the bucket itself has zero probabilitly), but they are slightly different, with curBucket being 1.0 and prevBucket being a little less than 1.0. If the sample falls between those values, the bucket gets selected.
I think curBucket prevBucket should be equal mathematically but are diverging because of floating point arithmetic -- they are the result of different sums calculated by different threads.

@rfejgin
Copy link

rfejgin commented Aug 1, 2018

@t-vi interesting finding about the memory allocation, btw. Yeah, that product does look strange. Though probably allocating too much memory which is better than the opposite... :)

@rfejgin
Copy link

rfejgin commented Aug 1, 2018

Here's a log showing an example where bin 1054 was selected. All bins from 1028 onwards had zero probability. Hmm, at least with this example I think the problem would not have happened with if the condition was 'gt' as it originally was. Let's think about whether that check fundamentally makes more sense or just happens to work better in this case.
curDist=329,thread=1023,cat=1023,prevBucket=0,curBucket=0,sample=0.999999940395355 curDist=329,thread=0,cat=1024,prevBucket=0,curBucket=0.227620840072632,sample=0.999999940395355 curDist=329,thread=1,cat=1025,prevBucket=0.227620840072632,curBucket=0.502622127532959,sample=0.999999940395355 curDist=329,thread=2,cat=1026,prevBucket=0.502622127532959,curBucket=0.769631028175354,sample=0.999999940395355 curDist=329,thread=3,cat=1027,prevBucket=0.769631028175354,curBucket=0.999999940395355,sample=0.999999940395355 curDist=329,thread=4,cat=1028,prevBucket=0.999999940395355,curBucket=1,sample=0.999999940395355 curDist=329,thread=5,cat=1029,prevBucket=1,curBucket=0.999999940395355,sample=0.999999940395355 curDist=329,thread=6,cat=1030,prevBucket=0.999999940395355,curBucket=1,sample=0.999999940395355

@rfejgin
Copy link

rfejgin commented Aug 1, 2018

Maybe if we want ensure zero-probability events are not selected, in a way that is robust to floating point numeric problems, the inBucket check could be forced to false if the difference between prevBucket and curBucket is less than an epsilon. That epsilon might need to depend on the underlying data type (T) being used.

@t-vi t-vi changed the title Fix corner case with torch.multinomial [not ready yet] Fix corner case with torch.multinomial Aug 2, 2018
@t-vi
Copy link
Collaborator Author

t-vi commented Aug 2, 2018

So I think the root cause is the inclusive prefix sum (in the comment). Here, the summation order is totally different for adjacent elements, that might lead to the observed non-monotonicity and other effects.
Now how to fix this? I see the following options:

  • There is the minimal size @rfejgin suggests at CUDA multinomial with replacement can select zero-probability events #4858.
  • One elaborate option (which I'm not 100% certain that it will suffice) would be to improve the summing - I tested using the powers of two in the bit pattern as summation units and it looks like it might help.
  • Rounding the inputs would also likely work.
  • My favourite solution would be to just test dist > 0 in the same place as Roy would check the distances. I implemented that.
    A final observation: the prefix sum is done using T instead of AccT. I'm not sure why given that shared memory of AccT type is available but I didn't change it for

@t-vi
Copy link
Collaborator Author

t-vi commented Aug 2, 2018

So to me it looks like this test case captures the essence of @rfejgin 's latest example and works as expected with the latest commit.

import torch
dist_size = 2048
batch_size = 1024
torch.cuda.manual_seed(41134)
weights = torch.zeros(dist_size, device='cuda')
weights[[1024,1025,1026,1027]] = torch.tensor([0.47056111693382263, 0.568510890007019, 0.5519880056381226, 0.476242333650589], device='cuda')
weights = weights[None].expand(batch_size, dist_size)
s = weights.multinomial(1).squeeze(1)
selected_probs = weights[torch.arange(weights.shape[0]), s]
assert not (selected_probs==0).any()

If @rfejgin is reasonably happy with this, I'll add the above as a test case and remove the not ready yet label.

@rfejgin
Copy link

rfejgin commented Aug 2, 2018

Thanks Thomas. I agree with you on the root cause. I also like the idea of explicitly checking if the distribution value equals zero, and was having similar thoughts after I posted. The explicit check avoids the need to tune a threshold (or one per type of T), which I've actually tried and found fragile. So your fix looks good to me!

One minor comment is that I would consider putting the memory allocation fix in a different commit. In the unlikely case that that change causes instability it would be good to be able to separate the two changes.

@t-vi t-vi changed the title [not ready yet] Fix corner case with torch.multinomial Fix corner case with torch.multinomial Aug 2, 2018
@t-vi
Copy link
Collaborator Author

t-vi commented Aug 2, 2018

OK. My impression is that PRs are heavyhanded enough to tack the shared memory size on this, but I'll gladly remove it if the reviewer prefers.

@rfejgin
Copy link

rfejgin commented Aug 2, 2018

@t-vi: it's really up to you - this is my first pytorch PR review, so I'm not super familiar with the conventions - just a suggestion.

On another topic: Now that we're explicitly handling the zero-probability case, do you think that the change in the condition from gt to ge is still necessary? https://github.com/pytorch/pytorch/pull/9960/files#diff-a5d44fe2c50c7cc4616e927154fb7e30R259

@t-vi
Copy link
Collaborator Author

t-vi commented Aug 2, 2018

You could argue that 0 is exceedingly rare, but I think it is good practice to cover the half-open interval [0,1) with equally half-open bins [a,b), which is what the ge achieves (and the gt seems to stem from the time when the uniform random number was in fact in (0,1]).

@rfejgin
Copy link

rfejgin commented Aug 2, 2018

Ah, I see the logic now. Makes sense.

@rfejgin
Copy link

rfejgin commented Aug 3, 2018

I just wanted to mention that my availability will be very limited in the next two weeks. But it looks like we're pretty close to resolving this -- nice :).

@yf225
Copy link
Contributor

yf225 commented Aug 14, 2018

@rfejgin @t-vi What should be the next step for this PR?

@t-vi
Copy link
Collaborator Author

t-vi commented Aug 15, 2018 via email

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ailzhang ailzhang self-assigned this Aug 15, 2018
zdevito pushed a commit to zdevito/ATen that referenced this pull request Aug 15, 2018
Summary:
In the shortcut for n_sample=1, when category 0 has 0 weight,
we should not map the (uniform) sample 0 to category 0.
The conversion uniform->multinomial was apparently written to work on
a (0,1] range (like curand uses), but PyTorch uses a [0,1) range.

Fixes: #4858. Thank you, Roy Fejgin for reporting.
Pull Request resolved: pytorch/pytorch#9960

Reviewed By: soumith

Differential Revision: D9341793

Pulled By: ailzhang

fbshipit-source-id: 6b1a96419a7bc58cc594f761f34c6408ff6354cf
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

open source ready for review (this tag is deprecated) All PRs are ready for review unless they are draft, WIP, or have undismissed requested changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

CUDA multinomial with replacement can select zero-probability events

8 participants