Skip to content

Conversation

@martinraison
Copy link
Contributor

@martinraison martinraison commented Feb 13, 2017

This pull request adds more support for sparse operations in Pytorch.

The original goals:

  • ability to propagate sparse updates in a network (e.g. for updating an embedding matrix)
  • ability to efficiently compute "bag-of-words" sentence embeddings (e.g. weighted average of word embeddings)

This request implements the following individual features to achieve those goals:

  • enable backpropagation of sparse gradients without conversion to dense tensors. In most cases a runtime exception is thrown when mixing different gradient types for the same variable
  • add some methods for THSTensor: zero, elementwise add and mul, scalar mul and div
  • make addcmul method of THTensor compatible with sparse operands
  • make spmm method accessible from Python (I had to use the name dsmm since smm was already taken. Maybe we should rename the current smm to ssmm to follow the convention)
  • sparse_mask method on THTensor. This produces a sparse tensor from a dense tensor, by using a sparse tensor as a mask. A value is only present in the output sparse tensor if it also exists in the mask. See the changes to adagrad.py for an example of why this is needed.
  • update Adagrad code to use sparse updates when possible. I was hoping the optimizers wouldn't require any modification, but it looks like they do (let me know if you see a better option). In addition, not every optimizer supports sparse updates easily. For example it looks like Adam is accumulating moments over time, which effectively makes the updates dense. For the same reason, Adagrad doesn't support having both sparse updates + weight decay.
  • leave Variable's gradient to None by default (required updating a few tests). This is because there is no canonical zero gradient anymore (it could be dense or sparse, and if it is sparse we don't know how many dimensions are sparse)
  • I also added the basic glue code to hook up the existing THCS (cuda sparse) tensor implementation to Python (I did pretty much the same as for TH, THC, THS). I did that mainly so that existing tests keep working even when cuda sparse gradients are involved. Most of the THCS operations are still stubs and will throw an exception with an error message, but it means the only thing remaining for GPU sparse ops support is to fill in the appropriate functions.

...and last but not least: N-dimensional values for sparse tensors. This one is a slightly bigger item. Basically for things like applying sparse updates to embedding matrices, only the first dimension (the one that corresponds to the word index) is sparse. The other dimension is always dense (only whole embedding vectors are updated). An elegant solution is to make the values tensor N-dimensional instead of 1-dimensional. For an embedding matrix, the sparse gradient will have a values tensor of size nnz * embedding_size instead of just nnz. I had to update a few existing functions to make that work, but otherwise the changes are actually not that big. Existing usecases with scalar values should all work as usual.

@martinraison martinraison force-pushed the sparse branch 5 times, most recently from 25ead1d to 8b65c9d Compare February 14, 2017 17:31

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@martinraison martinraison force-pushed the sparse branch 5 times, most recently from 0e957a9 to 08635d7 Compare February 15, 2017 00:36
@soumith
Copy link
Contributor

soumith commented Feb 15, 2017

@pytorchbot add to whitelist

@martinraison martinraison force-pushed the sparse branch 4 times, most recently from 6565285 to 165d7bf Compare February 15, 2017 12:50
@martinraison martinraison force-pushed the sparse branch 9 times, most recently from 06312ce to ed2a0d8 Compare February 28, 2017 14:15
@martinraison
Copy link
Contributor Author

I rebased this PR on top of the latest changes, polished a few things and added some tests. I think it is ready for review / merging, let me know what you think.

@martinraison
Copy link
Contributor Author

There are a few changes compared to the original PR description:

  • there is no explicit sparse_grad flag anymore. Whether a Variable's gradient is sparse is not an intrinsic property of the variable itself, instead it is determined by the functions that are applied to it. Therefore sparse gradients are propagated if possible, and an error message is shown if incompatible types of gradients are accumulated for the same variable (we could relax this constraint and fallback to dense gradients when accumulating but I'm not sure if that's better from a user perspective)
  • I removed some functions that don't make much sense / have unclear semantics: scalar add (that would turn a sparse matrix into a dense one), elementwise div and addcdiv (causes division by zero errors unless the denominator index set is a subset of the numerator index set)
  • I added some glue code to bridge THCS and Python to make existing tests pass and lay the groundwork for future sparse cuda tensor support (for the time being I didn't implement any additional sparse cuda ops)

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, but I have some questions and comments. The most important is that I don't understand why do we need separate sizes and dimension fields for indices and values.

I'll have to think if choosing the grad type dynamically, and disallowing changing it later won't be too limiting. I'm a bit afraid that it could cause some surprising errors with no easy workarounds.

.gitignore Outdated

This comment was marked as off-topic.

test/test_nn.py Outdated

This comment was marked as off-topic.

torch/_utils.py Outdated

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@martinraison
Copy link
Contributor Author

martinraison commented Mar 2, 2017

Thanks @apaszke for the feedback! I will address your various comments.

Regarding the difference between nDimensionI and nDimensionV: the idea is to allow values of arbitrary dimension instead of just scalars. This happens for example with word embedding matrices. The matrix will be of size num_words x embedding_size, and updates to that matrix will be sparse, but only along the first dimension (only a few words are used in each batch, however whenever a word embedding changes, the whole embedding changes). In this specific case, nDimensionI=1 (num_words) and nDimensionV=1 (embedding_size). If I call to_dense on a sparse matrix, the result has dimension nDimensionI + nDimensionV. So the "indices" and "values" dimensions are complementary, and can of course be different (for example in the typical sparse matrix case, nDimensionI=2 and nDimensionV=0). Let me know if I'm still unclear.

@apaszke
Copy link
Contributor

apaszke commented Mar 3, 2017

So sparse tensors now have nDimensionI sparse dimensions, followed by nDimensionV dense dimensions, right? That's nice. Were the sparse tensors we had previously equivalent to nDimensionV == 1 or nDimensionV == 0?

@martinraison
Copy link
Contributor Author

martinraison commented Mar 3, 2017

That's exactly right. The sparse tensors we had previously were equivalent to nDimensionV == 0 (which I think makes more sense than nDimensionV == 1)

This comment was marked as off-topic.

@martinraison
Copy link
Contributor Author

@apaszke the test test_variable_sequence_cuda was failing on my branch because x.grad never gets initialized in the CUDA case (the CPU case works fine). I debugged it on master, and it looks like x.grad remains zero during backwards with CUDA (not with CPU), so something seems off with the test. I commented it out on my branch, but it would be good to check why gradients aren't being computed correctly.

@apaszke
Copy link
Contributor

apaszke commented Mar 3, 2017

I found the issue and it's because of this line. On CUDA x is not a leaf, so its grad will never be initialized. Can you please fix that test too?

THTensor *dst, THTensor *src1, real value, THTensor *src2,
long dim, long dstIdx, long src1Idx, long src2Idx) {
if (src1->nDimension > 1) {
THTensor_(select)(src1Buffer, src1, dim, src1Idx);

This comment was marked as off-topic.

THTensor *dstBuffer, THTensor *src1Buffer, THTensor *src2Buffer,
THTensor *dst, THTensor *src1, real value, THTensor *src2,
long dim, long dstIdx, long src1Idx, long src2Idx) {
if (src1->nDimension > 1) {

This comment was marked as off-topic.

This comment was marked as off-topic.

################################################################################

import torch.cuda
import torch.cuda.sparse

This comment was marked as off-topic.

@apaszke apaszke merged commit f17cfe4 into pytorch:master Mar 3, 2017
@apaszke
Copy link
Contributor

apaszke commented Mar 3, 2017

Phew, that was a big diff! Thanks! 🎉

@martinraison
Copy link
Contributor Author

Thanks for taking the time to review all this :)

@martinraison martinraison deleted the sparse branch March 9, 2017 17:38
jjsjann123 pushed a commit to jjsjann123/pytorch that referenced this pull request Apr 11, 2021
Rework reduction heuristics, add a large reduction benchmarking suite.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants