Skip to content

Conversation

@ssnl
Copy link
Collaborator

@ssnl ssnl commented Feb 6, 2018

The CPU code previously uses resize2d which does not safely preserve data when input is not contiguous. This PR changes it to unsqueeze1d in both CPU and GPU functions, and add strided support for the two functions. As a result, distributions.Categorical now doesn't need to call contiguous() when the probabilities are not batched.

Also in this PR are

  1. updated tests for torch.multinomial,
  2. updated doc for distrbutions.OneHotCategorical and distrbutions.Categorical, and
  3. shape tests for distrbutions.OneHotCategorical and distrbutions.Categorical in unbatched mode.

Fixes #5062

cc @neerajprad @alicanb for distributions.* changes review.

Copy link
Contributor

@soumith soumith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor typo changes needed.

code looks good and correct

TH_API void THTensor_(expand)(THTensor *r, THTensor *tensor, THLongStorage *size);
TH_API void THTensor_(expandNd)(THTensor **rets, THTensor **ops, int count);

// resize* methods simplit resize the storage. So they may not retain the current data at current indices.

This comment was marked as off-topic.


// resize* methods simplit resize the storage. So they may not retain the current data at current indices.
// This is especially likely to happen when the tensor is not contiguous. In general, if you still need the
// values, unless you are doing some size and stride tricks, do not use reize*.

This comment was marked as off-topic.

THC_API void THCTensor_(expand)(THCState *state, THCTensor *r, THCTensor *tensor, THLongStorage *sizes);
THC_API void THCTensor_(expandNd)(THCState *state, THCTensor **rets, THCTensor **ops, int count);

// resize* methods simplit resize the storage. So they may not retain the current data at current indices.

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@neerajprad neerajprad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the doc fixes, and adding on the unbatched params tests. Looks great!

@soumith soumith merged commit 47ee867 into pytorch:master Feb 7, 2018
@ssnl ssnl deleted the distns branch February 7, 2018 03:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Potential bug when sampling from categorical distribution

4 participants