@@ -2039,7 +2039,29 @@ def test_as_strided(self):
20392039 strided_mps_out = strided_mps1 - strided_mps2
20402040 self .assertEqual (strided_cpu_out , strided_mps_out )
20412041
2042+ def test_unfold (self ):
2043+ x = torch .arange (1. , 8 )
2044+ x_mps = torch .arange (1. , 8 , device = "mps" )
20422045
2046+ y = x .unfold (0 , 2 , 1 )
2047+ y_mps = x_mps .unfold (0 , 2 , 1 )
2048+
2049+ self .assertEqual (y , y_mps )
2050+
2051+ def test_unfold_all_devices_and_dtypes (self ):
2052+ supported_dtypes = [torch .float32 , torch .float16 , torch .int64 , torch .int32 , torch .int16 , torch .uint8 ]
2053+ for dt in supported_dtypes :
2054+ x = torch .empty ((0 , 1 , 3 , 0 ), dtype = dt , device = "mps" )
2055+ self .assertEqual ((0 , 1 , 1 , 0 , 3 ), x .unfold (2 , 3 , 2 ).shape )
2056+
2057+ def test_unfold_scalars (self ):
2058+ x = torch .tensor (0.5 , device = "mps" )
2059+ # unfold on a 0-dimensional tensor should always return a 1-d dimensional
2060+ # tensor of shape [size] (i.e., the second parameter to unfold)
2061+
2062+ self .assertEqual (torch .empty (0 , device = "mps" ), x .unfold (0 , 0 , 1 ))
2063+ self .assertEqual (torch .empty (0 , device = "mps" ), x .unfold (0 , 0 , 2 ))
2064+ self .assertEqual (torch .tensor ([0.5 ], device = "mps" ), x .unfold (0 , 1 , 1 ))
20432065
20442066 def test_sum_backward (self ):
20452067 def helper (n , c ):
@@ -5726,14 +5748,13 @@ def test_T_view(self, device="mps"):
57265748 v [0 , 1 ] = 0
57275749 self .assertEqual (t [1 , 0 ], v [0 , 1 ])
57285750
5729- # requires aten::unfold
5730- # def test_unfold_view(self, device="mps"):
5731- # t = torch.ones(10, device=device)
5732- # v = t.unfold(0, 3, 2)
5733- # self.assertTrue(self.is_view_of(t, v))
5751+ def test_unfold_view (self , device = "mps" ):
5752+ t = torch .ones (10 , device = device )
5753+ v = t .unfold (0 , 3 , 2 )
5754+ self .assertTrue (self .is_view_of (t , v ))
57345755
5735- # v[1, 0] = 0
5736- # self.assertEqual(t[2], v[1, 0])
5756+ v [1 , 0 ] = 0
5757+ self .assertEqual (t [2 ], v [1 , 0 ])
57375758
57385759 def test_squeeze_view (self , device = "mps" ):
57395760 t = torch .ones (5 , 1 , 5 , device = device )
0 commit comments