Skip to content

Conversation

@gs-olive
Copy link
Contributor

@gs-olive gs-olive commented Jan 12, 2023

Description

  • Remove existing dropout removal lowering pass implementation due to bug
  • Use adapted Torch JIT dropout removal lowering pass to resolve bug where nested dropouts resulted in invalid graph
  • Existing removal process left artifacts in graph which caused an internal assertion error
  • Add two regression tests to catch nested dropout bug

Fixes #1587

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • [ x ] My code follows the style guidelines of this project (You can use the linters)
  • [ x ] I have performed a self-review of my own code
  • [ x ] I have commented my code, particularly in hard-to-understand areas and hacks
  • [ x ] I have made corresponding changes to the documentation
  • [ x ] I have added tests to verify my fix or my feature
  • [ x ] New and existing unit tests pass locally with my changes
  • [ x ] I have added the relevant labels to my PR in so that relevant reviewers are notified

@gs-olive gs-olive requested a review from narendasan January 12, 2023 22:05
@gs-olive gs-olive self-assigned this Jan 12, 2023
@github-actions github-actions bot added component: core Issues re: The core compiler component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests labels Jan 12, 2023
@github-actions github-actions bot requested review from bowang007 and peri044 January 12, 2023 22:05
@gs-olive gs-olive changed the title fix: Replace custom dropout-removal lowering pass with JIT pass fix: Replace RemoveDropout lowering pass implementation with JIT pass Jan 12, 2023
@gs-olive gs-olive changed the title fix: Replace RemoveDropout lowering pass implementation with JIT pass fix: Replace RemoveDropout lowering pass implementation with modified JIT pass Jan 20, 2023
@gs-olive gs-olive force-pushed the remove_dropout_bugfix branch 3 times, most recently from d84905f to f8385b6 Compare January 23, 2023 22:40
- Remove existing dropout removal lowering pass implementation due to
bug
- Use Torch JIT dropout removal lowering pass to resolve bug where
nested dropouts resulted in invalid graph
- Existing removal process left artifacts in graph which caused an
internal assertion error
- Add regression test to catch nested dropout bug
- Update tests to remove testing for `feature_alpha_dropout` and
`feature_alpha_dropout_`, which are not removed by the JIT lowering pass
and can be added in later
@gs-olive gs-olive force-pushed the remove_dropout_bugfix branch from f8385b6 to e7a469d Compare January 24, 2023 01:18
- Adapt JIT pass to remove dropout to accommodate multiple dropout
schemas
- Include additional test cases to verify new removal code
@gs-olive gs-olive force-pushed the remove_dropout_bugfix branch from e7a469d to 8698045 Compare January 24, 2023 05:36
Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@narendasan narendasan merged commit b2b6871 into pytorch:main Jan 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: core Issues re: The core compiler component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

🐛 [Bug] Error when compiling Punctuation BERT model

3 participants