-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add memory estimator #164738
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add memory estimator #164738
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164738
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 3a10033 with merge base b5e93ff ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
Starting merge as part of PR stack under #164783 |
This reverts commit ab01a0d. Reverted #164738 on behalf of https://github.com/eellison due to merge sets makes this trickier ([comment](#164581 (comment)))
Original work by ShatianWang, with lints applied. I am going to a few changes and add tests in subsequent prs but I want to preserve original commit first. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
- Update the Memory Estimator to use node storages for analysis, which simplifies book keeping, as opposed to manually looking at operator schema. This will also allow me to reuse this component elsewhere. - Factor out into separate class, so that this same logic can be used in scheduling (node allocations / aliasing / uses) - Adds Tests for correctness - right now only on fwd/bwd by itself, not with both. Pull Request resolved: #164783 Approved by: https://github.com/ruisizhang123 ghstack dependencies: #164738
Respect max_coll_distance from overlap scheduler in bucketing, also, add an optimization in path searching. Pull Request resolved: #164944 Approved by: https://github.com/IvanKobzarev ghstack dependencies: #164738, #164783
Add Memory Tracker utility, which will track live memory given alternate ordering of nodes. Pull Request resolved: #165059 Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev ghstack dependencies: #164738, #164783, #164944, #164945
Bucketing a number of smallish improvements: - Account for bucketing in overlap calculation: if an in-flight collective exists with the same bucket key, reduce new collectives estimated time by its latency time - Update compute domination so we are ordering based on compute idx, as opposed to compute depth, so we never reorder compute. this makes it a bit easier to reason about memory, and pre-fetching, although we can exploring reordering in the future. - When we wait on a collective, force all collectives on the same process group as it that were enqueued prior to the collective to wait as well. Better Memory Handling: - Pre-fetch limiting - when scheduling collectives for overlap, only pre-fetch up to a certain distance, then schedule off-path collectives (which are typically memory reducing). - When we are above peak memory, schedule waits. TODO: - for each compute node, we know its original memory in the graph. we could limit pre-fetching that goes across peak memory - By scheduling off-path collectives for overlap, we reduce memory, but if there weren't enough compute for overlap, we need to proactively schedule them. not an issue yet on examples. - config some hard coded constants, clean up enablement (can do in subsequent pr) On small llama 2d backward : 578 of 618 potentially hideable collectives hidden original mem 14.4GB, rescheduled mem, 15.9GB on forward: 254/256 potentially hideable collectives hidden original mem 5.8 gb, reshceduled mem 5.8GB WIP: adding tests Pull Request resolved: pytorch#165318 Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev ghstack dependencies: pytorch#164738, pytorch#164783, pytorch#164944, pytorch#164945, pytorch#165059
Original work by @ShatianWang, with lints applied. I am going to a few changes and add tests in subsequent prs but I want to preserve original commit first. Pull Request resolved: pytorch#164738 Approved by: https://github.com/IvanKobzarev ghstack dependencies: pytorch#164568, pytorch#164569, pytorch#164581
This reverts commit ab01a0d. Reverted pytorch#164738 on behalf of https://github.com/eellison due to merge sets makes this trickier ([comment](pytorch#164581 (comment)))
Original work by @ShatianWang, with lints applied. I am going to a few changes and add tests in subsequent prs but I want to preserve original commit first. Pull Request resolved: pytorch#164738 Approved by: https://github.com/IvanKobzarev
) - Update the Memory Estimator to use node storages for analysis, which simplifies book keeping, as opposed to manually looking at operator schema. This will also allow me to reuse this component elsewhere. - Factor out into separate class, so that this same logic can be used in scheduling (node allocations / aliasing / uses) - Adds Tests for correctness - right now only on fwd/bwd by itself, not with both. Pull Request resolved: pytorch#164783 Approved by: https://github.com/ruisizhang123 ghstack dependencies: pytorch#164738
Respect max_coll_distance from overlap scheduler in bucketing, also, add an optimization in path searching. Pull Request resolved: pytorch#164944 Approved by: https://github.com/IvanKobzarev ghstack dependencies: pytorch#164738, pytorch#164783
…ted (pytorch#164945) Pull Request resolved: pytorch#164945 Approved by: https://github.com/IvanKobzarev ghstack dependencies: pytorch#164738, pytorch#164783, pytorch#164944
Add Memory Tracker utility, which will track live memory given alternate ordering of nodes. Pull Request resolved: pytorch#165059 Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev ghstack dependencies: pytorch#164738, pytorch#164783, pytorch#164944, pytorch#164945
Bucketing a number of smallish improvements: - Account for bucketing in overlap calculation: if an in-flight collective exists with the same bucket key, reduce new collectives estimated time by its latency time - Update compute domination so we are ordering based on compute idx, as opposed to compute depth, so we never reorder compute. this makes it a bit easier to reason about memory, and pre-fetching, although we can exploring reordering in the future. - When we wait on a collective, force all collectives on the same process group as it that were enqueued prior to the collective to wait as well. Better Memory Handling: - Pre-fetch limiting - when scheduling collectives for overlap, only pre-fetch up to a certain distance, then schedule off-path collectives (which are typically memory reducing). - When we are above peak memory, schedule waits. TODO: - for each compute node, we know its original memory in the graph. we could limit pre-fetching that goes across peak memory - By scheduling off-path collectives for overlap, we reduce memory, but if there weren't enough compute for overlap, we need to proactively schedule them. not an issue yet on examples. - config some hard coded constants, clean up enablement (can do in subsequent pr) On small llama 2d backward : 578 of 618 potentially hideable collectives hidden original mem 14.4GB, rescheduled mem, 15.9GB on forward: 254/256 potentially hideable collectives hidden original mem 5.8 gb, reshceduled mem 5.8GB WIP: adding tests Pull Request resolved: pytorch#165318 Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev ghstack dependencies: pytorch#164738, pytorch#164783, pytorch#164944, pytorch#164945, pytorch#165059
ghstack-source-id: afb2d8c Pull Request resolved: pytorch#164738
Stack from ghstack (oldest at bottom):
Original work by @ShatianWang, with lints applied. I am going to a few changes and add tests in subsequent prs but I want to preserve original commit first.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben