-
Notifications
You must be signed in to change notification settings - Fork 26.3k
streamline stride propagation logic in TensorIterator #42922
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
💊 CI failures summary and remediationsAs of commit 06ae52c (more details on the Dr. CI page):
XLA failureJob pytorch_xla_linux_bionic_py3_6_clang9_test is failing. Please create an issue with title prefixed by This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 32 times. |
dzhulgakov
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Epic! It would be a good idea to merge a lot of observations from the PR descriptions as inline comments in the right places. All this logic is way too magical and requires explanations :)
| } | ||
| auto tensor_shape = invert_perm(shape_); | ||
| if (inverted) { | ||
| if (!op.tensor.defined()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, add {}
| // can just return contiguous output | ||
| // it is faster because it avoids allocating 0 size tensor and | ||
| // resizing and restriding it | ||
| op.tensor = at::empty(tensor_shape, op.options()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so we might just delete this branch if empty_strided is faster? maybe expand the comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's still unpermuting of strides that needs to be done even if empty_strided is faster, depending on perf impact we might ignore it and delete this branch.
| op.stride_bytes = compatible_stride(element_size); | ||
| // check if permutation is just an inverted order | ||
| bool inverted = true; | ||
| for (int i = 1; i <= ndim(); i++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very nit: write loop from 0 and <, all loops starting from 1 are suspicious :)
| operands_[i].will_resize = true; | ||
| continue; | ||
| } | ||
| TORCH_CHECK(is_reduction_, "output with shape ", output.sizes(), " doesn't match the broadcast shape ", |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
| } else { | ||
| } else if (stride0 > stride1) { | ||
| return 1; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so we treat equal dims as ambigious? shouldn't we do something like "preserver the original order", i.e. compare dim0 and dim1 themselves? Maybe that's what the insertion sort below does, but it's hard for me to mentally unravel it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a very good question! Added a "flies in the ointment" section to the description. When strides are equal, we don't swap corresponding dimension, so we tend to favor identity (or closer-to-identity) permutation in this case.
| } | ||
| } | ||
|
|
||
| template <int dim, MemoryFormat memory_format> |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
|
I'm impressed! Couple questions:
|
| } | ||
| if (stride0 == 0 || stride1 == 0) { | ||
| continue; | ||
| } else if (stride0 <= stride1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we keep stride0 <= stride1?
I think this is used to decide whether the comparison should continue in current loop within the insertion sort. The <= logic will make the loop stop earlier which should be more efficient.
Let's say we have dims 0 1 2 3 4 5 with stride 2 8 2 4 2 1. So perm_ will be 5 4 3 2 1 0 at the beginning.
Let's say we sorted perm_[0] to perm_[5] already, now the perm_ is 5 4 2 3 1 0, with stride 1 2 2 4 8 2. Then we pick up the last item perm_[5], stride is 2 and compare with all the sorted dims' strides and swap. By keeping the <= logic, the loop will break at perm_[2] and the < logic will delay the break to perm_[0] (but nothing happen from 2 to 0).
Please correct me if I am wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With <= logic the original tests were not passing, the loop was breaking too early and full permutation was not recovered.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see why, thanks.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Fast path benchmarks https://www.internalfb.com/intern/aibench/details/2193418916/ - for some reason inplace ops regress ~5% which is very weird, given that code path is literally the same. I'll do benchmarks for slow paths. |
|
At some point we should reorder the definitions in TensorIterators.cpp to be in order they run in the algorithm, it would have been a lot easier to read the PR if I could read top down and know that it was chronological |
|
Note for which parts of TensorIterator algorithm got modified: |
|
I was expecting to see something in the PR description talking about "resize", but I didn't see anything anywhere. Maybe the description can be expanded to talk about the general strategy that the new algorithm takes? (Or better yet, as Dmytro suggests, put it in code.) |
| TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand", i); | ||
| op.tensor = at::empty(shape_, op.options(), MemoryFormat::Contiguous); | ||
| op.current_dtype = op.target_dtype; | ||
| } else if (op.will_resize) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have an invariant that it will never be the case that a tensor is undefined AND will_resize is true?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, will_resize can be set to true only in mark_resize_outputs, and tensor.defined() is a necessary condition to enter that branch.
| op.current_dtype = op.target_dtype; | ||
| } else if (op.will_resize) { | ||
| at::native::resize_output(op.tensor, shape_); | ||
| op.tensor.as_strided_(shape_, operands_[i_defined].tensor.strides()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trusting that these are the only three sites that need updating.
| op.original_tensor.scalar_type() != op.current_dtype) { | ||
| if (op.original_tensor.sizes() != op.tensor.sizes()){ | ||
| op.original_tensor.resize_as_(op.tensor).as_strided_(op.tensor.sizes(), op.tensor.strides()); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this because we are lazily resizing now? I'm surprised that allocate_or_resize_outputs cannot be assumed to have handled this for you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we used to resize before computing types, so original output was resized to correct size (but not necessarily correct layout) before intermediate output was allocated. Now, by the time we are resizing, original output is forgotten and we resize intermediate only.
That's a good point though, I better handle this in allocate_or_resize_outputs anyway.
| // make sure that operand's strides are matching element size and | ||
| // dimensions permutations which are stored in _perm | ||
| op.stride_bytes = apply_perm_and_mul(op.tensor.strides(), element_size); | ||
| } else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So sweet, so savory.
| torch.randn((1, 3, 1, 1, 1), device=device).contiguous(memory_format=torch.channels_last_3d), | ||
| torch.channels_last_3d) | ||
|
|
||
| def test_strides_propagation(self, device): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @mruberry on this test
|
|
||
| dim = x.dim() | ||
| div = x.stride(dim - 1) | ||
| for p in permutations(range(dim)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
ezyang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is all super subtle, but it seems to make sense, and magically the code gets smallers. Nice work! Looking for more comments, but if you want to perhaps defer that for a second PR that's fine by me too.
|
Slightly changed the algorithm to better propagate input strides in ambiguous cases, changed the description to reflect that, added a note in the code, added a test with ambiguous sizes. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Fixes #41314 among other things.
This PR streamlines layout propagation logic in TensorIterator and removes almost all cases of channels-last hardcoding. The new rules and changes are as follows:
outputand defined output of the wrong (e.g. 0) size is always the same (before this PR the behavior was divergent)1) participate in sorting. Precedence is given to the first input, in case of a tie in the first input, first the corresponding dimensions are considered, and if that does not indicate that swap is needed, strides of the same dimension in subsequent inputs are considered. See changes inreorder_dimensionsandcompute_strides. Note that first inspecting dimensions of the first input allows us to better recover it's permutation (and we select this behavior because it more reliably propagates channels-last strides) but in some rare cases could result in worse traversal order for the second tensor.These rules are enough to recover previously hard-coded behavior related to channels last, so all existing tests are passing.
In general, these rules will produce intuitive results, and in most cases permutation of the full size input (in case of broadcasted operation) will be recovered, or permutation of the first input (in case of same sized inputs) will be recovered, including cases with trivial (1) dimensions. As an example of the latter, the following tensor
will produce output with the same stride (3,3,1) in binary operations with 1d tensor. Another example is a tensor of size N1H1 that has strides
H,H,1,1when contiguous andH, 1, 1, 1when channels-last. The output retains these strides in binary operations when another 1d tensor is broadcasted on this one.Footnote: for ambiguous cases where all inputs are memory dense and have the same physical layout that nevertheless can correspond to different permutations, such as e.g. NC11-sized physically contiguous tensors, regular contiguous tensor is returned, and thus permutation information of the input is lost (so for NC11 channels-last input had the strides
C, 1, C, C, but output will have the stridesC, 1, 1, 1). This behavior is unchanged from before and consistent with numpy, but it still makes sense to change it. The blocker for doing it currently is performance ofempty_strided. Once we make it on par withemptywe should be able to propagate layouts in these cases. For now, to not slow down common contiguous case, we default to contiguous.The table below shows how in some cases current behavior loses permutation/stride information, whereas new behavior propagates permutation.
a=torch.randn(2,3,8)[:,:,::2].permute(2,0,1)
print(a.stride())
print(a.exp().stride())
print((a+a).stride())
out = torch.empty(0)
torch.add(a,a,out=out)
print(out.stride())
(6, 3, 1)
(1, 12, 4)
(6, 3, 1)
(1, 12, 4)
(1, 12, 4)
(1, 12, 4)
a=torch.randn(3,1,1).as_strided((3,1,1), (1,3,3))
print(a.stride(), (a+torch.randn(1)).stride())
a=torch.randn(2,3,4).permute(2,0,1)
print(a.stride())
print(a.exp().stride())
print((a+a).stride())
out = torch.empty(0)
torch.add(a,a,out=out)
print(out.stride())
(1, 12, 4)
(6, 3, 1)
(1, 12, 4)
(6, 3, 1)
(1, 12, 4)
(1, 12, 4)
(1, 12, 4)
(1, 12, 4)