Skip to content

[pytorch] [feature request] Flatten convenience method #7743

@vadimkantorov

Description

@vadimkantorov

Minor suggestion (trivial to implement in user code, but having it in the library would improve code brevity). The purpose is to flatten specific trailing dimensions by passing negative dimension index.

Can be useful for aggregating across multiple trailing dimensions, before mean/max etc get multiple dimensions support.

Exists in numpy/tensorflow/onnx, but semantics there doesn't allow flattening only specific dimensions.

def flatten(x, dim):
    return x.view(x.size()[:dim] + (-1, ))

flatten(torch.rand(2,3,4,5,6), dim = -2).shape
# (2, 3, 4, 30)

Metadata

Metadata

Assignees

Labels

todoNot as important as medium or high priority tasks, but we will work on these.

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions