Skip to content

Commit cc6b046

Browse files
li-roysoumith
authored andcommitted
Implement flatten function (#8578)
* Implement flatten function * address comments * allow start_dim=end_dim * undo submodule change
1 parent 065fdbd commit cc6b046

File tree

7 files changed

+97
-5
lines changed

7 files changed

+97
-5
lines changed

aten/src/ATen/native/TensorShape.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,28 @@ Tensor & unsqueeze_(Tensor& self, int64_t dim) {
600600
return self.as_strided_(std::get<0>(g), std::get<1>(g));
601601
}
602602

603+
Tensor flatten(const Tensor& self, int64_t start_dim, int64_t end_dim) {
604+
start_dim = maybe_wrap_dim(start_dim, self.dim());
605+
end_dim = maybe_wrap_dim(end_dim, self.dim());
606+
AT_CHECK(start_dim <= end_dim, "flatten() has invalid args: start_dim cannot come after end_dim");
607+
608+
if (start_dim == end_dim) {
609+
return self;
610+
}
611+
612+
std::vector<int64_t> shape;
613+
shape.reserve(self.dim() - end_dim + start_dim);
614+
for (int64_t i = 0; i < start_dim; i++) {
615+
shape.push_back(self.size(i));
616+
}
617+
shape.push_back(-1);
618+
for (int64_t i = end_dim + 1; i < self.dim(); i++) {
619+
shape.push_back(self.size(i));
620+
}
621+
622+
return self.reshape(shape);
623+
}
624+
603625
Tensor view_as(const Tensor& self, const Tensor& other) {
604626
return self.view(other.sizes());
605627
}

aten/src/ATen/native/native_functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,8 @@
556556
variants: function
557557
deprecated: true
558558

559+
- func: flatten(Tensor self, int64_t start_dim=0, int64_t end_dim=-1) -> Tensor
560+
559561
- func: fill_(Tensor self, Scalar value) -> Tensor
560562

561563
- func: fill_(Tensor self, Tensor value) -> Tensor

test/test_torch.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5429,6 +5429,43 @@ def _fill_indices(self, idx, dim, dim_size, elems_per_row, m, n, o):
54295429
ii[dim] = slice(0, idx.size(dim) + 1)
54305430
idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row]
54315431

5432+
def test_flatten(self):
5433+
src = torch.randn(5, 5, 5, 5)
5434+
flat = src.flatten(0, -1)
5435+
self.assertEqual(flat.shape, torch.Size([625]))
5436+
self.assertEqual(src.view(-1), flat.view(-1))
5437+
5438+
flat = src.flatten(0, 2)
5439+
self.assertEqual(flat.shape, torch.Size([125, 5]))
5440+
self.assertEqual(src.view(-1), flat.view(-1))
5441+
5442+
flat = src.flatten(0, 1)
5443+
self.assertEqual(flat.shape, torch.Size([25, 5, 5]))
5444+
self.assertEqual(src.view(-1), flat.view(-1))
5445+
5446+
flat = src.flatten(1, 2)
5447+
self.assertEqual(flat.shape, torch.Size([5, 25, 5]))
5448+
self.assertEqual(src.view(-1), flat.view(-1))
5449+
5450+
flat = src.flatten(2, 3)
5451+
self.assertEqual(flat.shape, torch.Size([5, 5, 25]))
5452+
self.assertEqual(src.view(-1), flat.view(-1))
5453+
5454+
flat = src.flatten(-2, -1)
5455+
self.assertEqual(flat.shape, torch.Size([5, 5, 25]))
5456+
self.assertEqual(src.view(-1), flat.view(-1))
5457+
5458+
flat = src.flatten(2, 2)
5459+
self.assertEqual(flat, src)
5460+
5461+
# out of bounds index
5462+
with self.assertRaisesRegex(RuntimeError, 'dimension out of range'):
5463+
src.flatten(5, 10)
5464+
5465+
# invalid start and end
5466+
with self.assertRaisesRegex(RuntimeError, 'start_dim cannot come after end_dim'):
5467+
src.flatten(2, 0)
5468+
54325469
@staticmethod
54335470
def _test_gather(self, cast, test_bounds=True):
54345471
m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)

tools/autograd/gen_variable_type.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -364,14 +364,14 @@ def reference_args(args):
364364

365365
def get_trace_outputs(declaration):
366366
if declaration['return_type'] == 'std::vector<Tensor>':
367-
return 'flatten({})'.format(declaration['returns'][0]['name'])
367+
return 'flatten_tensor_args({})'.format(declaration['returns'][0]['name'])
368368
elif name.endswith('_out'):
369369
output_args = [arg['name'] for arg in arguments
370370
if arg.get('output', False)]
371371
return '{' + ', '.join(output_args) + '}'
372372
trace_outs = [r['name'] for r in declaration['returns']]
373373
if any(ret['dynamic_type'] == 'TensorList' for ret in declaration['returns']):
374-
return CodeTemplate("flatten( ${outs} )").substitute(outs=trace_outs)
374+
return CodeTemplate("flatten_tensor_args( ${outs} )").substitute(outs=trace_outs)
375375
else:
376376
return CodeTemplate("{ ${outs} }").substitute(outs=trace_outs)
377377

@@ -408,7 +408,7 @@ def emit_record_trace(env):
408408
local['tensor_args'] = [arg['name'] for arg in tensor_args]
409409
if any(arg['simple_type'] == 'TensorList' for arg in tensor_args):
410410
# Allocate a temporary vector with flatten and pass it in
411-
local['trace_inputs'] = CodeTemplate("flatten( $tensor_args )").substitute(local)
411+
local['trace_inputs'] = CodeTemplate("flatten_tensor_args( $tensor_args )").substitute(local)
412412
else:
413413
local['trace_inputs'] = CodeTemplate("{ ${tensor_args} }").substitute(local)
414414

@@ -496,7 +496,7 @@ def emit_history():
496496
fn = 'rebase' if modifies_arguments and not is_view else 'set'
497497
output_names = [r['name'] for r in differentiable_outputs]
498498
# TODO: flatten allocates a std::vector, which could be expensive
499-
outs = CodeTemplate("flatten( ${outs} )").substitute(outs=output_names)
499+
outs = CodeTemplate("flatten_tensor_args( ${outs} )").substitute(outs=output_names)
500500
return SET_HISTORY.substitute(fn=fn, differentiable_outputs=outs)
501501

502502
def emit_save_outputs():

tools/autograd/templates/VariableType.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ struct Flatten : IterArgs<Flatten> {
406406
}
407407
};
408408

409-
template<typename... Args> inline variable_list flatten(Args&&... args) {
409+
template<typename... Args> inline variable_list flatten_tensor_args(Args&&... args) {
410410
variable_list out;
411411
out.reserve(count_tensors(std::forward<Args>(args)...));
412412
Flatten(out).apply(std::forward<Args>(args)...);

torch/_tensor_docs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,13 @@ def add_docstr_all(method, docstr):
810810
In-place version of :meth:`~Tensor.frac`
811811
""")
812812

813+
add_docstr_all('flatten',
814+
r"""
815+
flatten(input, start_dim=0, end_dim=-1) -> Tensor
816+
817+
see :func:`torch.flatten`
818+
""")
819+
813820
add_docstr_all('gather',
814821
r"""
815822
gather(dim, index) -> Tensor

torch/_torch_docs.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1591,6 +1591,30 @@ def parse_kwargs(desc):
15911591
array([-1, 2, 3])
15921592
""")
15931593

1594+
add_docstr(torch.flatten,
1595+
r"""
1596+
flatten(input, start_dim=0, end_dim=-1) -> Tensor
1597+
1598+
Flattens a contiguous range of dims in a tensor.
1599+
1600+
Args:
1601+
input (Tensor): the input tensor
1602+
start_dim (int): the first dim to flatten
1603+
end_dim (int): the last dim to flatten
1604+
1605+
Example::
1606+
1607+
>>> t = torch.tensor([[[1, 2],
1608+
[3, 4]],
1609+
[[5, 6],
1610+
[7, 8]]])
1611+
>>> torch.flatten(t)
1612+
tensor([1, 2, 3, 4, 5, 6, 7, 8])
1613+
>>> torch.flatten(t, start_dim=1)
1614+
tensor([[1, 2, 3, 4],
1615+
[5, 6, 7, 8]])
1616+
""")
1617+
15941618
add_docstr(torch.gather,
15951619
r"""
15961620
gather(input, dim, index, out=None) -> Tensor

0 commit comments

Comments
 (0)