Skip to content

Conversation

@awgu
Copy link
Collaborator

@awgu awgu commented Feb 28, 2022

Stack from ghstack:

Overview

  • This adds FSDP gradient accumulation without no_sync(), which comparatively has more network bandwidth demand but less GPU memory requirement per worker.
  • This fixes a bug in the no_sync() testing, where the CPU offloading and backward prefetch arguments were not propagating to the FullyShardedDataParallel constructor.
  • This adds p_assert() (taken from Fairscale), which prints the assert error message before raising the AssertionError. It is meant to be used when running in the autograd backward context since otherwise the error message is swallowed, giving a unhelpful error like:
<built-in method run_backward of torch._C._EngineBase object at 0x7f1fd518dc80> returned NULL without setting an error

NOTE: Gradient accumulation without no_sync() is not currently compatible with CPU offloading.

Test Plan
I augmented the tests to test gradient accumulation interleaving iterations accumulating with and without no_sync().

Differential Revision: D34533546

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 28, 2022

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/pytorch/pytorch/blob/22cfb1fac06126a6da784e002486b12b189b51d7/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default
Add ciflow labels to this PR to trigger more builds:

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
linux-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
linux-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
linux-binary-manywheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
linux-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/trunk ✅ triggered
linux-bionic-rocm4.5-py3.7 ciflow/all, ciflow/default, ciflow/linux, ciflow/rocm, ciflow/trunk ✅ triggered
linux-docs ciflow/all, ciflow/cpu, ciflow/default, ciflow/docs, ciflow/linux, ciflow/trunk ✅ triggered
linux-vulkan-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7-no-ops ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
macos-arm64-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
macos-arm64-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
macos-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
macos-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
macos-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
macos-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
windows-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
windows-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
windows-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
docker-builds ciflow/all, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda10.2-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow, ciflow/trunk 🚫 skipped
linux-docs-push ciflow/all, ciflow/cpu, ciflow/linux, ciflow/scheduled 🚫 skipped
linux-xenial-cuda11.3-py3.7-gcc7-no-ops ciflow/all, ciflow/cuda, ciflow/linux, ciflow/trunk 🚫 skipped
macos-10-15-py3-arm64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-10-15-py3-lite-interpreter-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-11-py3-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
parallelnative-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
periodic-libtorch-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.3-py3.7-gcc7-debug ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.5-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-build ciflow/all, ciflow/android, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
pytorch-xla-linux-bionic-py3.7-clang8 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk, ciflow/xla 🚫 skipped

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Feb 28, 2022

🔗 Helpful links

💊 CI failures summary and remediations

As of commit ab03d0e (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@facebook-github-bot facebook-github-bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Feb 28, 2022
desertfire pushed a commit that referenced this pull request Feb 28, 2022
ghstack-source-id: d32c434
Pull Request resolved: #73535
desertfire pushed a commit that referenced this pull request Feb 28, 2022
ghstack-source-id: 8f046d4
Pull Request resolved: #73535
**Overview**
This adds FSDP gradient accumulation without `no_sync()`, which comparatively has more network bandwidth demand but less GPU memory requirement per worker.



**Test Plan**
I augmented the tests to test gradient accumulation without `no_sync()` and also interleaving iterations accumulating with and without `no_sync()`.

[ghstack-poisoned]
**Overview**
This adds FSDP gradient accumulation without `no_sync()`, which comparatively has more network bandwidth demand but less GPU memory requirement per worker.

This also adds `p_assert()` (taken from Fairscale), which prints the assert error message before raising the `AssertionError`. It is meant to be used when running in the autograd backward context since otherwise the error message is swallowed, giving a unhelpful error like:
```
<built-in method run_backward of torch._C._EngineBase object at 0x7f1fd518dc80> returned NULL without setting an error
```

**Test Plan**
I augmented the tests to test gradient accumulation without `no_sync()` and also interleaving iterations accumulating with and without `no_sync()`.

[ghstack-poisoned]
@awgu
Copy link
Collaborator Author

awgu commented Feb 28, 2022

@awgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@awgu awgu marked this pull request as ready for review February 28, 2022 22:12
Copy link
Contributor

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for turning this around so quickly! A couple of small comments, will stamp after those are addressed!

)
param._saved_grad_shard.data += output.data # type: ignore[attr-defined]
else:
param._saved_grad_shard = output.data # type: ignore[attr-defined]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we switching from using output to output.data in non-grad accumulation use case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. This is good for me to clarify. Was there any reason to use output before?

In my understanding, it does not matter since we are not supporting taking the gradient of the gradient, so we do not need to record any operations on output in the autograd graph. Fairscale uses output.data in both code paths, though I do not see why either way is better for the non-gradient accumulation case.

[2, 4],
"configs",
[
[_GradAccConfig(True, 4)],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if you want, maybe pass in named args here, so future developers know what it is

[_GradAccConfig(True, 4)],
[_GradAccConfig(False, 4)],
[_GradAccConfig(True, 2), _GradAccConfig(False, 2), _GradAccConfig(True, 2)],
[_GradAccConfig(False, 2), _GradAccConfig(True, 2), _GradAccConfig(False, 2)],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering why we have duplicated configs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to test interleaving both ways:

  1. with no_sync() -> without no_sync() -> with no_sync()
  2. without no_sync() -> with no_sync() -> without no_sync()

The reason I wanted two separate tests is that it could matter what was the last accumulation mode right before gradient synchronization.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call! Let's keep it then and add a small comment so no one removes it in the future.

[2, 4],
"configs",
[
[_GradAccConfig(True, 4)],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to reduce # of tests, can we just have:

use_context = true, interval={2,4}
use_context = false, interval={2,4}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given your comments last time, I thought about how to keep the set of tests minimal. I feel like these 4 configs test distinct things:

  1. Gradient accumulation with no_sync()
  2. Gradient accumulation without no_sync()
  3. Gradient accumulation interleaving no_sync() and without no_sync(), where the last iteration before synchronizing gradients is in no_sync()
  4. Gradient accumulation interleaving no_sync() and without no_sync(), where the last iteration before synchronizing gradients is outside no_sync()

My concern is that I do not want to overfit the current working implementation. It not obvious to me that 1) and 2) imply 3) and 4), and if we only had 3) and 4) and they break, then we would probably end up rewriting 1) and 2).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! Considering you've gone through the tradeoff this makes sense to leave as is

# Average grad by world_size for consistency with PyTorch DDP.
output.div_(self.gradient_postdivide_factor)
param.grad.data = output
accumulate_grad = getattr(param, "_saved_grad_shard", None) is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be useful to add a small comment about how gradient accumulation without no_sync is implemented. From my understanding:

  1. During backward, we point an attribute _saved_grad_shard to the gradient shard
  2. If we are accumulating gradients, we accumulate it on _saved_grad_shard
  3. When finalizing backwards before running the optimizer, we point p.grad to the saved grad shard so that the optimizer works on the right accumulated gradient.

can_accumulate_grad = p.grad.device == p.data.device and \
p.grad.size() == p._local_shard.shape # type: ignore[attr-defined]
if can_accumulate_grad:
p._saved_grad_shard = p.grad.data # type: ignore[attr-defined]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comment here that this is what makes the gradient accumulation work?

p.grad.size() != p._orig_size # type: ignore[attr-defined]
or p.grad.device != p.device
):
can_accumulate_grad = p.grad.device == p.data.device and \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a unittest where can_accumulate_grad = False and the user tries to accumulate grads and we raise an appropriate error?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After some investigation, I think that having a non-silent error for the case of gradient accumulation outside no_sync() while using CPU offloading requires some non-trivial re-design. It may be easiest to leave this error as silent for now and work on adding the compatibility itself.

(Some clarifying questions and comments)
Suppose we are using CPU offloading.

  • Outside no_sync(), if we want to accumulate gradients, are we performing the addition between the existing gradient and the newly-reduced gradient on CPU or on GPU? If on CPU, then we should perform a device-to-host transfer of the reduced gradient. If on GPU, then we should perform a host-to-device transfer of the existing gradient and a device-to-host transfer to re-offload the result.
  • Inside no_sync(), the existing implementation does not offload any gradients to CPU. Rather, the gradients are held in GPU memory until the first iteration outside no_sync(), which performs the gradient synchronization. At the end of that iteration's backward pass, the synchronized gradient shard is offloaded to CPU. Should we include any warning or message at runtime to the explain this behavior to the user? I added a note about it to the no_sync() docstring.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The challenge behind a non-silent error is that the no_sync() + CPU offloading case conflicts with the non-no_sync() + CPU offloading case.

  • The crux is that accumulating gradients using no_sync() follows the pattern of: accumulate for N-1 iterations inside no_sync() and execute 1 normal iteration outside no_sync().
  • That final normal iteration is indistinguishable from a non-no_sync() iteration unless we track something like a bool flag indicating that the last iteration was inside no_sync().
    • I am reluctant to add such a flag since it is solely a patch for one case and may be hiding the underlying design problem, but I am open to your thoughts.
  • For the final iteration coming out of no_sync(), the gradient is still on GPU, so performing the accumulation computation param._saved_grad_shard.data += output.data has no issue.
  • For a non-no_sync() iteration (after the first), the gradient was previously offloaded to CPU, so performing the accumulation computation param._saved_grad_shard.data += output.data has conflicting devices.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we are using gradient accumulation outside of no_sync + CPU offload, don't we already raise an appropriate error on L1336 of this PR? And is it possible to add a unittest for this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is an internal assert. Given the current implementation, that assert will never get triggered. I added it because the logic is quite complicated, and I wanted to demonstrate that if we are in the non-no_sync() + CPU offloading case, then we should never be in that branch.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(or more directly, the contrapositive: if we are in that branch, we should not be in the non-no_sync() + CPU offloading case, meaning that in particular we must not be CPU offloading)

**Overview**
This adds FSDP gradient accumulation without `no_sync()`, which comparatively has more network bandwidth demand but less GPU memory requirement per worker.

This also adds `p_assert()` (taken from Fairscale), which prints the assert error message before raising the `AssertionError`. It is meant to be used when running in the autograd backward context since otherwise the error message is swallowed, giving a unhelpful error like:
```
<built-in method run_backward of torch._C._EngineBase object at 0x7f1fd518dc80> returned NULL without setting an error
```

**Test Plan**
I augmented the tests to test gradient accumulation without `no_sync()` and also interleaving iterations accumulating with and without `no_sync()`.

Differential Revision: [D34533546](https://our.internmc.facebook.com/intern/diff/D34533546)

[ghstack-poisoned]
@awgu
Copy link
Collaborator Author

awgu commented Mar 3, 2022

@awgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

**Overview**
- This adds FSDP gradient accumulation without `no_sync()`, which comparatively has more network bandwidth demand but less GPU memory requirement per worker.
- This fixes a bug in the `no_sync()` testing, where the CPU offloading and backward prefetch arguments were not propagating to the `FullyShardedDataParallel` constructor.
- This adds `p_assert()` (taken from Fairscale), which prints the assert error message before raising the `AssertionError`. It is meant to be used when running in the autograd backward context since otherwise the error message is swallowed, giving a unhelpful error like:
```
<built-in method run_backward of torch._C._EngineBase object at 0x7f1fd518dc80> returned NULL without setting an error
```

NOTE: Gradient accumulation without `no_sync()` is not currently compatible with CPU offloading.

**Test Plan**
I augmented the tests to test gradient accumulation without `no_sync()` and also interleaving iterations accumulating with and without `no_sync()`.

Differential Revision: [D34533546](https://our.internmc.facebook.com/intern/diff/D34533546)

[ghstack-poisoned]
desertfire pushed a commit that referenced this pull request Mar 3, 2022
ghstack-source-id: e550f57
Pull Request resolved: #73535
@awgu
Copy link
Collaborator Author

awgu commented Mar 3, 2022

@awgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

# try to accumulate gradients. FSDP accumulates gradients in
# the separate variable `p._saved_grad_shard` to leave `p.grad`
# for the per-iteration gradient.
if prev_iter_outside_no_sync:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this PR is getting a bit cluttered, I wanted to specifically point this part out. Previously, I presented the logic here incorrectly, but hopefully this should make sense now.

I think the precise condition of when to use p._saved_grad_shard is if the previous iteration was outside no_sync().

Suppose we have
(1) some iterations outside no_sync() ->
(2) some iterations inside no_sync() ->
(3) one iteration outside no_sync().

  • In the pre-backward hook of (3), the FSDP instance holds an unsharded gradient in p.grad, which is the result of accumulating gradients from (2).
  • It computes that iteration's gradient, which is accumulated with the existing p.grad from (2) via the autograd engine and still stored in p.grad.
  • In the post-backward hook, it reduce-scatters that accumulated gradient stored in p.grad.
  • After the reduce-scatter, it accumulates the accumulated gradient from (2) and (3) with the accumulated gradient from (1) saved in _saved_grad_shard. This "super-accumulated" gradient is stored in _saved_grad_shard.
    • This step shows why as long as the previous iteration was outside no_sync(), there may be a gradient to accumulate on the first future iteration also outside no_sync().

I had misunderstood the comments from Fairscale (see here). I did not realize that the conditioning on inside/outside no_sync() referred to the previous iteration.

f"existing grad shape={param._saved_grad_shard.shape} "
f"new grad shape={output.shape}" # type: ignore[attr-defined]
)
p_assert(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more thing to point out: I previously had an assert like not self.cpu_offload.offload_params just as an internal assert to make sure that CPU offloading never takes this code path.

However, I changed it to a more direct assert here in case we distinguish between offloading parameters and offloading gradients in the future before we solve gradient accumulation with CPU offloading.

Copy link
Contributor

@zhaojuanmao zhaojuanmao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

solid tests! also left some minor comments

f"existing grad device={param._saved_grad_shard.device} "
f"new grad device={output.device}" # type: ignore[attr-defined]
)
param._saved_grad_shard.data += output.data # type: ignore[attr-defined]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will it work if removing the .data?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, it still works. I will remove the .data for both in this line.

Comment on lines 1626 to 1630
# FSDP currently does not support gradient accumulation
# outside `no_sync()` when using CPU offloading. Trying to
# do so yields incorrect results since FSDP will use the
# newly-reduced gradient instead of accumulating with any
# existing gradient.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we add a github issue to support grad accumulation in cpu offloading?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done: #73784

# newly-reduced gradient instead of accumulating with any
# existing gradient.
if not offloaded:
p._saved_grad_shard = p.grad.data # type: ignore[attr-defined]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we add a warning for this case and warning the comment "# FSDP currently does not support gradient accumulation
# outside no_sync() when using CPU offloading. Trying to
# do so yields incorrect results since FSDP will use the
# newly-reduced gradient instead of accumulating with any
# existing gradient."

**Overview**
- This adds FSDP gradient accumulation without `no_sync()`, which comparatively has more network bandwidth demand but less GPU memory requirement per worker.
- This fixes a bug in the `no_sync()` testing, where the CPU offloading and backward prefetch arguments were not propagating to the `FullyShardedDataParallel` constructor.
- This adds `p_assert()` (taken from Fairscale), which prints the assert error message before raising the `AssertionError`. It is meant to be used when running in the autograd backward context since otherwise the error message is swallowed, giving a unhelpful error like:
```
<built-in method run_backward of torch._C._EngineBase object at 0x7f1fd518dc80> returned NULL without setting an error
```

NOTE: Gradient accumulation without `no_sync()` is not currently compatible with CPU offloading.

**Test Plan**
I augmented the tests to test gradient accumulation without `no_sync()` and also interleaving iterations accumulating with and without `no_sync()`.

Differential Revision: [D34533546](https://our.internmc.facebook.com/intern/diff/D34533546)

[ghstack-poisoned]
@awgu
Copy link
Collaborator Author

awgu commented Mar 4, 2022

@awgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

**Overview**
- This adds FSDP gradient accumulation without `no_sync()`, which comparatively has more network bandwidth demand but less GPU memory requirement per worker.
- This fixes a bug in the `no_sync()` testing, where the CPU offloading and backward prefetch arguments were not propagating to the `FullyShardedDataParallel` constructor.
- This adds `p_assert()` (taken from Fairscale), which prints the assert error message before raising the `AssertionError`. It is meant to be used when running in the autograd backward context since otherwise the error message is swallowed, giving a unhelpful error like:
```
<built-in method run_backward of torch._C._EngineBase object at 0x7f1fd518dc80> returned NULL without setting an error
```

NOTE: Gradient accumulation without `no_sync()` is not currently compatible with CPU offloading.

**Test Plan**
I augmented the tests to test gradient accumulation interleaving iterations accumulating with and without `no_sync()`.

Differential Revision: [D34533546](https://our.internmc.facebook.com/intern/diff/D34533546)

[ghstack-poisoned]
@awgu
Copy link
Collaborator Author

awgu commented Mar 4, 2022

@awgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

3 similar comments
@awgu
Copy link
Collaborator Author

awgu commented Mar 4, 2022

@awgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@awgu
Copy link
Collaborator Author

awgu commented Mar 7, 2022

@awgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@awgu
Copy link
Collaborator Author

awgu commented Mar 7, 2022

@awgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

facebook-github-bot pushed a commit that referenced this pull request Mar 7, 2022
Summary:
Pull Request resolved: #73535

**Overview**
- This adds FSDP gradient accumulation without `no_sync()`, which comparatively has more network bandwidth demand but less GPU memory requirement per worker.
- This fixes a bug in the `no_sync()` testing, where the CPU offloading and backward prefetch arguments were not propagating to the `FullyShardedDataParallel` constructor.
- This adds `p_assert()` (taken from Fairscale), which prints the assert error message before raising the `AssertionError`. It is meant to be used when running in the autograd backward context since otherwise the error message is swallowed, giving a unhelpful error like:
```
<built-in method run_backward of torch._C._EngineBase object at 0x7f1fd518dc80> returned NULL without setting an error
```

NOTE: Gradient accumulation without `no_sync()` is not currently compatible with CPU offloading.

**Test Plan**
I augmented the tests to test gradient accumulation interleaving iterations accumulating with and without `no_sync()`.

After this diff:
- QPS (ResNet): f328439897
- QPS (RoBERTa): f328440141
- Accuracy: f328442119

Before this diff (trunk):
- QPS (ResNet): f328432756
- QPS (RoBERTa): f328436766
- Accuracy: f328437896

Test Plan: Imported from OSS

Reviewed By: zhaojuanmao

Differential Revision: D34533546

Pulled By: awgu

fbshipit-source-id: 821d762dfad5f2b1e59adcb8e5cb7c277399040c
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 9, 2022
Summary:
Pull Request resolved: pytorch/pytorch#73535

**Overview**
- This adds FSDP gradient accumulation without `no_sync()`, which comparatively has more network bandwidth demand but less GPU memory requirement per worker.
- This fixes a bug in the `no_sync()` testing, where the CPU offloading and backward prefetch arguments were not propagating to the `FullyShardedDataParallel` constructor.
- This adds `p_assert()` (taken from Fairscale), which prints the assert error message before raising the `AssertionError`. It is meant to be used when running in the autograd backward context since otherwise the error message is swallowed, giving a unhelpful error like:
```
<built-in method run_backward of torch._C._EngineBase object at 0x7f1fd518dc80> returned NULL without setting an error
```

NOTE: Gradient accumulation without `no_sync()` is not currently compatible with CPU offloading.

**Test Plan**
I augmented the tests to test gradient accumulation interleaving iterations accumulating with and without `no_sync()`.

After this diff:
- QPS (ResNet): f328439897
- QPS (RoBERTa): f328440141
- Accuracy: f328442119

Before this diff (trunk):
- QPS (ResNet): f328432756
- QPS (RoBERTa): f328436766
- Accuracy: f328437896

Test Plan: Imported from OSS

Reviewed By: zhaojuanmao

Differential Revision: D34533546

Pulled By: awgu

fbshipit-source-id: 821d762dfad5f2b1e59adcb8e5cb7c277399040c
(cherry picked from commit 746a5ea2720dcf87c376229b405a318396fe5769)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 9, 2022
Summary:
Pull Request resolved: pytorch/pytorch#73535

**Overview**
- This adds FSDP gradient accumulation without `no_sync()`, which comparatively has more network bandwidth demand but less GPU memory requirement per worker.
- This fixes a bug in the `no_sync()` testing, where the CPU offloading and backward prefetch arguments were not propagating to the `FullyShardedDataParallel` constructor.
- This adds `p_assert()` (taken from Fairscale), which prints the assert error message before raising the `AssertionError`. It is meant to be used when running in the autograd backward context since otherwise the error message is swallowed, giving a unhelpful error like:
```
<built-in method run_backward of torch._C._EngineBase object at 0x7f1fd518dc80> returned NULL without setting an error
```

NOTE: Gradient accumulation without `no_sync()` is not currently compatible with CPU offloading.

**Test Plan**
I augmented the tests to test gradient accumulation interleaving iterations accumulating with and without `no_sync()`.

After this diff:
- QPS (ResNet): f328439897
- QPS (RoBERTa): f328440141
- Accuracy: f328442119

Before this diff (trunk):
- QPS (ResNet): f328432756
- QPS (RoBERTa): f328436766
- Accuracy: f328437896

Test Plan: Imported from OSS

Reviewed By: zhaojuanmao

Differential Revision: D34533546

Pulled By: awgu

fbshipit-source-id: 821d762dfad5f2b1e59adcb8e5cb7c277399040c
(cherry picked from commit 746a5ea2720dcf87c376229b405a318396fe5769)
@rohan-varma
Copy link
Contributor

@awgu I was thinking maybe we should add some documentation somewhere around the gradient accumulation that is supported in FSDP, how to use it and what are the tradeoffs?

@facebook-github-bot facebook-github-bot deleted the gh/awgu/8/head branch March 11, 2022 15:17
@awgu
Copy link
Collaborator Author

awgu commented Mar 11, 2022

@awgu I was thinking maybe we should add some documentation somewhere around the gradient accumulation that is supported in FSDP, how to use it and what are the tradeoffs?

I think this is a great idea, especially since right now I do not have good intuition for the tradeoffs either.

@rohan-varma
Copy link
Contributor

#74153

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed oncall: distributed Add this issue/PR to distributed oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants