-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Support N-D tensors in Bilinear #5764
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@pytorchbot retest this please |
|
@pytorchbot retest this please |
| namespace at { namespace native { | ||
|
|
||
| Tensor bilinear(const Tensor& input1, const Tensor& input2, const Tensor& weight, const Tensor& bias) { | ||
| auto b_input1 = input1.unsqueeze(-2).unsqueeze(-2); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| Shape: | ||
| - Input: :math:`(N, \text{in1_features})`, :math:`(N, \text{in2_features})` | ||
| - Output: :math:`(N, \text{out_features})` | ||
| where :math:`*` means any number of additional dimensions. All but the last |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/Linear.cpp
Outdated
| auto b_input2 = input2.unsqueeze(-2).unsqueeze(-1); | ||
|
|
||
| auto output = at::matmul(at::matmul(b_input1, weight), b_input2); | ||
| output = output.squeeze(-1).squeeze(-2).sum(-1); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/Linear.cpp
Outdated
|
|
||
| Tensor bilinear(const Tensor& input1, const Tensor& input2, const Tensor& weight, const Tensor& bias) { | ||
| if (input1.dim() != input2.dim()) { | ||
| throw std::runtime_error("Inputs should have the same number of dimensions"); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/Linear.cpp
Outdated
| } | ||
| for (int64_t i = 0; i < input1.dim() - 1; i++) { | ||
| if (input1.size(i) != input2.size(i)) { | ||
| throw std::runtime_error("Batch dimensions of inputs do not match"); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/Linear.cpp
Outdated
| if (input1.dim() != input2.dim()) { | ||
| throw std::runtime_error("Inputs should have the same number of dimensions"); | ||
| } | ||
| AT_ASSERT(input1.dim() == input2.dim(), "bilinear(): input dimensions do not match: got %d and %d", |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/Linear.cpp
Outdated
| throw std::runtime_error("Bias sizes does not match weight size"); | ||
| AT_ASSERT(input1.size(i) == input2.size(i), | ||
| "bilinear(): input batch dimensions do not match at dim %d: got %d and %d", | ||
| i, input1.size(i), input2.size(i)); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@pytorchbot retest this please |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🥇 🥇 🥇 🥇 🥇 🥇 🥇 🥇 🥇 🥇🥇 🥇 🥇
🥇 🥇 🥇 🥇 🥇 🥇
🥇 🥇 🥇 🥇 🥇 🥇 🥇 🥇
🥇 🥇 🥇 🥇 🥇 🥇 🥇
🥇 🥇 🥇 🥇 🥇 🥇 🥇 🥇 🥇 🥇 🥇
|
I like richard's enthusiasm :D |
Closes #5601.