Skip to content

[feature proposal] adaptive softmax #4659

@elanmart

Description

@elanmart

Hi, some time ago Adam mentioned that you were considering adding the adaptive softmax from [1] to pytorch.

I wanted to ask if you would be interested in having an nn.Module implementation of that paper.
From my benchmarks on text-8 dataset (44371 tokens), batch size of 128, bptt of 20, I get the following timings per batch.

standard adaptive
forward [ms] 58.1 13.5
backward [ms] 90.74 24.03

I think implementing backward by hand would require moving this to C++ since the computation requires calling softmax_backward?

Please find below the API that I was thinking about.

[1] Efficient softmax approximation for GPUs

class AdaptiveLogSoftmax(nn.Module):
    def __init__(self, in_features, n_classes, cutoffs):
        """
        Parameters
        ----------
        in_features : int
            dimensionality of input data
        n_classes : int
            number of classes
        cutoffs : List[int]
            cutoff values for clusters. 
            e.g. cutoffs = [100, 1000] means that 100 most frequent labels will
            end up in main cluster, labels between 101 and 1000
             will end up in the first sub-cluster, and labels 
            between 1000 and n_classes will end up in second sub-cluster
        """

        pass

    def reset_parameters(self):
        pass        

    def forward(self, input, target):
        """
        Parameters
        ----------
        input : torch.FloatTensor of size (batch_size x self.in_features)
        target : torch.LongTensor of size (batch_size)

        Returns
        -------
        output : torch.FloatTensor of size (batch_size)
            each entry is a log probability of a corresponding `target` value
        """

        pass
        

    def get_log_proba(self, input):
        """ Computes log probabilities for all `self.n_claseess` classes """

        pass

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions