Skip to content

Conversation

@pmeier
Copy link
Collaborator

@pmeier pmeier commented Oct 28, 2022

Stack from ghstack (oldest at bottom):

Redo of #86586 with all BC breaking changes granularly placed into separate commits.


Per title. Deprecation happened on Feb 25, 2022 in c6f1bbc, which made it into the 1.12 release. Since it is now 245 days later and the next release will be 1.14, the removals later in the stack comply with the BC policy.

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 28, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87969

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 Failures

As of commit 90869bd:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: distributed (ddp) release notes category labels Oct 28, 2022
@pmeier pmeier added ciflow/trunk Trigger trunk jobs on your pull request ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR topic: not user facing topic category and removed release notes: distributed (ddp) release notes category ciflow/mps Run MPS tests (subset of trunk) labels Oct 28, 2022
@pytorch-bot pytorch-bot bot added the ciflow/mps Run MPS tests (subset of trunk) label Oct 28, 2022
Copy link
Collaborator Author

@pmeier pmeier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There were two things that needed manual handling:

  1. torch.testing.make_non_contiguous needed to be replaced with one of the suggested options as was applicable:

    @warn_deprecated(
    "Depending on the use case there a different replacement options:\n\n"
    "- If you are using `make_non_contiguous` in combination with a creation function to create a noncontiguous tensor "
    "with random values, use `torch.testing.make_tensor(..., noncontiguous=True)` instead.\n"
    "- If you are using `make_non_contiguous` with a specific tensor, you can replace this call with "
    "`torch.repeat_interleave(input, 2, dim=-1)[..., ::2]`.\n"
    "- If you are using `make_non_contiguous` in the PyTorch test suite, use "
    "`torch.testing._internal.common_utils.noncontiguous_like` instead."
    )

  2. torch.testing.assert_allclose couldn't always 1-to-1 replaced with the new-ish (stable since Feb 2022) torch.testing.assert_close. See #61844 for a detailed analysis of the differences.

I've highlighted all places where I did more than a simple replace of deprecated functionality below.

Comment on lines 17 to 20
# The function below is a faithful replica of the former `torch.testing.assert_allclose`. This is only here,
# because it is used extensively throughout the tests in this package while needing one feature that
# the new `torch.testing.assert_close` does not offer: comparison between numpy arrays and torch tensors. See
# https://github.com/pytorch/pytorch/issues/61844 for the reasoning why this feature was removed.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per comment. This only applies to the tests located in caffe2/python/operator_test/*.

dist = Poisson(rate_zero)
dist.log_prob(torch.ones_like(rate_zero)).backward()
torch.testing.assert_allclose(rate_zero.grad, torch.inf)
self.assertEqual(rate_zero.grad, torch.inf)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing scalar tensors to Python scalars is not supported by torch.testing.assert_close. I've opted to use self.assertEqual here since that still includes this type wrangling for BC.

trace = torch.jit.trace(fn, args)
self.assertAllFused(trace.graph_for(*args))
torch.testing.assert_allclose(fn(*args), trace(*args))
torch.testing.assert_close(fn(*args), trace(*args), equal_nan=True)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In torch.testing.assert_close, equal_nan is False by default. Are NaN's actually ok here or did this pass silently before?

py_relu_cpu = py_relu.to("cpu")

torch.testing.assert_allclose(np_relu, py_relu_cpu)
self.assertEqual(np_relu, py_relu_cpu)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing numpy.ndarray's to torch.Tensor's is not supported in torch.testing.assert_close. self.assertEqual still allows it.

Comment on lines +2962 to +2964
idx = torch.testing.make_tensor(
num_src, low=0, high=num_dest, dtype=idx_dtype, device=device, noncontiguous=index_noncontig
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to create a noncontiguous tensor right away, we can use torch.testing.make_tensor directly.

src = torch.randn(num_copy, *other_sizes, device=device)
if not src_contig:
src = torch.testing.make_non_contiguous(src)
src = noncontiguous_like(src)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not using torch.testing.make_tensor here since src samples from a normal distribution and the former would sample from a uniform one. If the intent is to just sample positive as well as negative values, LMK.

idx = torch.randperm(num_dest, dtype=dtype, device=device).narrow(0, 0, num_copy)
if not index_contig:
idx = torch.testing.make_non_contiguous(idx)
idx = noncontiguous_like(idx)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above but for a random permutation.

fsdp_loss = fsdp_loss.cuda()
fsdp_unsharded_params = get_full_params(fsdp_model)
torch.testing.assert_allclose(ref_loss, fsdp_loss)
torch.testing.assert_close(ref_loss, fsdp_loss, check_dtype=False)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strict dtype checking is the default for torch.testing.assert_close. Are mismatching dtypes ok here?

_Redo of #86586 with all BC breaking changes granularly placed into separate commits._

---

Per title. Deprecation happened on Feb 25, 2022 in c6f1bbc, which made it into the 1.12 release. Since it is now 245 days later and the next release will be 1.14, the removals later in the stack comply with the [BC policy](https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#minimizing-the-disruption-of-bc-breaking-changes).

[ghstack-poisoned]
_Redo of #86586 with all BC breaking changes granularly placed into separate commits._

---

Per title. Deprecation happened on Feb 25, 2022 in c6f1bbc, which made it into the 1.12 release. Since it is now 245 days later and the next release will be 1.14, the removals later in the stack comply with the [BC policy](https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#minimizing-the-disruption-of-bc-breaking-changes).

[ghstack-poisoned]
_Redo of #86586 with all BC breaking changes granularly placed into separate commits._

---

Per title. Deprecation happened on Feb 25, 2022 in c6f1bbc, which made it into the 1.12 release. Since it is now 245 days later and the next release will be 1.14, the removals later in the stack comply with the [BC policy](https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#minimizing-the-disruption-of-bc-breaking-changes).

[ghstack-poisoned]
_Redo of #86586 with all BC breaking changes granularly placed into separate commits._

---

Per title. Deprecation happened on Feb 25, 2022 in c6f1bbc, which made it into the 1.12 release. Since it is now 245 days later and the next release will be 1.14, the removals later in the stack comply with the [BC policy](https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#minimizing-the-disruption-of-bc-breaking-changes).

[ghstack-poisoned]
@pmeier pmeier marked this pull request as ready for review November 2, 2022 10:27
Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

There are several test failures, but I can't imagine they're related to this PR

pytorchmergebot pushed a commit that referenced this pull request Nov 2, 2022
pytorchmergebot pushed a commit that referenced this pull request Nov 2, 2022
pytorchmergebot pushed a commit that referenced this pull request Nov 2, 2022
pytorchmergebot pushed a commit that referenced this pull request Nov 2, 2022
pytorchmergebot pushed a commit that referenced this pull request Nov 2, 2022
See #87969 or #86586 for the reasoning.

Pull Request resolved: #87974
Approved by: https://github.com/mruberry
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Nov 5, 2022
…#87969)

_Redo of pytorch#86586 with all BC breaking changes granularly placed into separate commits._

---

Per title. Deprecation happened on Feb 25, 2022 in c6f1bbc, which made it into the 1.12 release. Since it is now 245 days later and the next release will be 1.14, the removals later in the stack comply with the [BC policy](https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#minimizing-the-disruption-of-bc-breaking-changes).
Pull Request resolved: pytorch#87969
Approved by: https://github.com/mruberry
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Nov 5, 2022
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Nov 5, 2022
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Nov 5, 2022
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Nov 5, 2022
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Nov 5, 2022
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
…#87969)

_Redo of pytorch#86586 with all BC breaking changes granularly placed into separate commits._

---

Per title. Deprecation happened on Feb 25, 2022 in c6f1bbc, which made it into the 1.12 release. Since it is now 245 days later and the next release will be 1.14, the removals later in the stack comply with the [BC policy](https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#minimizing-the-disruption-of-bc-breaking-changes).
Pull Request resolved: pytorch#87969
Approved by: https://github.com/mruberry
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
@kit1980 kit1980 added the Merged label Mar 24, 2023
@facebook-github-bot facebook-github-bot deleted the gh/pmeier/35/head branch June 8, 2023 18:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mps Run MPS tests (subset of trunk) ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged open source topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants