-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Implement hstack, vstack, dstack #42799
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
💊 CI failures summary and remediationsAs of commit a03e866 (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
|
@mruberry PTAL |
torch/_torch_docs.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: these examples are excellent but maybe
a = torch.tensor([[1],[2],[3]])
b = torch.tensor([[4],[5],[6]])
would be clearer?
torch/_torch_docs.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See numbering suggestion below.
torch/_torch_docs.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See numbering suggestion above.
test/test_torch.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests are good, but this test's case generation is limited to replicating the same tensor shape for each element of the input list. Here are some cases I was thinking about:
- op(t)
- the behavior of
np.hstack(a)is strange andnp.hstack(a) != np.hstack((a,))(the same is true fornp.dstack) - do we even support non-tuple arguments? if not we should validate this throws a runtime error
- if we support single tensor arguments, is
np.hstack(a)'s andnp.dstack(a)'s behavior correct?
- the behavior of
- op((a, b, c, ...))
- validating that if they differ on an unexpected dim an error is thrown (maybe
_test_special_stacksshould take a dim argument corresponding to the op?) - validating that if they differ only on the expected dim the result is equivalent to NumPy
- validating that if they differ on an unexpected dim an error is thrown (maybe
- are tensors with a size zero dim handled correctly? (if not that's OK, but let's assert it doesn't work)
np.hstackhas special-handling of 1D tensors (as your implementation does), doestest_hstackneed a custom elaboration to test that behavior, in particular?- validating that tensors with different shapes but the same post-
atleast_Xdshapes meet the criteria work
For an example of the last bullet:
a = np.array([[[1],[2],[3]]])
b = np.array((4, 5, 6))
np.dstack((a, b))
: array([[[1, 4],
[2, 5],
[3, 6]]])
This is a good number of cases but validating each one by hand shouldn't be too laborious, I hope.
What are your thoughts? Are there other cases I missed?
mruberry
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks excellent. A couple minor nits about the doc examples and questions about test coverage.
315a28d to
18e9013
Compare
|
@mruberry I have added tests that I think cover all of the cases. They cover:
For the last two, those are tested from 1 to 4 dimensions, so the special behavior for hstack is included with that. I have also added some autograd tests in a similar manner to the existing stack autograd test. Does this sound good, or do I need more tests? |
| else: | ||
| # Invalid dimensions, test for error | ||
| with self.assertRaisesRegex(RuntimeError, "Sizes of tensors must match except in dimension"): | ||
| torch_fn(torch_input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you add an assert that NumPy also throws a runtime error in this case? You don't need to assert a string is thrown:
with self.assertRaises(RuntimeError):
np_fn(np_input)
mruberry
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work, @muthuArivoli!
Would you just fix that one minor nit on the tests and we'll get this merged?
Let me know if you're interested in working on a new problem.
|
@mruberry I added the numpy error check. Is it ok that numpy throws a ValueError, while we throw a RuntimeError? Yes, I'm interested in working on a new problem, do you have any recommendations? |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Absolutely OK. Nice work.
For symmetry there are the split functions, hsplit, vsplit, and dsplit. A slightly more challenging binary function is divmod, because it returns two tensors. There are the "polynomial" functions, like polyadd and polyder, but I'm hoping someone will write all of them near simultaneously because they have a lot of common structure. There are also unary functions, like nan_to_num, that would be very helpful. If you'd like something more exotic or especially numerically challenging there are also functions like the kaiser windowing function. |
|
Two questions:
|
Excellent questions.
|
Related to #38349