Skip to content

Conversation

@kshitij12345
Copy link
Collaborator

@kshitij12345 kshitij12345 commented Oct 11, 2022

Fixes #42824

  • Test
  • Doc

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 11, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/86700

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 0f4a32e:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@kshitij12345 kshitij12345 changed the title [nn] module: backward_pre_hook [WIP] [nn] module: backward_pre_hook Oct 11, 2022
@pytorch-bot pytorch-bot bot added the release notes: jit release notes category label Oct 12, 2022
@kshitij12345 kshitij12345 changed the title [WIP] [nn] module: backward_pre_hook [nn] module: backward_pre_hook Oct 12, 2022
@kshitij12345 kshitij12345 marked this pull request as ready for review October 12, 2022 18:19

def register_module_backward_hook(
hook: Callable[['Module', _grad_t, _grad_t], Union[None, Tensor]]
hook: Callable[['Module', _grad_t, _grad_t], Union[None, _grad_t]]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the signature here was wrong.

Note: This change is not related to backward_pre_hook


def register_module_full_backward_hook(
hook: Callable[['Module', _grad_t, _grad_t], Union[None, Tensor]]
hook: Callable[['Module', _grad_t, _grad_t], Union[None, _grad_t]]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the signature here was wrong.

Note: This change is not related to backward_pre_hook

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@soulitzer can you take a look at this?

Copy link
Contributor

@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, look pretty good. Just have some small comments on the docs.

@kshitij12345 kshitij12345 changed the title [nn] module: backward_pre_hook [nn] module: full_backward_pre_hook Oct 13, 2022
output = bn(torch.randn(5, 5, requires_grad=True))
output.sum().backward()

@skipIfTorchDynamo("TorchDynamo does not work well with hooks")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fails on Dynamo. Have added skip similar to other hook related tests.

@kshitij12345
Copy link
Collaborator Author

@soulitzer have addressed all the points. PTAL Thanks :)

test/test_nn.py Outdated
output.sum().backward()

@skipIfTorchDynamo("TorchDynamo does not work well with hooks")
def test_hook_backward_pre_and_full(self):
Copy link
Contributor

@soulitzer soulitzer Oct 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we might want to rename this test now that both tests are "full"

Copy link
Collaborator Author

@kshitij12345 kshitij12345 Oct 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Renaming it to test_backward_hooks_interaction

Copy link
Contributor

@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 13, 2022
@soulitzer
Copy link
Contributor

@pytorchbot merge -g

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks on your PR pass since you used the green (-g) flag (ETA: 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions
Copy link
Contributor

Hey @kshitij12345.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@kshitij12345 kshitij12345 deleted the develop/hooks/module-backward-pre-hook branch October 13, 2022 18:39
@soulitzer soulitzer added release notes: nn release notes category and removed release notes: jit release notes category labels Oct 13, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: nn release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[feature request] Pre-backward hooks for runtime profiling use case

5 participants