-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Adds 'clip' alias for clamp #42770
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds 'clip' alias for clamp #42770
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -261,6 +261,7 @@ Pointwise Ops | |
| bitwise_xor | ||
| ceil | ||
| clamp | ||
| clip | ||
| conj | ||
| cos | ||
| cosh | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7034,80 +7034,85 @@ def test_logical_and(self, device): | |
| def test_logical_or(self, device): | ||
| self._test_logical(device, 'logical_or', [10, 0, 1, 0], [1, 0, 0, 10], [1, 0, 1, 1]) | ||
|
|
||
| # Tests clamp and its alias, clip | ||
| def test_clamp(self, device): | ||
| m1 = torch.rand(100, device=device).mul(5).add(-2.5) # uniform in [-2.5, 2.5] | ||
| # just in case we're extremely lucky. | ||
| min_val = -1 | ||
| max_val = 1 | ||
| m1[1] = min_val | ||
| m1[2] = max_val | ||
| op_list = ((torch.clamp, torch.Tensor.clamp, torch.Tensor.clamp_), | ||
| (torch.clip, torch.Tensor.clip, torch.Tensor.clip_)) | ||
| for op, method_op, inplace_op in op_list: | ||
|
|
||
| res1 = m1.clone() | ||
| res1.clamp_(min_val, max_val) | ||
| res2 = m1.clone() | ||
| for i in iter_indices(res2): | ||
| res2[i] = max(min_val, min(max_val, res2[i])) | ||
| self.assertEqual(res1, res2) | ||
| m1 = torch.rand(100, device=device).mul(5).add(-2.5) # uniform in [-2.5, 2.5] | ||
| # just in case we're extremely lucky. | ||
| min_val = -1 | ||
| max_val = 1 | ||
| m1[1] = min_val | ||
| m1[2] = max_val | ||
|
|
||
| out = m1.clone() | ||
| torch.clamp(m1, min=min_val, max=max_val, out=out) | ||
| self.assertEqual(out, res1) | ||
| res1 = m1.clone() | ||
| inplace_op(res1, min_val, max_val) | ||
| res2 = m1.clone() | ||
| for i in iter_indices(res2): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe compare_with_numpy instead of this loop? |
||
| res2[i] = max(min_val, min(max_val, res2[i])) | ||
| self.assertEqual(res1, res2) | ||
|
|
||
| res1 = torch.clamp(m1, min=min_val) | ||
| res2 = m1.clone() | ||
| for i in iter_indices(res2): | ||
| res2[i] = max(min_val, res2[i]) | ||
| self.assertEqual(res1, res2) | ||
| out = m1.clone() | ||
| op(m1, min=min_val, max=max_val, out=out) | ||
| self.assertEqual(out, res1) | ||
|
|
||
| torch.clamp(m1, min=min_val, out=out) | ||
| self.assertEqual(out, res1) | ||
| res1 = op(m1, min=min_val) | ||
| res2 = m1.clone() | ||
| for i in iter_indices(res2): | ||
| res2[i] = max(min_val, res2[i]) | ||
| self.assertEqual(res1, res2) | ||
|
|
||
| res1 = torch.clamp(m1, max=max_val) | ||
| res2 = m1.clone() | ||
| for i in iter_indices(res2): | ||
| res2[i] = min(max_val, res2[i]) | ||
| self.assertEqual(res1, res2) | ||
| op(m1, min=min_val, out=out) | ||
| self.assertEqual(out, res1) | ||
|
|
||
| res1 = op(m1, max=max_val) | ||
| res2 = m1.clone() | ||
| for i in iter_indices(res2): | ||
| res2[i] = min(max_val, res2[i]) | ||
| self.assertEqual(res1, res2) | ||
|
|
||
| torch.clamp(m1, max=max_val, out=out) | ||
| self.assertEqual(out, res1) | ||
|
|
||
| # if the tensor contains nan case | ||
| test_tens = torch.tensor([nan], device=device) | ||
|
|
||
| res1 = test_tens.clone() | ||
| res1.clamp_(min_val, max_val) | ||
| res2 = test_tens.clone() | ||
| for i in iter_indices(res2): | ||
| res2[i] = max(min(res2[i], max_val), min_val) | ||
| self.assertEqual(torch.isnan(res1), torch.isnan(res2)) | ||
|
|
||
| out = test_tens.clone() | ||
| torch.clamp(test_tens, min=min_val, max=max_val, out=out) | ||
| self.assertEqual(torch.isnan(out), torch.isnan(res1)) | ||
|
|
||
| res1 = torch.clamp(test_tens, min=min_val) | ||
| res2 = test_tens.clone() | ||
| for i in iter_indices(res2): | ||
| res2[i] = max(res2[i], min_val) | ||
| self.assertEqual(torch.isnan(res1), torch.isnan(res2)) | ||
|
|
||
| torch.clamp(test_tens, min=min_val, out=out) | ||
| self.assertEqual(torch.isnan(out), torch.isnan(res1)) | ||
|
|
||
| res1 = torch.clamp(test_tens, max=max_val) | ||
| res2 = test_tens.clone() | ||
| for i in iter_indices(res2): | ||
| res2[i] = min(res2[i], max_val) | ||
| self.assertEqual(torch.isnan(res1), torch.isnan(res2)) | ||
|
|
||
| torch.clamp(test_tens, max=max_val, out=out) | ||
| self.assertEqual(torch.isnan(out), torch.isnan(res1)) | ||
|
|
||
| error_msg = 'At least one of \'min\' or \'max\' must not be None' | ||
| with self.assertRaisesRegex(RuntimeError, error_msg): | ||
| m1.clamp() | ||
| with self.assertRaisesRegex(RuntimeError, error_msg): | ||
| m1.clamp_() | ||
| op(m1, max=max_val, out=out) | ||
| self.assertEqual(out, res1) | ||
|
|
||
| # if the tensor contains nan case | ||
| test_tens = torch.tensor([nan], device=device) | ||
|
|
||
| res1 = test_tens.clone() | ||
| inplace_op(res1, min_val, max_val) | ||
| res2 = test_tens.clone() | ||
| for i in iter_indices(res2): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this loop not needed, just compare with expected [nan] here |
||
| res2[i] = max(min(res2[i], max_val), min_val) | ||
| self.assertEqual(torch.isnan(res1), torch.isnan(res2)) | ||
|
|
||
| out = test_tens.clone() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. given that you expect |
||
| op(test_tens, min=min_val, max=max_val, out=out) | ||
| self.assertEqual(torch.isnan(out), torch.isnan(res1)) | ||
|
|
||
| res1 = op(test_tens, min=min_val) | ||
| res2 = test_tens.clone() | ||
| for i in iter_indices(res2): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. loop is not needed |
||
| res2[i] = max(res2[i], min_val) | ||
| self.assertEqual(torch.isnan(res1), torch.isnan(res2)) | ||
|
|
||
| op(test_tens, min=min_val, out=out) | ||
| self.assertEqual(torch.isnan(out), torch.isnan(res1)) | ||
|
|
||
| res1 = op(test_tens, max=max_val) | ||
| res2 = test_tens.clone() | ||
| for i in iter_indices(res2): | ||
| res2[i] = min(res2[i], max_val) | ||
| self.assertEqual(torch.isnan(res1), torch.isnan(res2)) | ||
|
|
||
| op(test_tens, max=max_val, out=out) | ||
| self.assertEqual(torch.isnan(out), torch.isnan(res1)) | ||
|
|
||
| error_msg = 'At least one of \'min\' or \'max\' must not be None' | ||
| with self.assertRaisesRegex(RuntimeError, error_msg): | ||
| method_op(m1) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see method_op tested anywhere other than here where it raises an error |
||
| with self.assertRaisesRegex(RuntimeError, error_msg): | ||
| inplace_op(m1) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is nan propagation in clamp tested anywhere else? Vectorized and non-vectorized paths (could) propagate |
||
|
|
||
| def test_cat_empty_legacy(self, device): | ||
| # FIXME: this is legacy behavior and should be removed | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1425,8 +1425,7 @@ def merge_dicts(*dicts): | |
| [-0.0889, 0.2122, 0.1412]]) | ||
| """) | ||
|
|
||
| add_docstr(torch.clamp, | ||
| r""" | ||
| add_docstr(torch.clamp, r""" | ||
| clamp(input, min, max, out=None) -> Tensor | ||
|
|
||
| Clamp all elements in :attr:`input` into the range `[` :attr:`min`, :attr:`max` `]` and return | ||
|
|
@@ -1497,6 +1496,12 @@ def merge_dicts(*dicts): | |
| tensor([ 0.5000, -0.4702, -0.4599, 0.5000]) | ||
| """.format(**common_args)) | ||
|
|
||
| add_docstr(torch.clip, r""" | ||
| clip(input, min, max, *, out=None) -> Tensor | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is there "*" here but not in |
||
|
|
||
| Alias for :func:`torch.clamp`. | ||
| """.format(**common_args)) | ||
|
|
||
| add_docstr(torch.conj, | ||
| r""" | ||
| conj(input, out=None) -> Tensor | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -276,6 +276,8 @@ def method_tests(): | |
| ('clamp', (), (None, 0.5), 'min_scalar', (True,)), | ||
| ('clamp', (), (0.5, None), 'max_scalar', (True,)), | ||
| ('clamp', (S, S), (), 'max_scalar_kwarg', (True,), (), (), ident, {'max': 1}), | ||
| ('clip', (S, S, S), dont_convert((0, 1)), '', (False,)), | ||
| ('clip_', (S, S, S), dont_convert((0, 1)), '', (False,)), | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. before your additions of
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The alias test requires an inplace entry. |
||
| ('sqrt', torch.rand(S, S, S) + 5e-4, NO_ARGS, '', (True,)), | ||
| ('sqrt', uniform_scalar(5e-4, requires_grad=True), NO_ARGS, 'scalar', (True,)), | ||
| ('sin', (S, S, S), NO_ARGS, '', (True,)), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what are you trying to achieve here? Even if you are extremely lucky and your input tensor is bound by [min_val, max_val], these assignments are not going to change anything, and output will be equal to input.