Skip to content

Conversation

@ngimel
Copy link
Collaborator

@ngimel ngimel commented Oct 12, 2018

This PR gets rid of unnecessary copy of weight gradients in cudnn rnn. Also removes unnecessary check for input size when deciding whether to use persistent rnn, and adds doc string explaining when persistent rnn can be used. cc @ezyang

@soumith
Copy link
Contributor

soumith commented Oct 12, 2018

test failures are unrelated

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

soumith is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@houseroad
Copy link
Member

@pytorchbot retest this please

zdevito pushed a commit to zdevito/ATen that referenced this pull request Oct 12, 2018
Summary:
This PR gets rid of unnecessary copy of weight gradients in cudnn rnn. Also removes unnecessary check for  input size when deciding whether to use persistent rnn, and adds doc string explaining when persistent rnn can be used. cc ezyang
Pull Request resolved: pytorch/pytorch#12600

Differential Revision: D10359981

Pulled By: soumith

fbshipit-source-id: 0fce11b527d543fabf21e6e9213fb2879853d7fb
if (copy) {
param_to.copy_(param_from.view_as(param_to));
} else {
param_from.resize_as_(param_to);

This comment was marked as off-topic.

This comment was marked as off-topic.

3) input data has dtype ``torch.float16``
4) V100 GPU is used,
5) input data is not in ``PackedSequence`` format
persistent algorithm can be selected to improve performance.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@ngimel ngimel deleted the rnn_copies branch January 16, 2019 19:51
facebook-github-bot pushed a commit that referenced this pull request Aug 18, 2020
Summary:
Should close #36428.

The cudnn RNN API expects weights to occupy a flat buffer in memory with a particular layout.  This PR implements a "speed of light" fix:  [`_cudnn_rnn_cast_reflatten`](https://github.com/pytorch/pytorch/pull/42385/files#diff-9ef93b6a4fb5a06a37c562b83737ac6aR327) (the autocast wrapper assigned to `_cudnn_rnn`) copies weights to the right slices of a flat FP16 buffer with a single read/write per weight (as opposed to casting them to FP16 individually then reflattening the individual FP16 weights, which would require 2 read/writes per weight).

It isn't pretty but IMO it doesn't make rnn bindings much more tortuous than they already are.

The [test](https://github.com/pytorch/pytorch/pull/42385/files#diff-e68a7bc6ba14f212e5e7eb3727394b40R2683) tries a forward under autocast and a backward for the full cross product of RNN options and input/weight/hidden dtypes.  As for all FP16list autocast tests, forward output and backward grads are checked against a control where inputs (including RNN module weights in this case) are precasted to FP16 on the python side.

Not sure who to ask for review, tagging ezyang and ngimel because Ed wrote this file (almost 2 years ago) and Natalia did the most recent major [surgery](#12600).

Side quests discovered:
- Should we update [persistent RNN heuristics](https://github.com/pytorch/pytorch/blob/dbdd28207c5cf6c4a35ceb1de0811c4812e8882c/aten/src/ATen/native/cudnn/RNN.cpp#L584) to include compute capability 8.0?  Could be another PR but seems easy enough to include.
- Many (maybe all?!) the raw cudnn API calls in [RNN.cpp](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp) are deprecated in cudnn 8.  I don't mind taking the AI to update them since my mental cache is full of rnn stuff, but that would be a substantial separate PR.

Pull Request resolved: #42385

Reviewed By: zhangguanheng66

Differential Revision: D23077782

Pulled By: ezyang

fbshipit-source-id: a2afb1bdab33ba0442879a703df13dc87f03ec2e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants