Commit 06762b4
Fix distributions.Categorical.sample bug from .view() (#23328)
Summary:
This modernizes distributions code by replacing a few uses of `.contiguous().view()` with `.reshape()`, fixing a sample bug in the `Categorical` distribution.
The bug is exercised by the following test:
```py
batch_shape = (1, 2, 1, 3, 1)
sample_shape = (4,)
cardinality = 2
logits = torch.randn(batch_shape + (cardinality,))
dist.Categorical(logits=logits).sample(sample_shape)
# RuntimeError: invalid argument 2: view size is not compatible with
# input tensor's size and stride (at least one dimension spans across
# two contiguous subspaces). Call .contiguous() before .view().
# at ../aten/src/TH/generic/THTensor.cpp:203
```
I have verified this works locally, but I have not added this as a regression test because it is unlikely to regress (the code is now simpler).
Pull Request resolved: #23328
Differential Revision: D16510678
Pulled By: colesbury
fbshipit-source-id: c125c1a37d21d185132e8e8b65241c86ad8ad04b1 parent be644d8 commit 06762b4
2 files changed
+3
-6
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
103 | 103 | | |
104 | 104 | | |
105 | 105 | | |
106 | | - | |
107 | | - | |
108 | | - | |
109 | | - | |
| 106 | + | |
110 | 107 | | |
111 | | - | |
| 108 | + | |
112 | 109 | | |
113 | 110 | | |
114 | 111 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
555 | 555 | | |
556 | 556 | | |
557 | 557 | | |
558 | | - | |
| 558 | + | |
559 | 559 | | |
560 | 560 | | |
561 | 561 | | |
| |||
0 commit comments