Skip to content

Commit 06762b4

Browse files
fritzofacebook-github-bot
authored andcommitted
Fix distributions.Categorical.sample bug from .view() (#23328)
Summary: This modernizes distributions code by replacing a few uses of `.contiguous().view()` with `.reshape()`, fixing a sample bug in the `Categorical` distribution. The bug is exercised by the following test: ```py batch_shape = (1, 2, 1, 3, 1) sample_shape = (4,) cardinality = 2 logits = torch.randn(batch_shape + (cardinality,)) dist.Categorical(logits=logits).sample(sample_shape) # RuntimeError: invalid argument 2: view size is not compatible with # input tensor's size and stride (at least one dimension spans across # two contiguous subspaces). Call .contiguous() before .view(). # at ../aten/src/TH/generic/THTensor.cpp:203 ``` I have verified this works locally, but I have not added this as a regression test because it is unlikely to regress (the code is now simpler). Pull Request resolved: #23328 Differential Revision: D16510678 Pulled By: colesbury fbshipit-source-id: c125c1a37d21d185132e8e8b65241c86ad8ad04b
1 parent be644d8 commit 06762b4

File tree

2 files changed

+3
-6
lines changed

2 files changed

+3
-6
lines changed

torch/distributions/categorical.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,9 @@ def sample(self, sample_shape=torch.Size()):
103103
sample_shape = self._extended_shape(sample_shape)
104104
param_shape = sample_shape + torch.Size((self._num_events,))
105105
probs = self.probs.expand(param_shape)
106-
if self.probs.dim() == 1 or self.probs.size(0) == 1:
107-
probs_2d = probs.view(-1, self._num_events)
108-
else:
109-
probs_2d = probs.contiguous().view(-1, self._num_events)
106+
probs_2d = probs.reshape(-1, self._num_events)
110107
sample_2d = torch.multinomial(probs_2d, 1, True)
111-
return sample_2d.contiguous().view(sample_shape)
108+
return sample_2d.reshape(sample_shape)
112109

113110
def log_prob(self, value):
114111
if self._validate_args:

torch/distributions/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ def _call(self, x):
555555
return torch.stack([self._call_on_event(flat_x[i]) for i in range(flat_x.size(0))]).view(x.shape)
556556

557557
def _inverse(self, y):
558-
flat_y = y.contiguous().view((-1,) + y.shape[-2:])
558+
flat_y = y.reshape((-1,) + y.shape[-2:])
559559
return torch.stack([self._inverse_on_event(flat_y[i]) for i in range(flat_y.size(0))]).view(y.shape)
560560

561561

0 commit comments

Comments
 (0)