Skip to content

LSTM::permute_hidden breaks Liskov substitution principle #43072

@guilhermeleobas

Description

@guilhermeleobas

While working on PR #43068, I found a method on LSTM which breaks the Liskov substitution principle. The class LSTM inherits from RNNBase and redefines a method called permute_hidden which returns two tensors (ref):

class LSTM(RNNBase):
    ...
    def permute_hidden(self, hx, permutation):
        # type: (Tuple[Tensor, Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor]
        if permutation is None:
            return hx
        return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation)

RNNBase defines this method returning only one tensor (ref)

class RNNBase(torch.nn.module):
    ...
    def permute_hidden(self, hx, permutation):
        # type: (Tensor, Optional[Tensor]) -> Tensor
        if permutation is None:
            return hx
        return apply_permutation(hx, permutation)

Given this difference, it is not possible to correctly type permute_hidden because of the Liskov substitution principle. Should LSTM not inherit from RNNBase?

cc @ezyang @malfet @rgommers @zou3519

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: rnnIssues related to RNN support (LSTM, GRU, etc)module: typingRelated to mypy type annotationstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions