Skip to content

Conversation

@samdow
Copy link
Contributor

@samdow samdow commented Dec 8, 2022

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 8, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90477

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 2 Failures

As of commit d07d26f:

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on master:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@samdow samdow added release notes: nn release notes category topic: deprecation topic category labels Dec 8, 2022
@samdow samdow requested a review from zou3519 December 8, 2022 16:32
@samdow
Copy link
Contributor Author

samdow commented Dec 8, 2022

This changed the algorithm from make_functional and from the original PR #87079, so some new numbers from that

In running a script like this on resnet-18 (no tied parameters), this went from 58% slower than vanilla (not using functional_call at all) to 63% slower. This is a 3.5% slowdown from the functional_call before this PR

Raw numbers in case anyone wants to see:

% of parameters passed CPU Time (us) GPU Time (us)
Vanilla model 9032 48657
Pre-PR with 100% of params changed 14334 51680
This PR with 100% of params changed 14809 52187

@zou3519
Copy link
Contributor

zou3519 commented Dec 8, 2022

If I'm understanding the numbers, they are measuring some fixed number of iterations of resnet18 using the functional_call API. Shouldn't resnet be faster on GPU than CPU?

@samdow
Copy link
Contributor Author

samdow commented Dec 8, 2022

If I'm understanding the numbers, they are measuring some fixed number of iterations of resnet18 using the functional_call API

Yep! They're running it 10 times and then this is the average time as reported

Shouldn't resnet be faster on GPU than CPU?

Yeah I didn't explain these tables well. From the docs, my understanding is that this reports the amount of time used by the CPU and GPU separately. In this case, what we're worried about is the CPU time since the code for the stateless call happens on CPU

samdow pushed a commit that referenced this pull request Dec 8, 2022
ghstack-source-id: f3f275a
Pull Request resolved: #90477
Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code looks correct, some questions

parameters_and_buffers: Dict[str, Tensor],
args: Union[Any, Tuple],
kwargs: Dict[str, Any] = None,
tie_weights: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is technically BC-breaking (does it matter?). What's our deprecation plan for nn.utils.stateless.functional_call?
Easiest thing to do seems like:

  • nn.utils.stateless.functional_call should retain the same behavior as before (tie_weights=False?)
  • we deprecate nn.utils.stateless.functional_call in the next version of PyTorch (so, now on master)
  • we introduce a new torch.func.functional_call (probably needs a better name) to replace it in the next version of PyTorch that has our preferred default (tie_weights=True)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is technically BC-breaking (does it matter?). What's our deprecation plan for nn.utils.stateless.functional_call?

Yep, this is BC-breaking on a beta feature. This is from earlier talks with @albanD where we figured the default behavior should be what most people expect (which, from the make_functional requests, seems to be to have the tied weights get changed together).

Since we are moving it to torch.func anyway, I'm fine to change the default back to match old behavior for now and break it when we do the move unless @albanD has other thoughts?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both can work really.
But it might be simple to really have the torch.func version be an alias to the old version.
And yes from earlier discussions, I think this BC-breaking is ok.

samdow pushed a commit that referenced this pull request Dec 14, 2022
ghstack-source-id: 87031f7
Pull Request resolved: #90477
@zou3519 zou3519 self-requested a review December 28, 2022 14:26
Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code LGTM, minor comments. Algorithm feels a bit weird but I can't come up with something simpler

parameters_and_buffers: Dict[str, Tensor],
args: Union[Any, Tuple],
kwargs: Dict[str, Any] = None,
tie_weights: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "weights" can refer to only parameters, "tie_weights" effectively ties both parameters and buffers. Is that a problem? If we think it's a problem, we can rename it tie_weights_and_buffers. I kind of like the shorter name (and we can document that it applies to both parameters and buffers), but open to suggestions.

Docstring for nn.utils.stateless.functional_call needs to be updated with new flag and details on what it does (unless the plan was to just update torch.func.functional_call)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm somewhat attached to just weights since the original paper calls it weight tying (granted there they are only tying the weight/parameters). However I hear you that it's ambiguous so the docstring comment explicitly calls out that it ties both parameters and buffers

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM

def _swap_parameters(module, tensor_name: str, full_path: str, tensor: Tensor) -> None:

def _create_swap_params(params_and_buffers):
def _swap_parameters(module, tensor_name: str, full_path: str, tensor: Optional[Tensor]) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is tensor being used for?

Copy link
Contributor Author

@samdow samdow Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unused but previously it was the original weight. Since it was harder to get for tied weights, I passed None and had to update the signature to match that. I can add a patch on top of this to remove that parameter since it appears unused

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM

samdow pushed a commit that referenced this pull request Jan 9, 2023
ghstack-source-id: df4f19d
Pull Request resolved: #90477
samdow pushed a commit that referenced this pull request Jan 10, 2023
ghstack-source-id: c8a6c49
Pull Request resolved: #90477
@zou3519 zou3519 added the topic: bc breaking topic category label Jan 10, 2023
@samdow samdow added the module: functorch Pertaining to torch.func or pytorch/functorch label Jan 10, 2023
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot!

def add_to_name_map(n: str, t: torch.Tensor):
# if the tensor hasn't been seen before, add it to the map
if t not in weight_to_name_and_tied_names:
weight_to_name_and_tied_names[t] = (n, set()) if n in names else (None, {n})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for n in names, I am not sure what exact object the dict_keys object is. What is the complexity of this lookup?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc zou3519 Chillee soumith

[ghstack-poisoned]
samdow pushed a commit that referenced this pull request Jan 10, 2023
ghstack-source-id: 988c614
Pull Request resolved: #90477
@samdow samdow added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 10, 2023
@samdow
Copy link
Contributor Author

samdow commented Jan 11, 2023

@pytorchbot merge -f "failures from flaky test and unrelated mps test"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/samdow/50/head branch June 8, 2023 18:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: functorch Pertaining to torch.func or pytorch/functorch release notes: nn release notes category topic: bc breaking topic category topic: deprecation topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants