Skip to content

Conversation

@awgu
Copy link
Collaborator

@awgu awgu commented Mar 16, 2022

Stack from ghstack:

Fixes #73890.

Differential Revision: D34937201

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 16, 2022

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/pytorch/pytorch/blob/a8eee325fed2bdc155282dc161bcefcef03c648a/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default
Add ciflow labels to this PR to trigger more builds:

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
linux-binary-libtorch-cxx11-abi ciflow/all, ciflow/binaries, ciflow/binaries_libtorch, ciflow/default, ciflow/trunk ✅ triggered
linux-binary-libtorch-pre-cxx11 ciflow/all, ciflow/binaries, ciflow/binaries_libtorch, ciflow/default, ciflow/trunk ✅ triggered
linux-binary-manywheel ciflow/all, ciflow/binaries, ciflow/binaries_wheel, ciflow/default, ciflow/trunk ✅ triggered
linux-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/trunk ✅ triggered
linux-bionic-rocm4.5-py3.7 ciflow/all, ciflow/default, ciflow/linux, ciflow/rocm, ciflow/trunk ✅ triggered
linux-docs ciflow/all, ciflow/cpu, ciflow/default, ciflow/docs, ciflow/linux, ciflow/trunk ✅ triggered
linux-vulkan-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc5.4-mobile-lightweight-dispatch-build ciflow/all, ciflow/cpu, ciflow/default, ciflow/libtorch, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7-no-ops ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
macos-arm64-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
macos-arm64-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
macos-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
macos-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
macos-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
macos-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
windows-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
windows-binary-libtorch-debug ciflow/all, ciflow/binaries, ciflow/binaries_libtorch, ciflow/default, ciflow/trunk ✅ triggered
windows-binary-libtorch-release ciflow/all, ciflow/binaries, ciflow/binaries_libtorch, ciflow/default, ciflow/trunk ✅ triggered
windows-binary-wheel ciflow/all, ciflow/binaries, ciflow/binaries_wheel, ciflow/default, ciflow/trunk ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
docker-builds ciflow/all, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda10.2-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow, ciflow/trunk 🚫 skipped
linux-bionic-rocm4.5-py3.7-distributed ciflow/all, ciflow/linux, ciflow/rocm, ciflow/trunk 🚫 skipped
linux-docs-push ciflow/all, ciflow/cpu, ciflow/linux, ciflow/scheduled 🚫 skipped
linux-xenial-cuda11.3-py3.7-gcc7-no-ops ciflow/all, ciflow/cuda, ciflow/linux, ciflow/trunk 🚫 skipped
macos-10-15-py3-arm64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-10-15-py3-lite-interpreter-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-11-py3-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
parallelnative-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
periodic-libtorch-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.3-py3.7-gcc7-debug ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.5-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-build ciflow/all, ciflow/android, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
pytorch-xla-linux-bionic-py3.7-clang8 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk, ciflow/xla 🚫 skipped

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Mar 16, 2022

🔗 Helpful links

💊 CI failures summary and remediations

As of commit acd6186 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

awgu pushed a commit that referenced this pull request Mar 16, 2022
ghstack-source-id: d153ef1
Pull Request resolved: #74333
@facebook-github-bot facebook-github-bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Mar 16, 2022
@awgu
Copy link
Collaborator Author

awgu commented Mar 16, 2022

@awgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@awgu awgu marked this pull request as ready for review March 16, 2022 20:35
"""
# Monkey patch `named_parameters()`
torch_named_parameters = torch.nn.Module.named_parameters
self.named_parameters = self._fsdp_named_parameters # type: ignore[assignment]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding the type: ignore[assignment] seems to be the best solution. See python/mypy#2427.

Fixes #73890 by monkey patching `torch.nn.Module.named_paramteters()`.

Differential Revision: [D34937201](https://our.internmc.facebook.com/intern/diff/D34937201)

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Mar 16, 2022
ghstack-source-id: 6a960ac
Pull Request resolved: #74333
@awgu
Copy link
Collaborator Author

awgu commented Mar 16, 2022

@awgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

List,
Optional,
Generator,
Iterator,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just sorted imports.

@awgu
Copy link
Collaborator Author

awgu commented Mar 17, 2022

To-Do: look into overriding named_parameters() instead of monkey patching, using the summon full params training state to indicate whether to use the special logic or simply call into the normal nn.Module.named_parameters().

Fixes #73890 by monkey patching `torch.nn.Module.named_paramteters()`.

Differential Revision: [D34937201](https://our.internmc.facebook.com/intern/diff/D34937201)

[ghstack-poisoned]
context manager.
"""
# Determine which logic to use based on the context at call time
if not hasattr(self, "training_state") or \
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nn.Modules contained in an FSDP instance may not have training_state as an attribute but still return True for isinstance(module, FullyShardedDataParallel). If you guys think this hasattr() check is too hacky, I will see if I can find a more direct solution.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this can be simplified to if getattr(self, "training_state", None) != TrainingState_.SUMMON_FULL_PARAMS:

Fixes #73890 by monkey patching `torch.nn.Module.named_paramteters()`.

Differential Revision: [D34937201](https://our.internmc.facebook.com/intern/diff/D34937201)

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Mar 18, 2022
ghstack-source-id: 61d4b39
Pull Request resolved: #74333
@awgu
Copy link
Collaborator Author

awgu commented Mar 18, 2022

@awgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.


def named_parameters(
self,
prefix: str = "",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's take in *args, **kwargs so we don't have to change if the API changes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be correct to say that PyTorch Core cannot add positional arguments before prefix without breaking backward compatibility? In that case, could we only add **kwargs?

Copy link
Contributor

@fegin fegin Mar 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using (*args, **kwargs) is kind of the convention in Python to pass arguments to the parent method without any modification.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I misunderstood. I thought we were trying to additionally pass in *args, **kwargs instead of replacing prefix and recurse.

Fixed this now.

…ummon_full_params()`"


Fixes #73890.

Differential Revision: [D34937201](https://our.internmc.facebook.com/intern/diff/D34937201)

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Mar 18, 2022
ghstack-source-id: cfdea1e
Pull Request resolved: #74333
@awgu
Copy link
Collaborator Author

awgu commented Mar 18, 2022

@awgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

…ummon_full_params()`"


Fixes #73890.

Differential Revision: [D34937201](https://our.internmc.facebook.com/intern/diff/D34937201)

[ghstack-poisoned]
@awgu
Copy link
Collaborator Author

awgu commented Mar 18, 2022

@awgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

…ummon_full_params()`"


Fixes #73890.

Differential Revision: [D34937201](https://our.internmc.facebook.com/intern/diff/D34937201)

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Mar 18, 2022
ghstack-source-id: 6b9675e
Pull Request resolved: #74333
@awgu
Copy link
Collaborator Author

awgu commented Mar 18, 2022

@awgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

…ummon_full_params()`"


Fixes #73890.

Differential Revision: [D34937201](https://our.internmc.facebook.com/intern/diff/D34937201)

[ghstack-poisoned]
@awgu
Copy link
Collaborator Author

awgu commented Mar 18, 2022

@awgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@rohan-varma rohan-varma self-requested a review March 18, 2022 17:36
Copy link
Contributor

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just 2 minor q's for your consideration, will stamp after that

# Remove any instances of the FSDP-specific prefix; there can
# be multiple in the case of nested FSDP modules
param_name = param_name.replace(FSDP_PREFIX, "")
yield (param_name, param)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do the following to avoid the duplicated for loop?

in_summon = (training_state == summon_full_params)
for n, p in named_parameters():
    name = name.replace(...) if in_summon else name
    yield (name, p)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point.

"""
# Determine which logic to use based on the context at call time
if getattr(self, "training_state", None) != TrainingState_.SUMMON_FULL_PARAMS:
for param_name, param in torch.nn.Module.named_parameters(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonder if we can use super().named_parameters rather than torch.nn.Module.named_parameters call? If user writes MyModule that inherits from nn.Module with a custom named_parameters will this work?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point.

…ummon_full_params()`"


Fixes #73890.

Differential Revision: [D34937201](https://our.internmc.facebook.com/intern/diff/D34937201)

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Mar 18, 2022
ghstack-source-id: cdbb6ea
Pull Request resolved: #74333
@awgu
Copy link
Collaborator Author

awgu commented Mar 18, 2022

@awgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@rohan-varma rohan-varma self-requested a review March 18, 2022 20:29
@rohan-varma
Copy link
Contributor

Lgtm

Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

facebook-github-bot pushed a commit that referenced this pull request Mar 21, 2022
…params()` (#74333)

Summary: Pull Request resolved: #74333

Test Plan: Imported from OSS

Reviewed By: fegin

Differential Revision: D34937201

Pulled By: awgu

fbshipit-source-id: dfd64e94ff66a068910eebc56102fcd640804429
@github-actions
Copy link
Contributor

Hey @awgu.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@facebook-github-bot facebook-github-bot deleted the gh/awgu/16/head branch March 24, 2022 14:17
shahofblah pushed a commit that referenced this pull request Mar 25, 2022
…params()` (#74333)

Summary: Pull Request resolved: #74333

Test Plan: Imported from OSS

Reviewed By: fegin

Differential Revision: D34937201

Pulled By: awgu

fbshipit-source-id: dfd64e94ff66a068910eebc56102fcd640804429
(cherry picked from commit c17f338)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed oncall: distributed Add this issue/PR to distributed oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants