Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,28 @@ Tensor & unsqueeze_(Tensor& self, int64_t dim) {
return self.as_strided_(std::get<0>(g), std::get<1>(g));
}

Tensor flatten(const Tensor& self, int64_t start_dim, int64_t end_dim) {
start_dim = maybe_wrap_dim(start_dim, self.dim());
end_dim = maybe_wrap_dim(end_dim, self.dim());
AT_CHECK(start_dim <= end_dim, "flatten() has invalid args: start_dim cannot come after end_dim");

if (start_dim == end_dim) {
return self;
}

std::vector<int64_t> shape;
shape.reserve(self.dim() - end_dim + start_dim);
for (int64_t i = 0; i < start_dim; i++) {
shape.push_back(self.size(i));
}
shape.push_back(-1);
for (int64_t i = end_dim + 1; i < self.dim(); i++) {
shape.push_back(self.size(i));
}

return self.reshape(shape);
}

Tensor view_as(const Tensor& self, const Tensor& other) {
return self.view(other.sizes());
}
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,8 @@
variants: function
deprecated: true

- func: flatten(Tensor self, int64_t start_dim=0, int64_t end_dim=-1) -> Tensor

- func: fill_(Tensor self, Scalar value) -> Tensor

- func: fill_(Tensor self, Tensor value) -> Tensor
Expand Down
37 changes: 37 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5429,6 +5429,43 @@ def _fill_indices(self, idx, dim, dim_size, elems_per_row, m, n, o):
ii[dim] = slice(0, idx.size(dim) + 1)
idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row]

def test_flatten(self):
src = torch.randn(5, 5, 5, 5)
flat = src.flatten(0, -1)
self.assertEqual(flat.shape, torch.Size([625]))
self.assertEqual(src.view(-1), flat.view(-1))

flat = src.flatten(0, 2)
self.assertEqual(flat.shape, torch.Size([125, 5]))
self.assertEqual(src.view(-1), flat.view(-1))

flat = src.flatten(0, 1)
self.assertEqual(flat.shape, torch.Size([25, 5, 5]))
self.assertEqual(src.view(-1), flat.view(-1))

flat = src.flatten(1, 2)
self.assertEqual(flat.shape, torch.Size([5, 25, 5]))
self.assertEqual(src.view(-1), flat.view(-1))

flat = src.flatten(2, 3)
self.assertEqual(flat.shape, torch.Size([5, 5, 25]))
self.assertEqual(src.view(-1), flat.view(-1))

flat = src.flatten(-2, -1)
self.assertEqual(flat.shape, torch.Size([5, 5, 25]))
self.assertEqual(src.view(-1), flat.view(-1))

flat = src.flatten(2, 2)
self.assertEqual(flat, src)

# out of bounds index
with self.assertRaisesRegex(RuntimeError, 'dimension out of range'):
src.flatten(5, 10)

# invalid start and end
with self.assertRaisesRegex(RuntimeError, 'start_dim cannot come after end_dim'):
src.flatten(2, 0)

@staticmethod
def _test_gather(self, cast, test_bounds=True):
m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
Expand Down
8 changes: 4 additions & 4 deletions tools/autograd/gen_variable_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,14 +364,14 @@ def reference_args(args):

def get_trace_outputs(declaration):
if declaration['return_type'] == 'std::vector<Tensor>':
return 'flatten({})'.format(declaration['returns'][0]['name'])
return 'flatten_tensor_args({})'.format(declaration['returns'][0]['name'])
elif name.endswith('_out'):
output_args = [arg['name'] for arg in arguments
if arg.get('output', False)]
return '{' + ', '.join(output_args) + '}'
trace_outs = [r['name'] for r in declaration['returns']]
if any(ret['dynamic_type'] == 'TensorList' for ret in declaration['returns']):
return CodeTemplate("flatten( ${outs} )").substitute(outs=trace_outs)
return CodeTemplate("flatten_tensor_args( ${outs} )").substitute(outs=trace_outs)
else:
return CodeTemplate("{ ${outs} }").substitute(outs=trace_outs)

Expand Down Expand Up @@ -408,7 +408,7 @@ def emit_record_trace(env):
local['tensor_args'] = [arg['name'] for arg in tensor_args]
if any(arg['simple_type'] == 'TensorList' for arg in tensor_args):
# Allocate a temporary vector with flatten and pass it in
local['trace_inputs'] = CodeTemplate("flatten( $tensor_args )").substitute(local)
local['trace_inputs'] = CodeTemplate("flatten_tensor_args( $tensor_args )").substitute(local)
else:
local['trace_inputs'] = CodeTemplate("{ ${tensor_args} }").substitute(local)

Expand Down Expand Up @@ -496,7 +496,7 @@ def emit_history():
fn = 'rebase' if modifies_arguments and not is_view else 'set'
output_names = [r['name'] for r in differentiable_outputs]
# TODO: flatten allocates a std::vector, which could be expensive
outs = CodeTemplate("flatten( ${outs} )").substitute(outs=output_names)
outs = CodeTemplate("flatten_tensor_args( ${outs} )").substitute(outs=output_names)
return SET_HISTORY.substitute(fn=fn, differentiable_outputs=outs)

def emit_save_outputs():
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/templates/VariableType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ struct Flatten : IterArgs<Flatten> {
}
};

template<typename... Args> inline variable_list flatten(Args&&... args) {
template<typename... Args> inline variable_list flatten_tensor_args(Args&&... args) {
variable_list out;
out.reserve(count_tensors(std::forward<Args>(args)...));
Flatten(out).apply(std::forward<Args>(args)...);
Expand Down
7 changes: 7 additions & 0 deletions torch/_tensor_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,13 @@ def add_docstr_all(method, docstr):
In-place version of :meth:`~Tensor.frac`
""")

add_docstr_all('flatten',
r"""
flatten(input, start_dim=0, end_dim=-1) -> Tensor
see :func:`torch.flatten`
""")

add_docstr_all('gather',
r"""
gather(dim, index) -> Tensor
Expand Down
24 changes: 24 additions & 0 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1591,6 +1591,30 @@ def parse_kwargs(desc):
array([-1, 2, 3])
""")

add_docstr(torch.flatten,
r"""
flatten(input, start_dim=0, end_dim=-1) -> Tensor

Flattens a contiguous range of dims in a tensor.

Args:
input (Tensor): the input tensor
start_dim (int): the first dim to flatten
end_dim (int): the last dim to flatten

Example::

>>> t = torch.tensor([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]])
>>> torch.flatten(t)
tensor([1, 2, 3, 4, 5, 6, 7, 8])
>>> torch.flatten(t, start_dim=1)
tensor([[1, 2, 3, 4],
[5, 6, 7, 8]])
""")

add_docstr(torch.gather,
r"""
gather(input, dim, index, out=None) -> Tensor
Expand Down