Skip to content

Conversation

@ngimel
Copy link
Collaborator

@ngimel ngimel commented Aug 12, 2020

Fixes #41314 among other things.
This PR streamlines layout propagation logic in TensorIterator and removes almost all cases of channels-last hardcoding. The new rules and changes are as follows:

  1. behavior of undefined output and defined output of the wrong (e.g. 0) size is always the same (before this PR the behavior was divergent)
  2. in obvious cases (unary operation on memory-dense tensors, binary operations on memory-dense tensors with the same layout) strides are propagated (before propagation was inconsistent) (see footnote)
  3. in other cases the output permutation is obtained as inverse permutation of sorting inputs by strides. Sorting is done with comparator obeying the following rules: strides of broadcasted dimensions are set to 0, and 0 compares equal to anything. Strides of not-broadcasted dimensions (including dimensions of size 1) participate in sorting. Precedence is given to the first input, in case of a tie in the first input, first the corresponding dimensions are considered, and if that does not indicate that swap is needed, strides of the same dimension in subsequent inputs are considered. See changes in reorder_dimensions and compute_strides. Note that first inspecting dimensions of the first input allows us to better recover it's permutation (and we select this behavior because it more reliably propagates channels-last strides) but in some rare cases could result in worse traversal order for the second tensor.

These rules are enough to recover previously hard-coded behavior related to channels last, so all existing tests are passing.
In general, these rules will produce intuitive results, and in most cases permutation of the full size input (in case of broadcasted operation) will be recovered, or permutation of the first input (in case of same sized inputs) will be recovered, including cases with trivial (1) dimensions. As an example of the latter, the following tensor

x=torch.randn(2,1,3).permute(1,0,2)

will produce output with the same stride (3,3,1) in binary operations with 1d tensor. Another example is a tensor of size N1H1 that has strides H,H,1,1 when contiguous and H, 1, 1, 1 when channels-last. The output retains these strides in binary operations when another 1d tensor is broadcasted on this one.

Footnote: for ambiguous cases where all inputs are memory dense and have the same physical layout that nevertheless can correspond to different permutations, such as e.g. NC11-sized physically contiguous tensors, regular contiguous tensor is returned, and thus permutation information of the input is lost (so for NC11 channels-last input had the strides C, 1, C, C, but output will have the strides C, 1, 1, 1). This behavior is unchanged from before and consistent with numpy, but it still makes sense to change it. The blocker for doing it currently is performance of empty_strided. Once we make it on par with empty we should be able to propagate layouts in these cases. For now, to not slow down common contiguous case, we default to contiguous.
The table below shows how in some cases current behavior loses permutation/stride information, whereas new behavior propagates permutation.

code old new
#strided tensors
a=torch.randn(2,3,8)[:,:,::2].permute(2,0,1)
print(a.stride())
print(a.exp().stride())
print((a+a).stride())
out = torch.empty(0)
torch.add(a,a,out=out)
print(out.stride())
(2, 24, 8)
(6, 3, 1)
(1, 12, 4)
(6, 3, 1)
(2, 24, 8)
(1, 12, 4)
(1, 12, 4)
(1, 12, 4)
#memory dense tensors
a=torch.randn(3,1,1).as_strided((3,1,1), (1,3,3))
print(a.stride(), (a+torch.randn(1)).stride())
a=torch.randn(2,3,4).permute(2,0,1)
print(a.stride())
print(a.exp().stride())
print((a+a).stride())
out = torch.empty(0)
torch.add(a,a,out=out)
print(out.stride())
(1, 3, 3) (1, 1, 1)
(1, 12, 4)
(6, 3, 1)
(1, 12, 4)
(6, 3, 1)
(1, 3, 3) (1, 3, 3)
(1, 12, 4)
(1, 12, 4)
(1, 12, 4)
(1, 12, 4)

@ngimel ngimel added the module: bc-breaking Related to a BC-breaking change label Aug 12, 2020
@dr-ci
Copy link

dr-ci bot commented Aug 12, 2020

💊 CI failures summary and remediations

As of commit 06ae52c (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

XLA failure

Job pytorch_xla_linux_bionic_py3_6_clang9_test is failing. Please create an issue with title prefixed by [PT_BREAK] in pytorch/xla and link to to this PR. If you have questions, please reach out to @ailzhang / @dlibenzi / @JackCaoG.


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 32 times.

Copy link
Collaborator

@dzhulgakov dzhulgakov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Epic! It would be a good idea to merge a lot of observations from the PR descriptions as inline comments in the right places. All this logic is way too magical and requires explanations :)

}
auto tensor_shape = invert_perm(shape_);
if (inverted) {
if (!op.tensor.defined())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, add {}

// can just return contiguous output
// it is faster because it avoids allocating 0 size tensor and
// resizing and restriding it
op.tensor = at::empty(tensor_shape, op.options());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we might just delete this branch if empty_strided is faster? maybe expand the comment

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's still unpermuting of strides that needs to be done even if empty_strided is faster, depending on perf impact we might ignore it and delete this branch.

op.stride_bytes = compatible_stride(element_size);
// check if permutation is just an inverted order
bool inverted = true;
for (int i = 1; i <= ndim(); i++) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nit: write loop from 0 and <, all loops starting from 1 are suspicious :)

operands_[i].will_resize = true;
continue;
}
TORCH_CHECK(is_reduction_, "output with shape ", output.sizes(), " doesn't match the broadcast shape ",

This comment was marked as resolved.

} else {
} else if (stride0 > stride1) {
return 1;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we treat equal dims as ambigious? shouldn't we do something like "preserver the original order", i.e. compare dim0 and dim1 themselves? Maybe that's what the insertion sort below does, but it's hard for me to mentally unravel it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very good question! Added a "flies in the ointment" section to the description. When strides are equal, we don't swap corresponding dimension, so we tend to favor identity (or closer-to-identity) permutation in this case.

@ngimel ngimel changed the title [WIP] streamline stride propagation logic in TensorIterator streamline stride propagation logic in TensorIterator Aug 13, 2020
}
}

template <int dim, MemoryFormat memory_format>

This comment was marked as resolved.

@VitalyFedyunin
Copy link
Contributor

I'm impressed!

Couple questions:

  • Have you tried to benchmark this implementation?
  • Are we fine about dropping warnings in terms of contig + c_l operations?

}
if (stride0 == 0 || stride1 == 0) {
continue;
} else if (stride0 <= stride1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we keep stride0 <= stride1?
I think this is used to decide whether the comparison should continue in current loop within the insertion sort. The <= logic will make the loop stop earlier which should be more efficient.

Let's say we have dims 0 1 2 3 4 5 with stride 2 8 2 4 2 1. So perm_ will be 5 4 3 2 1 0 at the beginning.
Let's say we sorted perm_[0] to perm_[5] already, now the perm_ is 5 4 2 3 1 0, with stride 1 2 2 4 8 2. Then we pick up the last item perm_[5], stride is 2 and compare with all the sorted dims' strides and swap. By keeping the <= logic, the loop will break at perm_[2] and the < logic will delay the break to perm_[0] (but nothing happen from 2 to 0).

Please correct me if I am wrong.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With <= logic the original tests were not passing, the loop was breaking too early and full permutation was not recovered.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see why, thanks.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ngimel
Copy link
Collaborator Author

ngimel commented Aug 17, 2020

  • Have you tried to benchmark this implementation?
  • Are we fine about dropping warnings in terms of contig + c_l operations?

Fast path benchmarks https://www.internalfb.com/intern/aibench/details/2193418916/ - for some reason inplace ops regress ~5% which is very weird, given that code path is literally the same. I'll do benchmarks for slow paths.
I think it's ok to drop warnings, we just have generic rules for layout propagation now that is not specific for channels_last. Similarly, when there's an operation on tensors with different permutations, we don't warn.

@ezyang
Copy link
Contributor

ezyang commented Aug 17, 2020

At some point we should reorder the definitions in TensorIterators.cpp to be in order they run in the algorithm, it would have been a lot easier to read the PR if I could read top down and know that it was chronological

@ezyang
Copy link
Contributor

ezyang commented Aug 17, 2020

Note for which parts of TensorIterator algorithm got modified:

  // try fast setup output tensor, if failed, fallback to normal setup
  if (!fast_set_up(config)) { // MODIFIED
    // compute each tensor's stride after broadcasting
    compute_strides(config);  // MODIFIED
    // re-order dimensions to improve coalescing
    reorder_dimensions(config); // MODIFIED
    // allocate the output tensor if it's not provided
    allocate_outputs(); // RENAMED AND MODIFIED
    // coalesce adjacent dimensions when possible
    coalesce_dimensions();
  }

@ezyang
Copy link
Contributor

ezyang commented Aug 17, 2020

I was expecting to see something in the PR description talking about "resize", but I didn't see anything anywhere. Maybe the description can be expanded to talk about the general strategy that the new algorithm takes? (Or better yet, as Dmytro suggests, put it in code.)

TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand", i);
op.tensor = at::empty(shape_, op.options(), MemoryFormat::Contiguous);
op.current_dtype = op.target_dtype;
} else if (op.will_resize) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have an invariant that it will never be the case that a tensor is undefined AND will_resize is true?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will_resize can be set to true only in mark_resize_outputs, and tensor.defined() is a necessary condition to enter that branch.

op.current_dtype = op.target_dtype;
} else if (op.will_resize) {
at::native::resize_output(op.tensor, shape_);
op.tensor.as_strided_(shape_, operands_[i_defined].tensor.strides());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trusting that these are the only three sites that need updating.

op.original_tensor.scalar_type() != op.current_dtype) {
if (op.original_tensor.sizes() != op.tensor.sizes()){
op.original_tensor.resize_as_(op.tensor).as_strided_(op.tensor.sizes(), op.tensor.strides());
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this because we are lazily resizing now? I'm surprised that allocate_or_resize_outputs cannot be assumed to have handled this for you.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we used to resize before computing types, so original output was resized to correct size (but not necessarily correct layout) before intermediate output was allocated. Now, by the time we are resizing, original output is forgotten and we resize intermediate only.
That's a good point though, I better handle this in allocate_or_resize_outputs anyway.

// make sure that operand's strides are matching element size and
// dimensions permutations which are stored in _perm
op.stride_bytes = apply_perm_and_mul(op.tensor.strides(), element_size);
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So sweet, so savory.

torch.randn((1, 3, 1, 1, 1), device=device).contiguous(memory_format=torch.channels_last_3d),
torch.channels_last_3d)

def test_strides_propagation(self, device):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @mruberry on this test


dim = x.dim()
div = x.stride(dim - 1)
for p in permutations(range(dim)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is all super subtle, but it seems to make sense, and magically the code gets smallers. Nice work! Looking for more comments, but if you want to perhaps defer that for a second PR that's fine by me too.

@ngimel
Copy link
Collaborator Author

ngimel commented Aug 19, 2020

Slightly changed the algorithm to better propagate input strides in ambiguous cases, changed the description to reflect that, added a note in the code, added a test with ambiguous sizes.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ngimel merged this pull request in c8bc298.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: bc-breaking Related to a BC-breaking change

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Input strides are not propagated for unary ops

7 participants