Skip to content

Conversation

@peterbell10
Copy link
Collaborator

Fixes #39639 (and a lot more)

I went through every use of TensorIterator and added the check_mem_overlap flag to all the inplace ops, or ops using out tensors. I only touched TH/THC where the operation was implemented by TensorIterator on the other device for consistency. Hopefully the remaining ops can be added as they are ported to ATen.

@dr-ci
Copy link

dr-ci bot commented Jun 11, 2020

💊 CI failures summary and remediations

As of commit 60e0565 (more details on the Dr. CI page):


None of the CI failures appear to be your fault 💚



❄️ 8 failures tentatively classified as flaky

but reruns have not yet been triggered to confirm:

See CircleCI build pytorch_linux_xenial_py3_clang5_mobile_build (1/8)

Step: "Build" (full log | diagnosis details | 🔁 rerun) ❄️

Error response from daemon: manifest for 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan:8bdba785b1eac4d297d5f5930f979518012a56e0 not found
DOCKER_IMAGE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan:8bdba785b1eac4d297d5f5930f979518012a56e0 
Error response from daemon: manifest for 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan:8bdba785b1eac4d297d5f5930f979518012a56e0 not found 

See CircleCI build pytorch_linux_bionic_py3_7_conda_build (2/8)

Step: "Build" (full log | diagnosis details | 🔁 rerun) ❄️

Error response from daemon: manifest for 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-py3.7-conda:8bdba785b1eac4d297d5f5930f979518012a56e0 not found
DOCKER_IMAGE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-py3.7-conda:8bdba785b1eac4d297d5f5930f979518012a56e0 
Error response from daemon: manifest for 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-py3.7-conda:8bdba785b1eac4d297d5f5930f979518012a56e0 not found 

See CircleCI build pytorch_linux_xenial_py3_clang5_asan_build (3/8)

Step: "Build" (full log | diagnosis details | 🔁 rerun) ❄️

Error response from daemon: manifest for 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan:8bdba785b1eac4d297d5f5930f979518012a56e0 not found
DOCKER_IMAGE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan:8bdba785b1eac4d297d5f5930f979518012a56e0 
Error response from daemon: manifest for 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan:8bdba785b1eac4d297d5f5930f979518012a56e0 not found 

See CircleCI build pytorch_libtorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build (4/8)

Step: "Build" (full log | diagnosis details | 🔁 rerun) ❄️

Error response from daemon: manifest for 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7:8bdba785b1eac4d297d5f5930f979518012a56e0 not found
DOCKER_IMAGE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7:8bdba785b1eac4d297d5f5930f979518012a56e0 
Error response from daemon: manifest for 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7:8bdba785b1eac4d297d5f5930f979518012a56e0 not found 

See CircleCI build pytorch_linux_xenial_py3_clang5_mobile_custom_build_static (5/8)

Step: "Build" (full log | diagnosis details | 🔁 rerun) ❄️

Error response from daemon: manifest for 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan:8bdba785b1eac4d297d5f5930f979518012a56e0 not found
DOCKER_IMAGE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan:8bdba785b1eac4d297d5f5930f979518012a56e0 
Error response from daemon: manifest for 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan:8bdba785b1eac4d297d5f5930f979518012a56e0 not found 

See CircleCI build pytorch_xla_linux_bionic_py3_6_clang9_build (6/8)

Step: "Build" (full log | diagnosis details | 🔁 rerun) ❄️

Error response from daemon: manifest for 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-py3.6-clang9:8bdba785b1eac4d297d5f5930f979518012a56e0 not found
DOCKER_IMAGE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-py3.6-clang9:8bdba785b1eac4d297d5f5930f979518012a56e0 
Error response from daemon: manifest for 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-py3.6-clang9:8bdba785b1eac4d297d5f5930f979518012a56e0 not found 

See CircleCI build pytorch_linux_bionic_py3_6_clang9_build (7/8)

Step: "Build" (full log | diagnosis details | 🔁 rerun) ❄️

Error response from daemon: manifest for 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-py3.6-clang9:8bdba785b1eac4d297d5f5930f979518012a56e0 not found
DOCKER_IMAGE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-py3.6-clang9:8bdba785b1eac4d297d5f5930f979518012a56e0 
Error response from daemon: manifest for 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-py3.6-clang9:8bdba785b1eac4d297d5f5930f979518012a56e0 not found 

See CircleCI build pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build (8/8)

Step: "Build" (full log | diagnosis details | 🔁 rerun) ❄️

Error response from daemon: manifest for 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7:8bdba785b1eac4d297d5f5930f979518012a56e0 not found
DOCKER_IMAGE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7:8bdba785b1eac4d297d5f5930f979518012a56e0 
Error response from daemon: manifest for 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7:8bdba785b1eac4d297d5f5930f979518012a56e0 not found 

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 38 times.

@mruberry mruberry added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 12, 2020
@ezyang
Copy link
Contributor

ezyang commented Jun 12, 2020

One thing I'm not sure about is (1) is it true that we should be universally checking for memory overlap in all inplace/out TensorIterator invocations, and (2) if so, why isn't this built into the API in a more structured way, rather than as an argument to each call site that you could forget.

@ezyang
Copy link
Contributor

ezyang commented Jun 12, 2020

@peterbell10 please add the original author of memory overlap and the reviewers of that PR as reviewers to this PR

@peterbell10 peterbell10 requested a review from pbelevich June 12, 2020 15:14
@peterbell10
Copy link
Collaborator Author

cc @pbelevich who seems to have done most of the TensorIterator work in #22917 and #24058

is it true that we should be universally checking for memory overlap in all inplace/out TensorIterator invocations

Most of the time, I think so. There are some cases like reductions where the tensors are intentionally given 0-strides and so these checks would give false-positives.

why isn't this built into the API in a more structured way, rather than as an argument to each call site that you could forget.

Perhaps we should remove the defaults from the n-ary_op functions. People creating custom tensor iterators could still forget to call set_check_mem_overlap though.

@ezyang ezyang requested review from VitalyFedyunin and zou3519 June 15, 2020 15:45
@ezyang
Copy link
Contributor

ezyang commented Jun 15, 2020

Removing myself as reviewer.

@peterbell10
Copy link
Collaborator Author

As a side note, in numpy broadcasted arrays are marked as read-only to avoid these issues:

>>> x = np.random.rand(100)
>>> y = np.broadcast_to(x, (10, 100))
>>> y[0, 0] = 1
ValueError: assignment destination is read-only
>>> y.flags['WRITEABLE']
False

@zou3519
Copy link
Contributor

zou3519 commented Jun 16, 2020

Perhaps we should remove the defaults from the n-ary_op functions.

I'm in favor of removing the check_mem_overlap argument from TensorIterator::nullary_op / TensorIterator::binary_op / TensorIterator::unary_op and just having those always check memory overlap. I don't see any cases when it comes to those where we wouldn't want to check memory overlap but it's also been a while since I've looked at the code.

People creating custom tensor iterators could still forget to call set_check_mem_overlap though.

Something to consider is that we could just default to checking memory overlap in custom tensor iterators but selectively disable it for the cases we know won't work. It sounds like reductions are one such case?

@peterbell10 peterbell10 force-pushed the mem-overlap branch 2 times, most recently from ef62811 to 6d6f832 Compare June 17, 2020 20:32
@peterbell10
Copy link
Collaborator Author

I've rebased on the latest master, changed the default in TensorIteratorConfig to always check for memory overlap and removed the check_mem_overlap parameter from the n-ary_op constructors.

@ezyang
Copy link
Contributor

ezyang commented Jun 18, 2020

Thanks!

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(not a full review)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did this have to change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a bug exposed by the new checks in masked_fill. It first broadcasts self.logits with value, then writes to the broadcasted (i.e. zero-strided) array in a way that overwrites entire rows instead of only where value == 0.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did this have to change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The next line writes into a broadcasted array:
expected[expected == float('inf')] = 0.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems BC-breaking. Even if expected is a broadcasted array, expected[expected == float('inf')] = 0. produces a correct result because it is setting the right elements to 0

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this is a masked_fill_ which is exactly the operation in the original issue. To preserve the correct uses, like

expected[expected == float('inf')] = 0.

then we could inspect the mask along the broadcasted dimensions and only allow the fill if the mask is constant along broadcasted dimensions. Maybe deprecating that behavior as well.

@zou3519
Copy link
Contributor

zou3519 commented Jun 18, 2020

Hmm, there are some operations where checking for memory overlap isn't necessary and would cause backwards incompatibility. For example, checking memory overlap for something like torch.tensor(0.4).expand(3).cos_() is good, because there's a chance that the result is cos(cos(cos(0.4))).

However, we have cases like torch.tensor(0.4).expand(3).zero_() where the result of the operation is unambiguous and this would lead to backwards incompatibility. Maybe that's OK. I am trying to brainstorm if there are other cases like this (the expected[expected == float('inf')] = 0) example seems like one of them).

EDIT: another way to solve the problem is to follow numpy and make broadcasted arrays read-only. That's a pretty big change though and I am not sure what it would entail

@peterbell10
Copy link
Collaborator Author

peterbell10 commented Jun 18, 2020

For zero_ specifically, perhaps fill with scalar could be made an exception. However, I'm not sure what you could be planning to do with it afterwards. Presumably you'd then start writing to indexed locations and then it becomes buggy again. So, I'm unsure that an exception would really be helpful.

Perhaps fill should just raise a deprecation warning instead of an error?

@zou3519
Copy link
Contributor

zou3519 commented Jul 17, 2020

@peterbell10 sorry, this fell off my radar -- I'll have more time to review this starting next week and will give this a read through then

@zou3519
Copy link
Contributor

zou3519 commented Jul 24, 2020

cc @peterbell10 I thought about the problem we’re addressing here and wrote up some comments. Please let me know what your thoughts are.

Also cc @mruberry (for numpy compatibility) and @VitalyFedyunin @ngimel @ezyang (for TensorIterator). I'm curious if you folks have thoughts on the internal memory overlap problem and potential solutions.

What is the issue we’re trying to solve?

As a common pitfall, users will expand tensors and then perform in-place operations on them. We’ve mitigated this problem in the past by adding internal memory overlap checks to some of the more common operations (e.g., we check if a Tensor has stride 0 and size > 1). However, this problem still exists for a lot of other operators.

My concerns on this Pull Request

Potential for BC-breaking behavior: There are some operators where if an output has internal memory overlap, or if an output and input have partial memory overlap, the result is always correct (e.g., Tensor.zero_). There are some other more interesting cases, like

expected = torch.randn(3).expand(4, 5, 3)
expected[expected == float('inf')] = 0

where the output tensor has internal overlap, but it is OK that it does and making a change like this would be BC-breaking (in fact, there is a test in this PR that we had to change because of this). Maybe these cases don’t actually matter, but we I think if we go down the route of this PR, we should take a detailed survey of the operations involved that we would need to deprecate and have a deprecation plan.

Decision to make TensorIterator check for memory overlap by default.

This isn't a complete solution to our problem because not all PyTorch operations use TensorIterator or will use TensorIterator in the future. Secondly, we’ve been doing a lot of TH->ATen migrations. If we change TensorIterator to check for memory overlap by default, then TH->ATen migrations may unwillingly introduce BC-breaking changes.

Alternatives

I think if we have to go through the trouble of identifying all the BC-breaking operators and deprecating the "expand + modify-in-place" behavior for those operators, we should take some time to explore the problem space. Assuming we have to break BC, it'll be 6-9 months (2-3 releases) to finish deprecating and updating everything.

One alternative that seems nice to me is to introduce a WRITEABLE flag to tensors, a la numpy. Broadcasted tensors (obtained from e.g. torch.expand, torch.broadcast_tensors) would have WRITEABLE=False and raise an error message on in-place operations. If it is possible to bake this change into the codebase structurally, then it sounds like a better solution: instead of having to check internal memory-overlap over the codebase, maybe we can just error out when an operator tries to ask for the tensor's data_ptr if WRITEABLE=False.

@peterbell10
Copy link
Collaborator Author

One alternative that seems nice to me is to introduce a WRITEABLE flag to tensors, a la numpy. [...] instead of having to check internal memory-overlap over the codebase, maybe we can just error out when an operator tries to ask for the tensor's data_ptr if WRITEABLE=False.

To go down this route I think there would need to be const and mutable variants of all accessors like data_ptr (either by different names, or say a ConstTensor class). Then each operator would need to be updated for const-correctness. Otherwise, I think you'd end up with functions that only read from a tensor but would none-the-less require the inputs be writeable.

A huge benefit of having TensorIterator take care of it is that it already knows which tensors are inputs and which are outputs. I suppose something similar could be done with the code generation since modified tensors are marked with !. These checks could be enabled everywhere, but at first only emit only a deprecation warning and then down the line make it a full error. Notably, this solution would cover all native functions and not just those using TensorIterator.

@ezyang
Copy link
Contributor

ezyang commented Jul 27, 2020

I'm not exactly sure what to do here. Let me describe some high level things to think about.

  1. If we break BC, things are less bad if we have reasonable ways for people to rewrite their code in the new way. expected[expected == float('inf')] = 0 is a good case to think about: if we make this stop working, is there a way for someone to recover the old behavior? (It doesn't sound like you can conveniently do so in this example; it seems like we need some sort of de-expand operation. But it's worth pointing out that this code is doing more work than it needs to, since expected == float('inf') is going to uselessly redo the inf test for the broadcasted elements.)
  2. I don't think the test suite is necessarily a good test for what people are likely to do in the wild. We have a lot of synthetic examples to exercise particular cases. So be careful about overoptimizing for the test case.
  3. Regarding Peter Bell's suggestion, @smessmer has advocated for const and mutable variants of data_ptr in the past. This was in a different context (cc @ljk53) where we need to update version counters on storages when we do mutable operations on them. It's worth noting that we actually do have the ability to identify what operators do mutations at a higher level (e.g., by just looking at schema), as mutability is annotated in schema. So you don't have to do it data_ptr style; and it might be feasible to include test for mutating broadcasting at the same time you check if you need to update a version counter. However, one complicating factor is that you want to error BEFORE you actually mutate anything (whereas version counters should be updated AFTER you mutate something.)
  4. When possible, opt for something simpler and requires less work when things are confusing.
  5. I would guess there aren't too many operations which work "correctly" when you pass them in broadcasted outputs; few enough that it would be fine to just special case them specifically.

Hope that helps.

@peterbell10
Copy link
Collaborator Author

I've rebased and added an exception for .fill_.

expected[expected == float('inf')] = 0 is a good case to think about: if we make this stop working, is there a way for someone to recover the old behavior?

With numpy's flags system the user can set the writable flag manually, then write to the array as normal. Though, a more idiomatic way would be to use where(expected == float('inf'), 0, expected) instead of writing to the array.

@zou3519
Copy link
Contributor

zou3519 commented Aug 11, 2020

Now that #41923 is closed I am returning to review this.

So from the tests it looks like we've added memory overlap support for the following:

random_
uniform_
cauchy_
log_normal_
exponential_
geometric_
normal_
fmod
lerp
bernoulli_/bernoulli
index_put_
masked_fill_
masked_select_
masked_scatter_
index_select_
cat
scatter_
gather_
linspace
logspace

Did we enable memory overlap for any other operators? I feel like switching the default on TensorIterator to check memory overlap should have affected more operators

@peterbell10
Copy link
Collaborator Author

Most TensorIterator operators were already using the check_mem_overlap=True option. In an earlier version of this PR I had gone through all TensorIterator users and added check_mem_overlap=True wherever it was missing and those test cases matched that version.

@zou3519
Copy link
Contributor

zou3519 commented Aug 18, 2020

@peterbell10, sorry for the delay in review. I had a conversation with @mruberry offline about this PR. We came to the following ideas:

  • Let's not pursue alternative solutions (like adopting a WRITABLE flag) for now. An alternative solution would require much larger changes that could happen in the future, but we would still like this PR to go in because it bring value and is not in conflict with an alternative solution.
  • We should try not to make any previously correct behavior BC-breaking. So we'll want exceptions for fill_ (already included in this PR)and other functions likefill_` (I'll check if there are any)
  • Having the default for TensorIterator check memory overlap (what is currently written in this PR) is good because there are more operations that will want to check memory overlap than not.

From those points, I'll review the pull request from the following perspective:

  • I'll take an audit of all of the operators touched in the PR, and check for BC-breaking behavior.

@mruberry
Copy link
Collaborator

  • We should try not to make any previously correct behavior BC-breaking. So we'll want exceptions for fill_ (already included in this PR)and other functions likefill_` (I'll check if there are any)

It'd also be nice to document when functions should and shouldn't be setting the check flag.

("addcdiv", True, True, 'cuda'),
("lerp", True, True, 'cpu'),
("lerp", False, False, 'cuda')
("lerp", True, True, 'cuda')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did an audit of TensorIterator::unary_op, TensorIterator::binary_op, and TensorIterator::nullary_op cases in the code that use the default. The default before this PR was check_mem_overlap=False, but after this PR, those are True. I found a few that we didn't test, so it would be nice to add tests for them (to check that we've added the checks for both CPU and CUDA:

For: TensorIterator::unary_op

  • elu: test in-place variant:
    Tensor & elu_(
    Tensor & self,
    Scalar alpha,
    Scalar scale,
    Scalar input_scale) {
    return at::elu_out(self, self, alpha, scale, input_scale);
    }
  • hardswish: test in-place variant:
    Tensor& hardswish_out(Tensor& result, const Tensor& self) {
    auto iter = TensorIterator::unary_op(result, self);
    hardswish_stub(iter.device_type(), iter);
    return result;
    }
  • leaky_relu: test in-place variant
    Tensor& leaky_relu_out(
    Tensor& result,
    const Tensor& self,
    Scalar negval) {
    auto iter = TensorIterator::unary_op(result, self);
    leaky_relu_stub(iter.device_type(), iter, negval);
    return result;
    }
    Tensor leaky_relu(
    const Tensor& self,
    Scalar negval) {
    Tensor result;
    auto iter = TensorIterator::unary_op(result, self);
    leaky_relu_stub(iter.device_type(), iter, negval);
    return iter.output();
    }
    Tensor & leaky_relu_(
    Tensor & self,
    Scalar neg_val) {
    return at::leaky_relu_out(self, self, neg_val);
    }

For TensorIterator::binary_op:

  • index_add_:
    auto iter = TensorIterator::binary_op(selfSlice, selfSlice, sourceSlice);
    has two paths.
  • atan2: test out variant and in-place variant:
    Tensor& atan2_out(Tensor& result, const Tensor& self, const Tensor& other) {
    auto iter = TensorIterator::binary_op(result, self, other);
    atan2_stub(iter.device_type(), iter);
    return result;
    }
    Tensor atan2(const Tensor& self, const Tensor& other) {
    Tensor result = at::empty({0}, self.options());
    return native::atan2_out(result, self, other);
    }
    Tensor& atan2_(Tensor& self, const Tensor& other) {
    return native::atan2_out(self, self, other);
    }
  • ilshift:
    Tensor& __ilshift__(Tensor& self, const Tensor& other) {
    auto iter = TensorIterator::binary_op(self, self, other);
    lshift_stub(iter.device_type(), iter);
    return self;
    }
    Tensor& __ilshift__(Tensor& self, Scalar other) {
    auto wrapper = wrapped_scalar_tensor(other).toType(self.scalar_type());
    auto iter = TensorIterator::binary_op(self, self, wrapper);
    lshift_stub(iter.device_type(), iter);
    return self;
    }
  • irshift:
    Tensor& __irshift__(Tensor& self, const Tensor& other) {
    auto iter = TensorIterator::binary_op(self, self, other);
    rshift_stub(iter.device_type(), iter);
    return self;
    }
    Tensor& __irshift__(Tensor& self, Scalar other) {
    auto wrapper = wrapped_scalar_tensor(other).toType(self.scalar_type());
    auto iter = TensorIterator::binary_op(self, self, wrapper);
    rshift_stub(iter.device_type(), iter);
    return self;
    }
  • hypot: test out variant and in-place variant:
    Tensor& hypot_out(Tensor& result, const Tensor& self, const Tensor& other) {
    auto iter = TensorIterator::binary_op(result, self, other);
    hypot_stub(iter.device_type(), iter);
    return result;
    }
    Tensor hypot(const Tensor& self, const Tensor& other) {
    Tensor result;
    auto iter = TensorIterator::binary_op(result, self, other);
    hypot_stub(iter.device_type(), iter);
    return iter.output();
    }
    Tensor& hypot_(Tensor& self, const Tensor& other) {
    return at::hypot_out(self, self, other);
    }
  • nextafter: test out variant and in-place variant:
    Tensor& nextafter_out(Tensor& result, const Tensor& self, const Tensor& other) {
    auto iter = TensorIterator::binary_op(result, self, other);
    nextafter_stub(iter.device_type(), iter);
    return result;
    }
    Tensor nextafter(const Tensor& self, const Tensor& other) {
    Tensor result;
    auto iter = TensorIterator::binary_op(result, self, other);
    nextafter_stub(iter.device_type(), iter);
    return iter.output();
    }
    Tensor& nextafter_(Tensor& self, const Tensor& other) {
    return at::nextafter_out(self, self, other);
    }
  • threshold: test in-place variant:
    Tensor threshold(const Tensor& self, Scalar threshold, Scalar value) {
    return threshold_out(nullopt, self, threshold, value, self);
    }
    Tensor& threshold_(Tensor& self, Scalar threshold, Scalar value) {
    threshold_out(make_optional(self), self, threshold, value, self);
    return self;
    }
    Tensor& threshold_out(Tensor& result, const Tensor& self, Scalar threshold, Scalar value) {
    threshold_out(make_optional(result), self, threshold, value, self);
    return result;
    }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In retrospect, because there are so many ops, I think this PR would be easier to review if we split it up into, e.g.:

  • a PR to remove TensorIterator::unary_op's check_mem_overlap (and instead default to True)
  • a PR to remove TensorIterator::binary_op's check_mem_overlap (and instead default to True)
  • ...
  • a PR that hits the switch for the default of check_mem_overlap in TensorIterator to True.

However, I'm fine with reviewing all the changes in one single PR (this PR) because the ultimate line count will probably still be manageable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try to split this up into a clean stack of PRs. The longer each individual change takes to land, the more ops will be ported to TensorIterator with the old default and will need changing. So anything to make review go smoother should be worth it IMO.

c10::optional<DimVector> static_shape_ = c10::nullopt;
c10::optional<std::pair<ScalarType, Device>> static_dtype_and_device_ = c10::nullopt;
bool check_mem_overlap_ = false;
bool check_mem_overlap_ = true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add some prescriptive documentation to

TensorIteratorConfig& set_check_mem_overlap(bool check_mem_overlap) {
check_mem_overlap_ = check_mem_overlap;
return *this;
}
describing when a user should set_check_mem_overlap = False, and that the default is True.

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(not a full review, I'm still going through the ops)

@peterbell10
Copy link
Collaborator Author

Closing as this was superseded by gh-43423

@peterbell10 peterbell10 closed this Sep 2, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

masked_fill_ (and possibly others) produces a different output than masked_fill on cpu

4 participants