Skip to content

20x less memory use and 37.25% speedup in min_cut_rematerialization_partition when using the new dp knapsack solver, compared to existing default one (dp)#160914

Closed
jmaczan wants to merge 8 commits intopytorch:mainfrom
jmaczan:dp_knapsack_hirschberg
Closed

Conversation

@jmaczan
Copy link
Contributor

@jmaczan jmaczan commented Aug 18, 2025

The goal:
Reduce memory usage in min_cut_rematerialization_partition with new knapsack implementation

Rationale:
The @Chillee's comment in original dp_knapsack suggests improving the code with Hirschberg algorithm and sliding window. In this PR, I create a new knapsack implementation which is based on the original implementation and uses Hirschberg + sliding window

Existing dp_knapsack implementation instantiates full dp table of shape (n, W), where n is number of memory elements and W is quantized memory length. The new dp_knapsack_sliding_hirschberg uses only (2, W) memory - it needs only 2 dp profiles (1-dim tensors) to compute the same output as original dp_knapsack

This optimization is possible, because at each step we use only current and previous row ("dp profile"). We split the indices of items (each item is a pair of memory and runtime) into two, compute dp profile for each half, then compute the index of split at which the sum of runtimes from both halfs is the highest, then we use this split index to decide how much budget we give to left half and right half and we recurse on left half and right half with new memory budgets

Based on benchmarks, consider if we should keep two dp knapsack implementations or one, which one should be default, do we want to make it easier to use the new one etc. In general, think about the next steps

Thanks in advance for all comments

cc @zou3519 @Chillee @samdow @kshitij12345

…ew alternative to 0/1 knapsack with full dp table implementation. Unit testing in progress. Benchmarks not started. The current, first implementation uses recursion and can be improved (explicit stack, vectorization and many others). That's a starting point. Some unit tests are failing and showing minor differences between new and previous implementations
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 18, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160914

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 1 Unrelated Failure

As of commit 2b3706b with merge base 226850c (image):

NEW FAILURES - The following jobs have failed:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@jmaczan jmaczan changed the title [WIP] Hirschberg algorithm with sliding window for 0/1 dp knapsack [WIP] Reduce memory usage in min_cut_rematerialization_partition when using dp knapsack solver Aug 18, 2025
@jmaczan
Copy link
Contributor Author

jmaczan commented Aug 18, 2025

@pytorchbot label "module: functorch"

@pytorch-bot pytorch-bot bot added the module: functorch Pertaining to torch.func or pytorch/functorch label Aug 18, 2025
…lit: and adjust right split accordingly. Remove quantized_max_memory <= 0 condition, add a condition for 0 memory length
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 27, 2025

Didn't find following labels among repository labels: topic: user facing

@jmaczan
Copy link
Contributor Author

jmaczan commented Aug 27, 2025

@pytorchbot label "release notes: torch.func"

not sure about it, pls someone adjust if needed

@pytorch-bot pytorch-bot bot added the release notes: torch.func release notes category for torch.vmap or torch.func.* APIs label Aug 27, 2025
@jmaczan
Copy link
Contributor Author

jmaczan commented Aug 31, 2025

I have some happy benchmarks to share!

  • The new hirschberg dp implementation is 37.25% faster than existing dp implementation
  • The new hirschberg dp implementation can be used on more than 20x bigger problem sizes. On my machine the existing dp knapsack crashes on as low as 100 items, whereas the new hirschberg dp is able to run correctly for 2000 items (with peak memory 58.4 GB)

Ran all solvers for 1000 times on as low noisy machine as possible (Ubuntu 24.04.2, AMD Ryzen 7 9800X3D, 64 GB RAM, no background activity except Terminal with benchmark script)

How to reproduce:
On my pytorch/benchmark fork on branch hirschberg_benchmarks (https://github.com/jmaczan/benchmark/tree/hirschberg_benchmarks), assuming you have built pytorch/audio, pytorch/vision and pytorch/pytorch with my branch dp_knapsack_hirschberg, in root directory of benchmark repo, run:

python userbenchmark/functorch/model_benchmark.py && python to_csv.py

It will produce JSON and CSV with results. You can steer the number of runs for each test and modify a problem size you want to test (I used [1, 5, 10, 20, 50, 100] sizes and 1000 runs for each size)

You can copy the content of CSV file to a copy of the spreadsheet below, split it to columns using a method from top context menu and you will get neatly organized benchmarks

Results:
https://docs.google.com/spreadsheets/d/1lvtHXTwofGEEA2_B8FDbpAMZuAbH9l-jxb-9gQTZixA/edit?usp=sharing

Code for benchmarks is still kinda messy, I'll improve it before merging this PR

I think it's a good moment to move it from 'draft' to 'ready to review' @zou3519 @Chillee @samdow @kshitij12345

@jmaczan jmaczan marked this pull request as ready for review August 31, 2025 18:42
@jmaczan jmaczan changed the title [WIP] Reduce memory usage in min_cut_rematerialization_partition when using dp knapsack solver Reduce memory usage in min_cut_rematerialization_partition when using dp knapsack solver Aug 31, 2025
@albanD albanD requested review from Chillee and bdhirsh September 4, 2025 14:59
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 4, 2025
@jmaczan jmaczan changed the title Reduce memory usage in min_cut_rematerialization_partition when using dp knapsack solver 20x less memory use and 37.25% speedup in min_cut_rematerialization_partition when using the new dp knapsack solver, compared to existing default one (dp) Sep 10, 2025
@jmaczan
Copy link
Contributor Author

jmaczan commented Sep 15, 2025

Hi @bdhirsh, a gentle ping if you are available for a review

@Chillee I read your post that you step away from PyTorch, but perhaps would you like to throw an eye on this particular change? You might be interested in this one, because it's your idea for the optimization (hirschberg algo + sliding window). If you're not interested though, then sorry for bothering you!

@jmaczan
Copy link
Contributor Author

jmaczan commented Oct 5, 2025

@bdhirsh @Chillee @albanD review ping 🔍

@jmaczan
Copy link
Contributor Author

jmaczan commented Oct 8, 2025

Who I should ping for the review? @ezyang @janeyx99 @anijain2305

@jmaczan
Copy link
Contributor Author

jmaczan commented Oct 14, 2025

This PR needs review @ezyang @janeyx99 @anijain2305 @bdhirsh @Chillee

@janeyx99 janeyx99 requested a review from zou3519 October 16, 2025 17:32
@jmaczan
Copy link
Contributor Author

jmaczan commented Oct 22, 2025

@zou3519 review ping, this one is important because yields very solid memory savings (up to 20 times less RAM used!) compared to current default dp implementation

@ezyang
Copy link
Contributor

ezyang commented Nov 8, 2025

This is in my review box but it requires a very heavy review that I haven't had time for. It would be easier to justify reviewing this if there was a more clear description of when the CPU memory use/speed of the min cut partitioner is a limiting factor for compile time; my impression is that this is typically not the big problem.

@jmaczan
Copy link
Contributor Author

jmaczan commented Nov 8, 2025

@ezyang The main reason for this change is RAM usage on DP, CPU speedup is a nice bonus. Current dp impl allocates a lot of RAM to build a full dp table. 100 items it doesn't even fit into 64GB RAM. With this new implementation, you can go to more than 2000 items within this memory budget and for 100 items or less get significant speedup compared to full dp table current impl

@ezyang
Copy link
Contributor

ezyang commented Nov 11, 2025

oh. we probably did not notice this because our machines have a ginormous amount of cpu ram lol

best_split = int(torch.argmax(left_dp_local).item())

left_capacity = best_split
right_capacity = capacity - best_split
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be capacity - 1 - best_split?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should not, because it's not an index arithmetic - I really compute the memory left (capacity) here

@jmaczan
Copy link
Contributor Author

jmaczan commented Nov 13, 2025

ping @ezyang up

@ezyang
Copy link
Contributor

ezyang commented Nov 14, 2025

@pytorchbot merge

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 14, 2025

Pull workflow has not been scheduled for the PR yet. It could be because author doesn't have permissions to run those or skip-checks keywords were added to PR/commits, aborting merge. Please get/give approval for the workflows and/or remove skip ci decorators before next merge attempt. If you think this is a mistake, please contact PyTorch Dev Infra.

@ezyang
Copy link
Contributor

ezyang commented Nov 14, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 14, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Nov 14, 2025
@jmaczan
Copy link
Contributor Author

jmaczan commented Nov 14, 2025

I pushed a lint fix and merged main, let's try again @ezyang

@jmaczan
Copy link
Contributor Author

jmaczan commented Nov 20, 2025

@ezyang Let's try merge again pls

@ezyang
Copy link
Contributor

ezyang commented Nov 21, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 21, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@jmaczan
Copy link
Contributor Author

jmaczan commented Nov 21, 2025

Seem like flaky tests @ezyang

@ezyang
Copy link
Contributor

ezyang commented Nov 21, 2025

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 3 checks: pull / linux-docs / build-docs-python-false, pull / linux-jammy-py3.10-gcc11 / test (docs_test, 1, 1, lf.linux.2xlarge), trunk / linux-jammy-py3-clang12-executorch / test (executorch, 1, 1, linux.2xlarge, unstable)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

JacobSzwejbka pushed a commit that referenced this pull request Dec 8, 2025
…artition when using the new dp knapsack solver, compared to existing default one (dp) (#160914)

The goal:
**Reduce memory usage** in [min_cut_rematerialization_partition](https://github.com/pytorch/pytorch/blob/1f1900369435933a013df9a7d5e07c75c1cebb5d/torch/_functorch/partitioners.py#L2601) with new knapsack implementation

Rationale:
The @Chillee's comment in original dp_knapsack suggests improving the code with [Hirschberg algorithm and sliding window](https://codeforces.com/blog/entry/47247?#comment-316200). In this PR, I create a new knapsack implementation which is based on the original implementation and uses Hirschberg + sliding window

Existing `dp_knapsack` implementation instantiates full dp table of shape (n, W), where n is number of memory elements and W is quantized memory length. The new `dp_knapsack_sliding_hirschberg` uses only (2, W) memory - it needs only 2 dp profiles (1-dim tensors) to compute the same output as original `dp_knapsack`

This optimization is possible, because at each step we use only current and previous row ("dp profile"). We split the indices of items (each item is a pair of memory and runtime) into two, compute dp profile for each half, then compute the index of split at which the sum of runtimes from both halfs is the highest, then we use this split index to decide how much budget we give to left half and right half and we recurse on left half and right half with new memory budgets

Based on benchmarks, consider if we should keep two dp knapsack implementations or one, which one should be default, do we want to make it easier to use the new one etc. In general, think about the next steps

Thanks in advance for all comments

Pull Request resolved: #160914
Approved by: https://github.com/ezyang

Co-authored-by: Edward Yang <ezyang@meta.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: functorch Pertaining to torch.func or pytorch/functorch open source release notes: torch.func release notes category for torch.vmap or torch.func.* APIs triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants