Skip to content

Commit 136dadd

Browse files
XiaobingSuperpytorchmergebot
authored andcommitted
fix norrow_copy correctness issue for non-contiguous input for cpu path (#91789)
Fix #91690. Pull Request resolved: #91789 Approved by: https://github.com/jgong5, https://github.com/lezcano
1 parent 8cec433 commit 136dadd

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

aten/src/ATen/native/TensorShape.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1217,7 +1217,7 @@ Tensor narrow_copy_dense(const Tensor& self, int64_t dim, int64_t start, int64_t
12171217
// Should just use narrow_copy_out, but this API is used internally at Meta:
12181218
// https://github.com/pytorch/pytorch/pull/87045#issuecomment-1309353561
12191219
Tensor narrow_copy_dense_cpu(const Tensor& self, int64_t dim, int64_t start, int64_t length){
1220-
auto output = at::empty_like(self);
1220+
auto output = at::empty({0}, self.options());
12211221
return narrow_copy_dense_cpu_out(self, dim, start, length, output);
12221222
}
12231223

test/test_torch.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2971,6 +2971,13 @@ def test_narrow_empty(self, device):
29712971
sz[d] = 0
29722972
self.assertEqual(sz, y.size())
29732973

2974+
def test_narrow_copy_non_contiguous(self, device):
2975+
# see https://github.com/pytorch/pytorch/issues/91690.
2976+
inp = torch.randn(10, 2, device=device).movedim(-1, 0)
2977+
expected = torch.narrow_copy(inp.contiguous(), 1, 0, 10)
2978+
actual = torch.narrow_copy(inp, 1, 0, 10)
2979+
self.assertEqual(expected, actual)
2980+
29742981
# FIXME: move to indexing test suite
29752982
@parametrize("reduce", ['prod', 'amin', 'amax', 'mean'])
29762983
@dtypes(*all_types_and(torch.half, torch.bfloat16))

0 commit comments

Comments
 (0)