Skip to content

Conversation

@ailzhang
Copy link
Contributor

@ailzhang ailzhang commented Oct 9, 2018

Fixes #12260 #2896

torch.multinomial(torch.FloatTensor([0, 1, 0, 0]), 3, replacement=False)

The old behavior is that we return 0 after we run out of postive categories. Now we raise an error based on discussion in the issue thread.

  • Add testcase for cpu & cuda case, in cuda case n_samples=1 is a simple special case, so we test against n_sample=2 instead.

2,
"invalid multinomial distribution (sum of probabilities <= 0)");
THArgCheckWithCleanup((n_categories - n_zeros >= n_sample),
THCleanup(THDoubleTensor_free(cum_dist); if (start_dim == 1) THTensor_(squeeze1d)(prob_dist, prob_dist, 0);),

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ailzhang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Oct 11, 2018
Summary:
Fixes #12260 #2896

```
torch.multinomial(torch.FloatTensor([0, 1, 0, 0]), 3, replacement=False)
```
The old behavior is that we return `0` after we run out of postive categories. Now we raise an error based on discussion in the issue thread.

- Add testcase for cpu & cuda case, in cuda case `n_samples=1` is a simple special case, so we test against `n_sample=2` instead.
Pull Request resolved: pytorch/pytorch#12490

Differential Revision: D10278794

Pulled By: ailzhang

fbshipit-source-id: d04de7a60f60d0c0d648b975db3f3961fcf42db1
@matteorr
Copy link

The documentation at https://pytorch.org/docs/master/torch.html#torch.multinomial does not reflect the change in behavior. The example should be updated to reflect versions 1.x.

ezyang pushed a commit to ezyang/pytorch that referenced this pull request Feb 19, 2019
Summary:
Update documentation to raise awareness of the fix in pytorch#12490. Thanks matteorr for pointing this out!
Pull Request resolved: pytorch#17269

Reviewed By: ezyang

Differential Revision: D14138421

Pulled By: ailzhang

fbshipit-source-id: 6433f9807a6ba1d871eba8e9d37aa6b78fa1e1fd
facebook-github-bot pushed a commit that referenced this pull request Feb 19, 2019
Summary:
Update documentation to raise awareness of the fix in #12490. Thanks matteorr for pointing this out!
Pull Request resolved: #17269

Reviewed By: ezyang

Differential Revision: D14138421

Pulled By: ailzhang

fbshipit-source-id: 6433f9807a6ba1d871eba8e9d37aa6b78fa1e1fd
@ezyang ezyang added the merged label Jun 25, 2019
@asanakoy
Copy link
Contributor

I think it makes sens to print a meaningful error message. I just get RuntimeError: CUDA error: device-side assert triggered

pytorch version 1.1.0

Code to reproduce:

(py36) ~  ipython      
Python 3.6.7 |Anaconda, Inc.| (default, Oct 23 2018, 19:16:44) 
Type 'copyright', 'credits' or 'license' for more information
IPython 7.2.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import torch                                                                                                                                                                     

In [2]: weights = torch.zeros(5, 6); weights[:, 3] = 1; weights = weights.cuda()                                                                                                         

In [3]: torch.multinomial(weights, num_samples=1, replacement=False)                                                                                                                     
Out[3]: 
tensor([[3],
        [3],
        [3],
        [3],
        [3]], device='cuda:0')

In [4]: torch.multinomial(weights, num_samples=2, replacement=False)                                                                                                                     
Out[4]: ---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
~/anaconda2/envs/py36/lib/python3.6/site-packages/IPython/core/formatters.py in __call__(self, obj)
    700                 type_pprinters=self.type_printers,
    701                 deferred_pprinters=self.deferred_printers)
--> 702             printer.pretty(obj)
    703             printer.flush()
    704             return stream.getvalue()

~/anaconda2/envs/py36/lib/python3.6/site-packages/IPython/lib/pretty.py in pretty(self, obj)
    400                         if cls is not object \
    401                                 and callable(cls.__dict__.get('__repr__')):
--> 402                             return _repr_pprint(obj, self, cycle)
    403 
    404             return _default_pprint(obj, self, cycle)

~/anaconda2/envs/py36/lib/python3.6/site-packages/IPython/lib/pretty.py in _repr_pprint(obj, p, cycle)
    695     """A pprint that just redirects to the normal repr function."""
    696     # Find newlines and replace them with p.break_()
--> 697     output = repr(obj)
    698     for idx,output_line in enumerate(output.splitlines()):
    699         if idx:

~/anaconda2/envs/py36/lib/python3.6/site-packages/torch/tensor.py in __repr__(self)
     69         # characters to replace unicode characters with.
     70         if sys.version_info > (3,):
---> 71             return torch._tensor_str._str(self)
     72         else:
     73             if hasattr(sys.stdout, 'encoding'):

~/anaconda2/envs/py36/lib/python3.6/site-packages/torch/_tensor_str.py in _str(self)
    284                 tensor_str = _tensor_str(self.to_dense(), indent)
    285             else:
--> 286                 tensor_str = _tensor_str(self, indent)
    287 
    288     if self.layout != torch.strided:

~/anaconda2/envs/py36/lib/python3.6/site-packages/torch/_tensor_str.py in _tensor_str(self, indent)
    199     if self.dtype is torch.float16:
    200         self = self.float()
--> 201     formatter = _Formatter(get_summarized_data(self) if summarize else self)
    202     return _tensor_str_with_formatter(self, indent, formatter, summarize)
    203 

~/anaconda2/envs/py36/lib/python3.6/site-packages/torch/_tensor_str.py in __init__(self, tensor)
     81         if not self.floating_dtype:
     82             for value in tensor_view:
---> 83                 value_str = '{}'.format(value)
     84                 self.max_width = max(self.max_width, len(value_str))
     85 

~/anaconda2/envs/py36/lib/python3.6/site-packages/torch/tensor.py in __format__(self, format_spec)
    384     def __format__(self, format_spec):
    385         if self.dim() == 0:
--> 386             return self.item().__format__(format_spec)
    387         return object.__format__(self, format_spec)
    388 

RuntimeError: CUDA error: device-side assert triggered

@soumith
Copy link
Contributor

soumith commented Jul 22, 2019

@asanakoy due to limitations with CUDA, we cannot print a better error message

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

torch.multinomial without replacement returns repetitive values when all non-zero items are exhausted

6 participants