@@ -5429,6 +5429,43 @@ def _fill_indices(self, idx, dim, dim_size, elems_per_row, m, n, o):
54295429 ii [dim ] = slice (0 , idx .size (dim ) + 1 )
54305430 idx [tuple (ii )] = torch .randperm (dim_size )[0 :elems_per_row ]
54315431
5432+ def test_flatten (self ):
5433+ src = torch .randn (5 , 5 , 5 , 5 )
5434+ flat = src .flatten (0 , - 1 )
5435+ self .assertEqual (flat .shape , torch .Size ([625 ]))
5436+ self .assertEqual (src .view (- 1 ), flat .view (- 1 ))
5437+
5438+ flat = src .flatten (0 , 2 )
5439+ self .assertEqual (flat .shape , torch .Size ([125 , 5 ]))
5440+ self .assertEqual (src .view (- 1 ), flat .view (- 1 ))
5441+
5442+ flat = src .flatten (0 , 1 )
5443+ self .assertEqual (flat .shape , torch .Size ([25 , 5 , 5 ]))
5444+ self .assertEqual (src .view (- 1 ), flat .view (- 1 ))
5445+
5446+ flat = src .flatten (1 , 2 )
5447+ self .assertEqual (flat .shape , torch .Size ([5 , 25 , 5 ]))
5448+ self .assertEqual (src .view (- 1 ), flat .view (- 1 ))
5449+
5450+ flat = src .flatten (2 , 3 )
5451+ self .assertEqual (flat .shape , torch .Size ([5 , 5 , 25 ]))
5452+ self .assertEqual (src .view (- 1 ), flat .view (- 1 ))
5453+
5454+ flat = src .flatten (- 2 , - 1 )
5455+ self .assertEqual (flat .shape , torch .Size ([5 , 5 , 25 ]))
5456+ self .assertEqual (src .view (- 1 ), flat .view (- 1 ))
5457+
5458+ flat = src .flatten (2 , 2 )
5459+ self .assertEqual (flat , src )
5460+
5461+ # out of bounds index
5462+ with self .assertRaisesRegex (RuntimeError , 'dimension out of range' ):
5463+ src .flatten (5 , 10 )
5464+
5465+ # invalid start and end
5466+ with self .assertRaisesRegex (RuntimeError , 'start_dim cannot come after end_dim' ):
5467+ src .flatten (2 , 0 )
5468+
54325469 @staticmethod
54335470 def _test_gather (self , cast , test_bounds = True ):
54345471 m , n , o = random .randint (10 , 20 ), random .randint (10 , 20 ), random .randint (10 , 20 )
0 commit comments