Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions torch/nn/modules/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,16 @@ def extra_repr(self):
return s.format(**self.__dict__)

@classmethod
def from_pretrained(cls, embeddings, freeze=True):
def from_pretrained(cls, embeddings, freeze=True, sparse=False):

This comment was marked as off-topic.

r"""Creates Embedding instance from given 2-dimensional FloatTensor.

Args:
embeddings (Tensor): FloatTensor containing weights for the Embedding.
First dimension is being passed to Embedding as 'num_embeddings', second as 'embedding_dim'.
freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process.
Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True``
sparse (bool, optional): if ``True``, gradient w.r.t. weight matrix will be a sparse tensor.
See Notes for more details regarding sparse gradients.

Examples::

Expand All @@ -144,7 +146,12 @@ def from_pretrained(cls, embeddings, freeze=True):
assert embeddings.dim() == 2, \
'Embeddings parameter is expected to be 2-dimensional'
rows, cols = embeddings.shape
embedding = cls(num_embeddings=rows, embedding_dim=cols, _weight=embeddings)
embedding = cls(
num_embeddings=rows,
embedding_dim=cols,
_weight=embeddings,
sparse=sparse,
)
embedding.weight.requires_grad = not freeze
return embedding

Expand Down