Conversation
7e50d90 to
3b3813d
Compare
|
Thanks for your PR. However, it's being worked on in #12721. |
|
Could we resolve conflicts so that it's a bit easier to review? Seems like there's some overlap from #12692. |
6d96002 to
33d8b52
Compare
|
Done! Rebased on latest main and resolved conflicts with #12692. Should be much cleaner to review now. |
| should_synchronize = ( | ||
| not self.group.onload_self and self.group.stream is not None and not should_onload_next_group | ||
| ) |
There was a problem hiding this comment.
Even with non_blocking=True, if a previous group onloaded this one on a side stream, we need a sync before the default stream uses the weights or we risk reading half-copied tensors. I’ve limited the sync to the record_stream=False case, when record_stream=True the tensors are tied to the consumer stream so we can safely skip the sync.
|
Thank you for the initial comment! We are working on the solutions right now |
6f5887e to
1194a83
Compare
|
@bot /style |
|
Style bot fixed some files and pushed the changes. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Hi @DN6 @sayakpaul We’ve updated the fix according to the review. Could you take a quick look and share any feedback when you have a moment? Thank you in advance! |
|
Hey @DN6 @sayakpaul , As mentioned above, have fixed the comments. Could you help us guide on to the next steps? |
sayakpaul
left a comment
There was a problem hiding this comment.
Thanks for all the work on this PR. There are a couple of things that feel quite confusing to me. So, I would appreciate some explanations.
|
|
||
|
|
||
| # Model with only standalone computational layers at top level | ||
| class DummyModelWithStandaloneLayers(ModelMixin): |
There was a problem hiding this comment.
Why is this being deleted?
Rest of the diffs in this testing script are a bit difficult to follow honestly. Could we keep this cleaner?
There was a problem hiding this comment.
Thanks for pointing this out. The class was not intentionally deleted. From the git history, this shows up as removed due to branch history / rebase artifacts while integrating changes (rather than a deliberate change to the test itself), which makes the diff noisier than it should be. I’m cleaning this up now: I’ll restore that block and reorganize the commits so the test diffs are more focused/atomic and easier to review.
| pinned_dict = None | ||
|
|
||
| def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream): | ||
| def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream=None): |
There was a problem hiding this comment.
Why do we have to set the default of default_stream?
There was a problem hiding this comment.
I made it optional because the non-stream path calls _process_tensors_from_modules without a stream, there is nothing to record in that case, and record_stream is gated. None is a safety net for the record call, and it saves passing a placeholder from those call sites. If you prefer the stricter signature, I can keep it required and pass None explicitly where we don’t use streams. please do correct me thru my understanding if this is required to change
There was a problem hiding this comment.
Let's stick to the existing implementation in this case i.e., a stricter signature.
There was a problem hiding this comment.
Let's stick to a stricter signature.
| _apply_group_offloading_hook(module, unmatched_group, config=config) | ||
| else: | ||
| _apply_lazy_group_offloading_hook(module, unmatched_group, config=config) | ||
| elif config.stream is None and config.offload_to_disk_path is None: |
There was a problem hiding this comment.
This seems unnecessary. Explain?
There was a problem hiding this comment.
originally added the empty root hook to tag the top module as offloaded when everything else was matched, but it did not change behaviour, the child hooks already mark the model as group-offloaded and the guardrails rely on those. It just added an empty group and potential extra files, so have removed it to simplify. Functionally nothing depends on it.
There was a problem hiding this comment.
Let's remove this if it isn't relevant to pinning.
| low_cpu_mem_usage=config.low_cpu_mem_usage, | ||
| onload_self=True, | ||
| group_id=name, | ||
| group_id=f"{config.module_prefix}{name}", |
There was a problem hiding this comment.
It is the same thing as above, we prefix group_id with the parent name to avoid collisions (ids) when recursing into block_modules. Root stays empty to preserve existing ids, the prefix only appears when descending into children.
…feature/group-offload-pinning
8403860 to
2e8f538
Compare
|
@sayakpaul thank you for all ur comments, sorry for the delay in resolving. All have been answered, Please do let us know ur review |
| if isinstance(pin_groups, str) and pin_groups in VALID_PIN_GROUPS: | ||
| return pin_groups | ||
| raise ValueError( | ||
| f"`pin_groups` must be None, {', '.join(repr(v) for v in sorted(VALID_PIN_GROUPS))}, or a callable." | ||
| ) |
There was a problem hiding this comment.
| if isinstance(pin_groups, str) and pin_groups in VALID_PIN_GROUPS: | |
| return pin_groups | |
| raise ValueError( | |
| f"`pin_groups` must be None, {', '.join(repr(v) for v in sorted(VALID_PIN_GROUPS))}, or a callable." | |
| ) | |
| elif isinstance(pin_groups, str) and pin_groups not in VALID_PIN_GROUPS: | |
| raise ValueError( | |
| f"`pin_groups` must be None, {', '.join(repr(v) for v in sorted(VALID_PIN_GROUPS))}, or a callable." | |
| ) | |
| return pin_groups |
| pinned_dict = None | ||
|
|
||
| def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream): | ||
| def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream=None): |
There was a problem hiding this comment.
Let's stick to the existing implementation in this case i.e., a stricter signature.
| pinned_dict = None | ||
|
|
||
| def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream): | ||
| def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream=None): |
There was a problem hiding this comment.
Let's stick to a stricter signature.
|
|
||
| def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: | ||
| if self.group.offload_leader == module: | ||
| # For disk offload we materialize the safetensor files upfront so callers can inspect them immediately. |
There was a problem hiding this comment.
Can you clarify this scenario in the comments as well? And provide a small example that justifies this change?
| # If the current module is the onload_leader of the group, we onload the group if it is supposed | ||
| # to onload itself. In the case of using prefetching with streams, we onload the next group if | ||
| # it is not supposed to onload itself. |
There was a problem hiding this comment.
(nit): let's not get rid of the important comments.
| not self.group.onload_self | ||
| and self.group.stream is not None | ||
| and not should_onload_next_group | ||
| and not self.group.record_stream |
There was a problem hiding this comment.
Could I get a clarification on why this condition needs to be modified?
There was a problem hiding this comment.
did not change the condition, those same four checks were already present in both branches. I consolidated them into one place to avoid duplication. We still only sync when the group did not onload itself, we are using a stream, there is no pending prefetch, and record_stream is not handling lifetime tracking.
| if self.group.offload_leader == module: | ||
| self.group.offload_() | ||
| return output | ||
|
|
There was a problem hiding this comment.
This part of the diff reads very confusing to me and hence, a bit hard to confidently review. It seems to me, post_forward() was just brought up, _send_kwargs_to_device() was added (and I am not sure why) amongst other things. Possible to have a cleaner diff?
There was a problem hiding this comment.
Thanks for the flag, I simplified the diff ( have removed _send_kwargs_to_device and the kwargs handling is back inline in pre_forward as before )
| Args: | ||
| pin_groups (`"first_last"` | `"all"` | `Callable`, *optional*): | ||
| Optionally keep selected groups on the onload device permanently. See | ||
| [`~hooks.group_offloading.apply_group_offloading`] for details. |
There was a problem hiding this comment.
Are we just documenting pin_groups here? If so, we should remove that from here as apply_group_offloading() should already cover it:
| # keys toignore when AlignDeviceHook moves inputs/outputs between devices | ||
| # these are shared mutable state modified in-place | ||
| _skip_keys = ["feat_cache", "feat_idx"] | ||
| _group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"] |
There was a problem hiding this comment.
Let's also add a comment on how these modules were chosen to be included here.
|
@seed93 would you like to test it? |
|
Thanks again @sayakpaul for the detailed review! Have addressed all the points |
|
@sayakpaul Would u help us on the next steps in this PR |
|
Thanks @sayakpaul for running CI. Based on the logs, the failures don’t touch the group‑offloading changes in this PR:
@DN6 could you please guide on the next steps in the PR, with this been there for a while |
|
@sayakpaul @DN6 may we ask a quick confirmation whether the current CI failures are independent of the group_offloading changes in this PR? |
|
Yes, those are unrelated. |
|
Hi @sayakpaul , could you share the next steps for this PR? |
|
Hi @bconstantine. Apologies for the delay. I was caught up with a few other tasks. I will review this tomorrow. |
| # eg: model.enable_group_offload(..., offload_to_disk_path=tmpdir) | ||
| # assert glob.glob(f"{tmpdir}/*.safetensors") | ||
| # In-memory offload stays lazy to allow adapter loading before the first forward. | ||
| if self.group.offload_to_disk_path is not None and self.group.offload_leader == module: |
There was a problem hiding this comment.
What is the test that's failing? If this change is to make a test pass and it is not related to pinning, let's revert it please. We can create a dedicated PR to address that issue.
There was a problem hiding this comment.
Agreed. This eager disk offload was added to satisfy a test unrelated to pinning. have reverted it so offload remains lazy. If we still need eager materialization, I will open a separate PR scoped to that issue
| return send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking) | ||
|
|
||
| return args, kwargs | ||
| def _is_group_on_device(self) -> bool: |
There was a problem hiding this comment.
Instead of using this loop to check all tensors, it would be better to introduce an _is_offloaded attribute to the ModuleGroup which changes to True/False during onload/offload. Use that to check if the group is one device or not.
| for name, submodule in module.named_children(): | ||
| # Check if this is an explicitly defined block module | ||
| if name in block_modules: | ||
| if block_modules and name in block_modules: |
There was a problem hiding this comment.
? Change not needed I think? If block_modules is an empty set this won't evaluate to true.
There was a problem hiding this comment.
reverted this check back to if name in block_modules since block_modules is already a set and the extra guard is redundant
| _apply_group_offloading_hook(module, unmatched_group, config=config) | ||
| else: | ||
| _apply_lazy_group_offloading_hook(module, unmatched_group, config=config) | ||
| elif config.stream is None and config.offload_to_disk_path is None: |
There was a problem hiding this comment.
Let's remove this if it isn't relevant to pinning.
|
|
||
| # Ensure the top-level module also has a group_offloading hook so hook presence checks pass, | ||
| # even when it holds no parameters/buffers itself. | ||
| if config.stream is None: |
There was a problem hiding this comment.
What is the failing check that requires this change? It doesn't seem related to pinning?
| group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group | ||
| group_offloading_hooks[i].next_group.onload_self = False | ||
|
|
||
| if self.pin_groups is not None and num_executed > 0: |
There was a problem hiding this comment.
My understanding here is that this logic is responsible for onloading/pinning the groups during the first inference pass. The first pass of offloading with streams will always be slow since we need to figure out execution order. The complexity introduced with this change is probably not worth the incremental performance benefit.
I would suggest we factor this into simpler
def post_forward(self, module, output):
# existing prefetch setup code unchanged
if self.pin_groups is not None:
self._mark_pinned_groups(group_offloading_hooks)
return outputAfter marking, subsequent passes should just skip offloading. The group remains pinned after the second pass.
Also, what happens with pinning if we don't use streams? The prefetch hook doesn't run in that case.
There was a problem hiding this comment.
refactored the pinning into a _mark_pinned_groups helper called from post_forward after the prefetch setup, matching the simplified structure you suggested. Also added a warning + doc note, pin_groups is only supported with use_stream=True, it is ignored otherwise since the lazy prefetch hook doesn’t run without streams
|
@DN6 Thanks for the review, I have addressed all your comments. Please advise on the next steps for this PR |
What does this PR do?
Fixes #11966
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@sayakpaul