Skip to content

Commit 2153c5e

Browse files
author
Roy Li
committed
address comments
1 parent d1590d3 commit 2153c5e

File tree

8 files changed

+35
-22
lines changed

8 files changed

+35
-22
lines changed

aten/src/ATen/native/TensorShape.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -600,21 +600,22 @@ Tensor & unsqueeze_(Tensor& self, int64_t dim) {
600600
return self.as_strided_(std::get<0>(g), std::get<1>(g));
601601
}
602602

603-
Tensor flatten(const Tensor& self, int64_t start, int64_t end) {
604-
start = maybe_wrap_dim(start, self.dim());
605-
end = maybe_wrap_dim(end, self.dim());
606-
AT_CHECK(start < end, "start dim must be before end dim");
603+
Tensor flatten(const Tensor& self, int64_t start_dim, int64_t end_dim) {
604+
start_dim = maybe_wrap_dim(start_dim, self.dim());
605+
end_dim = maybe_wrap_dim(end_dim, self.dim());
606+
AT_CHECK(start_dim < end_dim, "start_dim must be before end_dim");
607607

608608
std::vector<int64_t> shape;
609-
for (int i = 0; i < start; i++) {
610-
shape.push_back(self.sizes()[i]);
609+
shape.reserve(self.dim() - end_dim + start_dim);
610+
for (int64_t i = 0; i < start_dim; i++) {
611+
shape.push_back(self.size(i));
611612
}
612613
shape.push_back(-1);
613-
for (int i = end + 1; i < self.dim(); i++) {
614-
shape.push_back(self.sizes()[i]);
614+
for (int64_t i = end_dim + 1; i < self.dim(); i++) {
615+
shape.push_back(self.size(i));
615616
}
616617

617-
return self.view(IntList(shape));
618+
return self.reshape(shape);
618619
}
619620

620621
Tensor view_as(const Tensor& self, const Tensor& other) {

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@
556556
variants: function
557557
deprecated: true
558558

559-
- func: flatten(Tensor self, int64_t start, int64_t end) -> Tensor
559+
- func: flatten(Tensor self, int64_t start_dim=0, int64_t end_dim=-1) -> Tensor
560560

561561
- func: fill_(Tensor self, Scalar value) -> Tensor
562562

test/test_torch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5430,7 +5430,7 @@ def _fill_indices(self, idx, dim, dim_size, elems_per_row, m, n, o):
54305430
idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row]
54315431

54325432
def test_flatten(self):
5433-
src = torch.randn(5,5,5,5)
5433+
src = torch.randn(5, 5, 5, 5)
54345434
flat = src.flatten(0, -1)
54355435
self.assertEqual(flat.shape, torch.Size([625]))
54365436
self.assertEqual(src.view(-1), flat.view(-1))
@@ -5460,7 +5460,7 @@ def test_flatten(self):
54605460
src.flatten(5, 10)
54615461

54625462
# invalid start and end
5463-
with self.assertRaisesRegex(RuntimeError, 'start dim must be before end dim'):
5463+
with self.assertRaisesRegex(RuntimeError, 'start_dim must be before end_dim'):
54645464
src.flatten(2, 0)
54655465

54665466
@staticmethod

tools/autograd/derivatives.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@
265265
self: zeros_like(grad)
266266
value: grad.sum()
267267

268-
- name: flatten(Tensor self, int64_t start, int64_t end)
268+
- name: flatten(Tensor self, int64_t start_dim, int64_t end_dim)
269269
self: grad.view(self.sizes())
270270

271271
- name: floor(Tensor self)

tools/autograd/gen_variable_type.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -364,14 +364,14 @@ def reference_args(args):
364364

365365
def get_trace_outputs(declaration):
366366
if declaration['return_type'] == 'std::vector<Tensor>':
367-
return 'flatten_tensor({})'.format(declaration['returns'][0]['name'])
367+
return 'flatten_tensor_args({})'.format(declaration['returns'][0]['name'])
368368
elif name.endswith('_out'):
369369
output_args = [arg['name'] for arg in arguments
370370
if arg.get('output', False)]
371371
return '{' + ', '.join(output_args) + '}'
372372
trace_outs = [r['name'] for r in declaration['returns']]
373373
if any(ret['dynamic_type'] == 'TensorList' for ret in declaration['returns']):
374-
return CodeTemplate("flatten_tensor( ${outs} )").substitute(outs=trace_outs)
374+
return CodeTemplate("flatten_tensor_args( ${outs} )").substitute(outs=trace_outs)
375375
else:
376376
return CodeTemplate("{ ${outs} }").substitute(outs=trace_outs)
377377

@@ -408,7 +408,7 @@ def emit_record_trace(env):
408408
local['tensor_args'] = [arg['name'] for arg in tensor_args]
409409
if any(arg['simple_type'] == 'TensorList' for arg in tensor_args):
410410
# Allocate a temporary vector with flatten and pass it in
411-
local['trace_inputs'] = CodeTemplate("flatten_tensor( $tensor_args )").substitute(local)
411+
local['trace_inputs'] = CodeTemplate("flatten_tensor_args( $tensor_args )").substitute(local)
412412
else:
413413
local['trace_inputs'] = CodeTemplate("{ ${tensor_args} }").substitute(local)
414414

@@ -496,7 +496,7 @@ def emit_history():
496496
fn = 'rebase' if modifies_arguments and not is_view else 'set'
497497
output_names = [r['name'] for r in differentiable_outputs]
498498
# TODO: flatten allocates a std::vector, which could be expensive
499-
outs = CodeTemplate("flatten_tensor( ${outs} )").substitute(outs=output_names)
499+
outs = CodeTemplate("flatten_tensor_args( ${outs} )").substitute(outs=output_names)
500500
return SET_HISTORY.substitute(fn=fn, differentiable_outputs=outs)
501501

502502
def emit_save_outputs():

tools/autograd/templates/VariableType.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ struct Flatten : IterArgs<Flatten> {
406406
}
407407
};
408408

409-
template<typename... Args> inline variable_list flatten_tensor(Args&&... args) {
409+
template<typename... Args> inline variable_list flatten_tensor_args(Args&&... args) {
410410
variable_list out;
411411
out.reserve(count_tensors(std::forward<Args>(args)...));
412412
Flatten(out).apply(std::forward<Args>(args)...);

torch/_tensor_docs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -812,7 +812,7 @@ def add_docstr_all(method, docstr):
812812

813813
add_docstr_all('flatten',
814814
r"""
815-
flatten(input, start, end) -> Tensor
815+
flatten(input, start_dim=0, end_dim=-1) -> Tensor
816816
817817
see :func:`torch.flatten`
818818
""")

torch/_torch_docs.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,14 +1593,26 @@ def parse_kwargs(desc):
15931593

15941594
add_docstr(torch.flatten,
15951595
r"""
1596-
flatten(input, start, end) -> Tensor
1596+
flatten(input, start_dim=0, end_dim=-1) -> Tensor
15971597
15981598
Flattens a contiguous range of dims in a tensor.
15991599
16001600
Args:
16011601
input (Tensor): the input tensor
1602-
start (int): the first dim to flatten
1603-
end (int): the last dim to flatten
1602+
start_dim (int): the first dim to flatten
1603+
end_dim (int): the last dim to flatten
1604+
1605+
Example::
1606+
1607+
>>> t = torch.tensor([[[1, 2],
1608+
[3, 4]],
1609+
[[5, 6],
1610+
[7, 8]]])
1611+
>>> torch.flatten(t)
1612+
tensor([1, 2, 3, 4, 5, 6, 7, 8])
1613+
>>> torch.flatten(t, start_dim=1)
1614+
tensor([[1, 2, 3, 4],
1615+
[5, 6, 7, 8]])
16041616
""")
16051617

16061618
add_docstr(torch.gather,

0 commit comments

Comments
 (0)