Skip to content

Commit d1590d3

Browse files
author
Roy Li
committed
Implement flatten function
1 parent 0a5fe55 commit d1590d3

File tree

9 files changed

+81
-6
lines changed

9 files changed

+81
-6
lines changed

aten/src/ATen/native/TensorShape.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,23 @@ Tensor & unsqueeze_(Tensor& self, int64_t dim) {
600600
return self.as_strided_(std::get<0>(g), std::get<1>(g));
601601
}
602602

603+
Tensor flatten(const Tensor& self, int64_t start, int64_t end) {
604+
start = maybe_wrap_dim(start, self.dim());
605+
end = maybe_wrap_dim(end, self.dim());
606+
AT_CHECK(start < end, "start dim must be before end dim");
607+
608+
std::vector<int64_t> shape;
609+
for (int i = 0; i < start; i++) {
610+
shape.push_back(self.sizes()[i]);
611+
}
612+
shape.push_back(-1);
613+
for (int i = end + 1; i < self.dim(); i++) {
614+
shape.push_back(self.sizes()[i]);
615+
}
616+
617+
return self.view(IntList(shape));
618+
}
619+
603620
Tensor view_as(const Tensor& self, const Tensor& other) {
604621
return self.view(other.sizes());
605622
}

aten/src/ATen/native/native_functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,8 @@
556556
variants: function
557557
deprecated: true
558558

559+
- func: flatten(Tensor self, int64_t start, int64_t end) -> Tensor
560+
559561
- func: fill_(Tensor self, Scalar value) -> Tensor
560562

561563
- func: fill_(Tensor self, Tensor value) -> Tensor

test/test_torch.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5429,6 +5429,40 @@ def _fill_indices(self, idx, dim, dim_size, elems_per_row, m, n, o):
54295429
ii[dim] = slice(0, idx.size(dim) + 1)
54305430
idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row]
54315431

5432+
def test_flatten(self):
5433+
src = torch.randn(5,5,5,5)
5434+
flat = src.flatten(0, -1)
5435+
self.assertEqual(flat.shape, torch.Size([625]))
5436+
self.assertEqual(src.view(-1), flat.view(-1))
5437+
5438+
flat = src.flatten(0, 2)
5439+
self.assertEqual(flat.shape, torch.Size([125, 5]))
5440+
self.assertEqual(src.view(-1), flat.view(-1))
5441+
5442+
flat = src.flatten(0, 1)
5443+
self.assertEqual(flat.shape, torch.Size([25, 5, 5]))
5444+
self.assertEqual(src.view(-1), flat.view(-1))
5445+
5446+
flat = src.flatten(1, 2)
5447+
self.assertEqual(flat.shape, torch.Size([5, 25, 5]))
5448+
self.assertEqual(src.view(-1), flat.view(-1))
5449+
5450+
flat = src.flatten(2, 3)
5451+
self.assertEqual(flat.shape, torch.Size([5, 5, 25]))
5452+
self.assertEqual(src.view(-1), flat.view(-1))
5453+
5454+
flat = src.flatten(-2, -1)
5455+
self.assertEqual(flat.shape, torch.Size([5, 5, 25]))
5456+
self.assertEqual(src.view(-1), flat.view(-1))
5457+
5458+
# out of bounds index
5459+
with self.assertRaisesRegex(RuntimeError, 'dimension out of range'):
5460+
src.flatten(5, 10)
5461+
5462+
# invalid start and end
5463+
with self.assertRaisesRegex(RuntimeError, 'start dim must be before end dim'):
5464+
src.flatten(2, 0)
5465+
54325466
@staticmethod
54335467
def _test_gather(self, cast, test_bounds=True):
54345468
m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)

third_party/onnx

Submodule onnx updated 47 files

tools/autograd/derivatives.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,9 @@
265265
self: zeros_like(grad)
266266
value: grad.sum()
267267

268+
- name: flatten(Tensor self, int64_t start, int64_t end)
269+
self: grad.view(self.sizes())
270+
268271
- name: floor(Tensor self)
269272
self: zeros_like(grad)
270273

tools/autograd/gen_variable_type.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -364,14 +364,14 @@ def reference_args(args):
364364

365365
def get_trace_outputs(declaration):
366366
if declaration['return_type'] == 'std::vector<Tensor>':
367-
return 'flatten({})'.format(declaration['returns'][0]['name'])
367+
return 'flatten_tensor({})'.format(declaration['returns'][0]['name'])
368368
elif name.endswith('_out'):
369369
output_args = [arg['name'] for arg in arguments
370370
if arg.get('output', False)]
371371
return '{' + ', '.join(output_args) + '}'
372372
trace_outs = [r['name'] for r in declaration['returns']]
373373
if any(ret['dynamic_type'] == 'TensorList' for ret in declaration['returns']):
374-
return CodeTemplate("flatten( ${outs} )").substitute(outs=trace_outs)
374+
return CodeTemplate("flatten_tensor( ${outs} )").substitute(outs=trace_outs)
375375
else:
376376
return CodeTemplate("{ ${outs} }").substitute(outs=trace_outs)
377377

@@ -408,7 +408,7 @@ def emit_record_trace(env):
408408
local['tensor_args'] = [arg['name'] for arg in tensor_args]
409409
if any(arg['simple_type'] == 'TensorList' for arg in tensor_args):
410410
# Allocate a temporary vector with flatten and pass it in
411-
local['trace_inputs'] = CodeTemplate("flatten( $tensor_args )").substitute(local)
411+
local['trace_inputs'] = CodeTemplate("flatten_tensor( $tensor_args )").substitute(local)
412412
else:
413413
local['trace_inputs'] = CodeTemplate("{ ${tensor_args} }").substitute(local)
414414

@@ -496,7 +496,7 @@ def emit_history():
496496
fn = 'rebase' if modifies_arguments and not is_view else 'set'
497497
output_names = [r['name'] for r in differentiable_outputs]
498498
# TODO: flatten allocates a std::vector, which could be expensive
499-
outs = CodeTemplate("flatten( ${outs} )").substitute(outs=output_names)
499+
outs = CodeTemplate("flatten_tensor( ${outs} )").substitute(outs=output_names)
500500
return SET_HISTORY.substitute(fn=fn, differentiable_outputs=outs)
501501

502502
def emit_save_outputs():

tools/autograd/templates/VariableType.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ struct Flatten : IterArgs<Flatten> {
406406
}
407407
};
408408

409-
template<typename... Args> inline variable_list flatten(Args&&... args) {
409+
template<typename... Args> inline variable_list flatten_tensor(Args&&... args) {
410410
variable_list out;
411411
out.reserve(count_tensors(std::forward<Args>(args)...));
412412
Flatten(out).apply(std::forward<Args>(args)...);

torch/_tensor_docs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,13 @@ def add_docstr_all(method, docstr):
810810
In-place version of :meth:`~Tensor.frac`
811811
""")
812812

813+
add_docstr_all('flatten',
814+
r"""
815+
flatten(input, start, end) -> Tensor
816+
817+
see :func:`torch.flatten`
818+
""")
819+
813820
add_docstr_all('gather',
814821
r"""
815822
gather(dim, index) -> Tensor

torch/_torch_docs.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1591,6 +1591,18 @@ def parse_kwargs(desc):
15911591
array([-1, 2, 3])
15921592
""")
15931593

1594+
add_docstr(torch.flatten,
1595+
r"""
1596+
flatten(input, start, end) -> Tensor
1597+
1598+
Flattens a contiguous range of dims in a tensor.
1599+
1600+
Args:
1601+
input (Tensor): the input tensor
1602+
start (int): the first dim to flatten
1603+
end (int): the last dim to flatten
1604+
""")
1605+
15941606
add_docstr(torch.gather,
15951607
r"""
15961608
gather(input, dim, index, out=None) -> Tensor

0 commit comments

Comments
 (0)