Skip to content

Commit 0670a82

Browse files
DenisVieriu97pytorchmergebot
authored andcommitted
Register unfold key for MPS (#134)
1 parent 55749b9 commit 0670a82

File tree

2 files changed

+29
-8
lines changed

2 files changed

+29
-8
lines changed

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9439,7 +9439,7 @@
94399439
device_check: NoCheck
94409440
device_guard: False
94419441
dispatch:
9442-
CPU, CUDA, Meta: unfold
9442+
CPU, CUDA, Meta, MPS: unfold
94439443
QuantizedCPU, QuantizedCUDA: unfold
94449444

94459445
- func: unfold_backward(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step) -> Tensor

test/test_mps.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2039,7 +2039,29 @@ def test_as_strided(self):
20392039
strided_mps_out = strided_mps1 - strided_mps2
20402040
self.assertEqual(strided_cpu_out, strided_mps_out)
20412041

2042+
def test_unfold(self):
2043+
x = torch.arange(1., 8)
2044+
x_mps = torch.arange(1., 8, device="mps")
20422045

2046+
y = x.unfold(0, 2, 1)
2047+
y_mps = x_mps.unfold(0, 2, 1)
2048+
2049+
self.assertEqual(y, y_mps)
2050+
2051+
def test_unfold_all_devices_and_dtypes(self):
2052+
supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8]
2053+
for dt in supported_dtypes:
2054+
x = torch.empty((0, 1, 3, 0), dtype=dt, device="mps")
2055+
self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape)
2056+
2057+
def test_unfold_scalars(self):
2058+
x = torch.tensor(0.5, device="mps")
2059+
# unfold on a 0-dimensional tensor should always return a 1-d dimensional
2060+
# tensor of shape [size] (i.e., the second parameter to unfold)
2061+
2062+
self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 1))
2063+
self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 2))
2064+
self.assertEqual(torch.tensor([0.5], device="mps"), x.unfold(0, 1, 1))
20432065

20442066
def test_sum_backward(self):
20452067
def helper(n, c):
@@ -5726,14 +5748,13 @@ def test_T_view(self, device="mps"):
57265748
v[0, 1] = 0
57275749
self.assertEqual(t[1, 0], v[0, 1])
57285750

5729-
# requires aten::unfold
5730-
# def test_unfold_view(self, device="mps"):
5731-
# t = torch.ones(10, device=device)
5732-
# v = t.unfold(0, 3, 2)
5733-
# self.assertTrue(self.is_view_of(t, v))
5751+
def test_unfold_view(self, device="mps"):
5752+
t = torch.ones(10, device=device)
5753+
v = t.unfold(0, 3, 2)
5754+
self.assertTrue(self.is_view_of(t, v))
57345755

5735-
# v[1, 0] = 0
5736-
# self.assertEqual(t[2], v[1, 0])
5756+
v[1, 0] = 0
5757+
self.assertEqual(t[2], v[1, 0])
57375758

57385759
def test_squeeze_view(self, device="mps"):
57395760
t = torch.ones(5, 1, 5, device=device)

0 commit comments

Comments
 (0)