Skip to content

Conversation

@mruberry
Copy link
Collaborator

@mruberry mruberry commented Jun 12, 2018

This PR addresses issue #7601.

Currently models that use streams explicitly in forward have to do a lot of extra work to make backwards respect those streams. This PR extends the (recently added) input tracing (see TypeAndShape) to record the devices and streams of inputs. The autograd engine then uses this metadata to enact the expected stream parallelism without extra work from the user.

For example, a model with forward declared like (original example courtesy of @ngimel):

def forward(self,x):
        x0 = x.clone()
        torch._C._cuda_setStream(self.stream1._cdata)
        y0 = self.fc1(x0)
        self.event1.record(stream = torch.cuda.current_stream())
        
        torch._C._cuda_setStream(self.stream2._cdata)
        y1 = self.fc2(x)
        self.event2.record(stream = torch.cuda.current_stream())
        self.stream2.wait_event(self.event1)
        return y0 + y1

currently will backward on a single stream. With this change the kernels will go on the streams they are assigned in forward and both forward and backward will (for appropriate sizes) run the fc1 and fc2 kernels simultaneously.

The crux of this change is, as mentioned, an expansion of the TypeAndShape tracing and a relatively simple change to the autograd engine to use cuda events for stream synchronization. To make this efficient I also added a new AutoGPUAndStream class, exposed getting and setting streams on devices, and removed InputBuffer's AutoGPU (it's now redundant). While making these modifications I also fixed AutoGPU to check before setting the GPU when it's destroyed and to use THCudaCheck instead of its custom error handler. These changes mean that an often excessive cudaSetDevice() is not being called when inputs are added to a buffer.

In addition to allowing users to easily set and use streams that are respected in both forward and backward, this change may encourage modules to do the same and the expanded tracing might allow further optimizations in the autograd engine. (@apaszke, for example, now after initial enumeration we know the number of devices that will be used by a graph task, which might help provide a sense of the "level of parallelism" we should expect.)

@mruberry
Copy link
Collaborator Author

The failures appear relevant and I will diagnose tomorrow.

@mruberry
Copy link
Collaborator Author

This PR is now ready for review.

The prior failures were caused by a module's gradient accumulator metadata being out of date and triggering an assertion.

Metadata being out of date is not a new issue. Calling set_data on a Variable, for instance, will reset that variable's grad_accumulator_ since its metadata may be out of date. To resolve this particular issue I added an additional check to Variable::Impl::get_grad_accumulator(). This check is only for device metadata, and it's possible that the other metadata staleness checks could be merged into this path, too. In general the connections between functions, edges, and variables appear to do a poor job of reusing code paths, however, and a more general cleanup and consolidation may be warranted (and separate from this PR).

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit ambivalent about this PR:

  1. The engine complexity grows by a lot, and the amount of logic devoted to handling streams might make us even more CPU-bound.
  2. I'm also worried about recording more and more metadata for every single place in the graph. It's a lot.
  3. If we want to use backward streams somehow, I don't think that replaying them at a granularity of every op is a good idea.
  4. This will scale very poorly to other backends like ROCm.

Do you have any benchmarks?

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@mruberry
Copy link
Collaborator Author

mruberry commented Jun 13, 2018

Great feedback. Let me try to break it down here:

  1. Code complexity: let me pretty up the code per your suggestions and I think it'll be much clearer and simpler. Shouldn't be an issue afterwards.

  2. Recording metadata: we are adding at most 128 CPU bytes to an existing tracing structure. The CPU cost and memory overhead of this should be minimal. Fun fact, if you look at python_function.cpp there's already a VariableInfo struct that records type, device, size, and whether the variable requires grad or not, and this is a comparable expansion of the analogous struct for Functions.

  3. Backwards streaming design (generally). While there are always better ways to stream backwards I think this is a natural place to plugin streaming with no downside and considerable upside. This PR means that PyTorch users simply stream in forward as they wish and reap the benefits again in backward. While it may seem like we're hammering the stream every op, in practice we're actually just setting the same stream across multiple ops, but getting and setting streams is fast. The per op tracing is also helpful for future prs (I have one planned that I expect will make use of it) and I think it'll be interesting to allow engineers to experiment with multistream modules. cuDNN, for example, makes heavy use of streams, and this change will allow engineers working only in Python to exploit that feature.

  4. Performance. Remember that I removed an (almost always) excessive setDevice on every call of evaluateFunction. setDevice is EXPENSIVE. Everything we're doing on the CPU is very fast, in contrast.

All that said, I do actually have a follow-up PR to address CPU latency and simplify the autograd engine code, but I'm loathe to complicate this PR by including it (it's a large and separable piece of work) and don't want to go down the rabbit hole of discussing it too much. While I think this PR is perf neutral, my follow-up actually does give a performance improvement.

Let me take your code suggestions and gather some appropriate numbers and I think things will look much improved.

@mruberry
Copy link
Collaborator Author

mruberry commented Jun 14, 2018

I incorporated your feedback and I think this latest version is significantly more streamlined.

As for timing, I have tested the mnist, word_language, and time_series examples and there does not appear to any regression from master in the timings. On the simple example I described above, the timing difference vs running without streams is what one might expect. If tensors are too small then CPU latency prevents overlap. If tensors are too large then the GPU cannot overlap them (much) even though they are on different streams. There is a goldilocks zone, however. When x is a 2^11x2^11 tensor then there is nontrivial overlap. In this case 200 reps from the model takes about 3.8s without streaming and 3.6s with streaming, an improvement of 5% end to end performance. The actual backwards pass speeds up by 9% in this case.

A 9% backwards speedup is nontrivial, and I'm sure more elaborate examples could be constructed. This example is about as simple, straightforward, and not overfitting as it comes, however, without writing some new RNN modules (although I do look forward to doing that).

@mruberry
Copy link
Collaborator Author

I'm not familiar with Caffe2 but the failure there seems unrelated:

04:30:53 [ RUN ] NetTest.ChainingForForkJoin
04:30:53 I0614 04:30:53.294577 253 net_dag_utils.cc:102] Operator graph pruning prior to chain compute took: 3.584e-06 secs
04:30:53 I0614 04:30:53.294605 253 net_async_base.cc:417] Using specified CPU pool size: 4; NUMA node id: -1
04:30:53 I0614 04:30:53.294608 253 net_async_base.cc:422] Created new CPU pool, size: 4; NUMA node id: -1
04:30:53 E0614 04:30:53.295459 265 net_async_base.cc:345] [enforce fail at event.cc:93] wrapper->status_ == EventStatus::EVENT_INITIALIZED || wrapper->status_ == EventStatus::EVENT_SCHEDULED. Calling SetFinished on finished event , op NetTestDummy
04:30:53 E0614 04:30:53.295459 263 net_async_base.cc:345] [enforce fail at event.cc:93] wrapper->status_ == EventStatus::EVENT_INITIALIZED || wrapper->status_ == EventStatus::EVENT_SCHEDULED. Calling SetFinished on finished event , op NetTestDummy
04:30:53 unknown file: Failure
04:30:53 C++ exception with description "[enforce fail at event.cc:93] wrapper->status_ == EventStatus::EVENT_INITIALIZED || wrapper->status_ == EventStatus::EVENT_SCHEDULED. Calling SetFinished on finished event " thrown in the test body.
04:30:53 [ FAILED ] NetTest.ChainingForForkJoin (1 ms)

@mruberry
Copy link
Collaborator Author

mruberry commented Jun 15, 2018

Updated to use USE_CUDA instead of WITH_CUDA, merged with master.

The Python 2.7 lint failure appears to not be related.

@mruberry
Copy link
Collaborator Author

mruberry commented Jun 16, 2018

Additional changes in master to AutoGPU are easy to incorporate and have no material effect on the PR. Waiting for review before making any more changes.

@mruberry
Copy link
Collaborator Author

@apaszke Any word on this? I was hoping to get started on some follow-ups.

@apaszke
Copy link
Contributor

apaszke commented Jun 19, 2018

Yes, sorry to keep you waiting. I promise I'll review it today (give me ~2h).

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @colesbury for opinions on extending input metadata, and on adding more code in the critical path of the engine. Maybe you have a better idea for stream management in backward.

Also @ezyang for explanation on what we should use to manage streams (is THCStream still ok, or are we moving to sth else?)


I'm still unsure if the added complexity and code in the hot path is worth it. Can you share what kind of use cases do you have in mind exactly, or at least what kinds of speedups do you see from replaying stream usage in backwards?

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@mruberry
Copy link
Collaborator Author

mruberry commented Jun 19, 2018

These latest refinements sound smart and also like the appropriate way to merge with Master.

The only open issue on the code itself is, I think, what invariant we're setting? Please review what I wrote and see if that fits with your thinking.

I appreciate the general concern about performance, but I think the CPU latency impact of these changes is very low. Again, the tracing is analogous to that already found in python_function (see VariableInfo). With your suggested changes we're only adding a few conditionals and a couple set streams calls (which are fast and just hitting the tls) per function in backwards (we'll make the same number of get/set gpu calls as current master), and we're adding a couple of conditionals, a ++, and 128 CPU bytes in forward.

As for benefits, as mentioned, even for the simple case of parallel linear layers there is a 9% backwards speedup when streaming for "goldilocks" sizes. What I am really looking forward, to, however, is using a stream-aware backwards to implement "fast and flexible RNNs." See (longstanding) issue #711. To get cuDNN-like speeds we're going to have to stream forward and we'll then want to stream backward (see Appleyard). If we have to write custom code to stream that backwards pass that'll (1) be a pain and (2) limit the flexibility.

There are additional independent opportunities for the tracing, too. For example, the autograd engine could be updated so that specialized code like CopyBackwards is not needed (I actually think this would be a really nice change). Currently that can't be done because a function has no notion of what device it expects inputs on, so we don't know whether we need to transfer a tensor or not. It also means that synchronization points can be identified upfront in the initial enumeration of the graph. This latter point is extremely interesting but, unfortunately, the current potential upside there is limited by support for reentrancy.

Now if we're really concerned about cpu latency vs. feature value, let's talk about reentrancy. If the initial graph enumeration was actually the complete graph, then we could use this tracing knowledge to enqueue and encode all synchronization points in advance, getting rid of the current mutex structure and the ref counting done in backward(). That would likely be a significant speedup. If you like I'll commit to a PR that adds a "allow_reentrancy" flag to .backward() (true by default) and if it's set to false uses this alternative fast path. Callbacks, by the way, are also "broken" for reentrant backwards because they're not associated with any GraphTask (and there is no mechanism for them to be) and calling backward() while you backward() will clear them. Ensuring only one backward() happens at a time seems pretty natural (and we can add a torch.autograd.backward() that takes a list of tensors for multiple simultaneous backward() calls, too). Largely a separate issue, I admit, but my point is that tracing is potentially very useful at reducing the exact issue we're concerned it could exacerbate.

(Disallowing reentrancy would also allow us to easily reuse the main thread in backward, btw.)

@mruberry mruberry force-pushed the stream_respect branch 2 times, most recently from 6af70a0 to 2200575 Compare June 20, 2018 23:34
@mruberry
Copy link
Collaborator Author

mruberry commented Jun 21, 2018

I incorporated the feedback and merged with master. Between taking advantages of changes in master and @apaszke's suggestions I think the code is now in excellent shape. In particular:

(1) I reverted any changes to auto_gpu.h since it had been updated in master (to DeviceGuard)
(2) I incorporated a DeviceGuard in AutoGPUStream and extended AutoGPUStream to allow calling set()
(3) The engine.cpp code was further simplified by an update to input_buffer adding a method that returns an index to its first valid cuda tensor
(4) I updated the name of CUDAEvent
(5) I merged the existing out of date metadata check in variable.cpp's set_data with the check I added when attempting to acquire a variable's gradient accumulator
(6) I simplified the code for checking for out of date metadata by adding a new method to InputMetadata

There are two failures in the CI and both appear unrelated:

(1) The failure on caffe2-py2-gcc4.8-ubuntu14.04-test appears to be an unrelated timeout. It is also happening in other PRs.
(2) The failure on pytorch-macos-10.13-py3 also appears to be an unrelated complaint about data loader. I don't believe the failing test touches any of the changed code, even. This build also appears to be failing in other PRs.

@colesbury
Copy link
Member

I like the idea behind this PR and the code is clean and well written. I think it needs tests for the stream behavior. Here's one test that does not pass:

https://gist.github.com/colesbury/a5b5540329b267cfeb76f7349f40881b

However this passes:

https://gist.github.com/colesbury/fe870eca7c914cb94349fabb1870465c

There seems to be a missing cudaStreamWaitEvent. I think x.grad should be available in the same stream as x.

@mruberry
Copy link
Collaborator Author

mruberry commented Jun 21, 2018

@colesbury Thanks for taking such a deep dive. I think you're proposing a smart design change but let's verify we're on the same page.

First, totally agree on the testing. I'll add a few as well that we can look at.

Now for design. Currently in master if you're working on your own stream and call backward() then you are responsible for syncing with the streams backward() uses (the default stream on each device).

Currently in this PR the same is true, except backward() is no longer always using the default stream. In particular, if you move the default_stream.wait_stream(stream) in the failing test to after output.sum().backward() it will pass. You don't need to wait after the apply().

In this way the PR is more natural as backward() reuses the streams you used in forward and you don't need to worry about synchronizing with streams you may not have thought you were using (the default streams).

What may be surprising, however, is the particular relationship between a tensor created in one stream and backward(). I think what you're saying is that users expect this tensor's gradient will be available to the stream it was created in without additional synchronization? That's totally reasonable. However, it will still mean that if you create a tensor in Stream X, then forward/backward in Stream Y, you will still need to sync Y with X after backwarding if you use Y to access the gradient. (Edit: actually this will probably just work. You only need to sync Y with X if Y is not part of X's dependency chain.)

I like this idea a lot because it ensures consistency of what's need to sync with. So, just to be clear, this change will not eliminate the need to sync (which exists today in an even worse way). But it will eliminate the need to sync in these tests because backward() will sync with the stream the tensor was created on naturally.

Cool? I'll start working on the fix now.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks nice now! I have a few final observations/opportunities for simplification, and I'm curious why did we change the GradAccumulator invalidation strategy. The older one seemed to be more efficient.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@yf225
Copy link
Contributor

yf225 commented Jul 10, 2018

@mruberry Do you mind rebasing this PR onto current master? Thanks!

@mruberry
Copy link
Collaborator Author

@yf225 Yep, working on that now. The rebase is using another PR that just went in on Sunday and is changing the functionality per review with @colesbury, so it may take a day.

Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code quality seems good. I'd like someone more familiar with the task logic and the input_metadata to approve though. It is not clear to me whether this is right or wrong.

@ezyang
Copy link
Contributor

ezyang commented Sep 5, 2019

If the latter streams are different from the former, however, then they previously had to sync with the default streams to run properly. The default sync keeps these (vanishingly small) networks running properly.

OK, this explains something I was a bit puzzled about: why you are using the default stream specifically. But now I have the following question: suppose I have changed my local stream, and I call x.backward(). Shouldn't backward synchronize with the non-default stream that was active at the backward() call site, not the default stream?

@mruberry
Copy link
Collaborator Author

mruberry commented Sep 5, 2019

If the latter streams are different from the former, however, then they previously had to sync with the default streams to run properly. The default sync keeps these (vanishingly small) networks running properly.

OK, this explains something I was a bit puzzled about: why you are using the default stream specifically. But now I have the following question: suppose I have changed my local stream, and I call x.backward(). Shouldn't backward synchronize with the non-default stream that was active at the backward() call site, not the default stream?

YES! Actually it should do both. Sync with the default streams for back compat (Sam's point) and the sync with the current streams because it's expected (like how PyTorch ops that use other streams should sync with the current streams). I was actually planning to put that feature in the follow-up PR that enables the non-default stream test flag. I could add it to this PR if you like, though?

@ezyang
Copy link
Contributor

ezyang commented Sep 5, 2019

Sync with the default streams for back compat (Sam's point)

I missed that! I guess eventually someone is going to ask for the ability to turn this sync off, but doesn't have to be this diff.

I was actually planning to put that feature in the follow-up PR that enables the non-default stream test flag. I could add it to this PR if you like, though?

It's up to you. This PR is not too big yet so if it's not blocking anything, feel free to roll in the changes here. But this PR has also been cooking for a long time, so I'm also inclined to move it along.

@mruberry
Copy link
Collaborator Author

mruberry commented Sep 6, 2019

pr/py2-clang7-rocmdeb-ubuntu16.04 failure is unrelated and happening on other submissions (see https://ci.pytorch.org/jenkins/job/pytorch-builds/job/py2-clang7-rocmdeb-ubuntu16.04-test/36541/console)

caffe2_onnx_py2_gcc5_ubuntu16_04_test failures also appears unrelated.

@li-roy
Copy link
Contributor

li-roy commented Sep 6, 2019

@pytorchbot retest this please

@pytorchbot pytorchbot added module: cpu CPU specific problem (e.g., perf, algorithm) module: operators labels Sep 6, 2019
@ezyang
Copy link
Contributor

ezyang commented Sep 9, 2019

@pytorchbot rebase this please

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@mruberry mruberry deleted the stream_respect branch September 11, 2019 06:55
zdevito pushed a commit to zdevito/ATen that referenced this pull request Sep 11, 2019
Summary:
This PR addresses issue pytorch/pytorch#7601.

Currently models that use streams explicitly in forward have to do a lot of extra work to make backwards respect those streams. This PR extends the (recently added) input tracing (see TypeAndShape) to record the devices and streams of inputs. The autograd engine then uses this metadata to enact the expected stream parallelism without extra work from the user.

For example, a model with forward declared like (original example courtesy of ngimel):

```
def forward(self,x):
        x0 = x.clone()
        torch._C._cuda_setStream(self.stream1._cdata)
        y0 = self.fc1(x0)
        self.event1.record(stream = torch.cuda.current_stream())

        torch._C._cuda_setStream(self.stream2._cdata)
        y1 = self.fc2(x)
        self.event2.record(stream = torch.cuda.current_stream())
        self.stream2.wait_event(self.event1)
        return y0 + y1
```

currently will backward on a single stream. With this change the kernels will go on the streams they are assigned in forward and both forward and backward will (for appropriate sizes) run the fc1 and fc2 kernels simultaneously.

The crux of this change is, as mentioned, an expansion of the TypeAndShape tracing and a relatively simple change to the autograd engine to use cuda events for stream synchronization. To make this efficient I also added a new AutoGPUAndStream class, exposed getting and setting streams on devices, and removed InputBuffer's AutoGPU (it's now redundant). While making these modifications I also fixed AutoGPU to check before setting the GPU when it's destroyed and to use THCudaCheck instead of its custom error handler. These changes mean that an often excessive cudaSetDevice() is not being called when inputs are added to a buffer.

In addition to allowing users to easily set and use streams that are respected in both forward and backward, this change may encourage modules to do the same and the expanded tracing might allow further optimizations in the autograd engine. (apaszke, for example, now after initial enumeration we know the number of devices that will be used by a graph task, which might help provide a sense of the "level of parallelism" we should expect.)
Pull Request resolved: pytorch/pytorch#8354

Test Plan: Two tests were added specifically for this behavior.

Differential Revision: D17275980

Pulled By: mruberry

fbshipit-source-id: 92bd50ac782ffa973b159fcbbadb7a083802e45d
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: autograd Related to torch.autograd, and the autograd engine in general module: cpu CPU specific problem (e.g., perf, algorithm) module: cuda Related to torch.cuda, and CUDA support in general module: internals Related to internal abstractions in c10 and ATen module: rocm AMD GPU support for Pytorch open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.