Skip to content

Commit 1dcad08

Browse files
li-roysoumith
authored andcommitted
Support N-D tensors in Bilinear (#5764)
* support n-d inputs in bilinear and move to aten * support n-d inputs in bilinear and move to aten * add asserts to bilinear inputs * address comments * cast int64_t in asserts
1 parent 04edb89 commit 1dcad08

File tree

6 files changed

+52
-65
lines changed

6 files changed

+52
-65
lines changed

aten/src/ATen/native/Linear.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#include "ATen/ATen.h"
2+
#include "ATen/NativeFunctions.h"
3+
4+
5+
namespace at { namespace native {
6+
7+
Tensor bilinear(const Tensor& input1, const Tensor& input2, const Tensor& weight, const Tensor& bias) {
8+
AT_ASSERT(input1.dim() == input2.dim(), "bilinear(): input dimensions do not match: got %lld and %lld",
9+
(long long)input1.dim(), (long long)input2.dim());
10+
for (int64_t i = 0; i < input1.dim() - 1; i++) {
11+
AT_ASSERT(input1.size(i) == input2.size(i),
12+
"bilinear(): input batch dimensions do not match at dim %lld: got %lld and %lld",
13+
(long long)i, (long long)input1.size(i), (long long)input2.size(i));
14+
}
15+
AT_ASSERT(input1.size(input1.dim() - 1) == weight.size(1),
16+
"bilinear(): input1 size does not match weight size: got %lld but expected %lld",
17+
(long long)input1.size(input1.dim() - 1), (long long)weight.size(1));
18+
AT_ASSERT(input2.size(input2.dim() - 1) == weight.size(2),
19+
"bilinear(): input2 size does not match weight size: got %lld but expected %lld",
20+
(long long)input2.size(input2.dim() - 1), (long long)weight.size(2));
21+
AT_ASSERT(bias.defined() && bias.size(0) == weight.size(0),
22+
"bilinear(): bias size does not match weight size: got %lld but expected %lld",
23+
(long long)bias.size(0), (long long)weight.size(0));
24+
25+
auto b_input1 = input1.unsqueeze(-2).unsqueeze(-2);
26+
auto b_input2 = input2.unsqueeze(-2).unsqueeze(-1);
27+
28+
auto output = at::matmul(at::matmul(b_input1, weight), b_input2);
29+
output = output.squeeze(-1).squeeze(-1);
30+
if (bias.defined()) {
31+
output = output + bias;
32+
}
33+
return output;
34+
}
35+
36+
}} // namespace at::native

aten/src/ATen/native/native_functions.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@
7171

7272
- func: bernoulli_(Tensor self, double p=0.5, Generator* generator=nullptr) -> Tensor
7373

74+
- func: bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias) -> Tensor
75+
variants: function
76+
7477
- func: cat(TensorList tensors, int64_t dim=0) -> Tensor
7578
variants: function
7679

test/test_nn.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4169,6 +4169,13 @@ def test_bilinear(self):
41694169
_assertGradAndGradgradChecks(self, lambda x1, x2: F.bilinear(x1, x2, module.weight, module.bias),
41704170
(input1_1, input2_1))
41714171

4172+
def test_bilinear_broadcasting(self):
4173+
m = nn.Bilinear(5, 6, 8)
4174+
input1 = torch.randn(2, 3, 5)
4175+
input2 = torch.randn(2, 3, 6)
4176+
expected = m(input1.view(6, 5), input2.view(6, 6)).view(2, 3, 8)
4177+
self.assertEqual(expected, m(input1, input2))
4178+
41724179
def test_conv_tbc(self):
41734180
inp = Variable(torch.randn(9, 4, 5), requires_grad=True)
41744181
weight = Variable(torch.randn(3, 5, 6), requires_grad=True)

torch/nn/_functions/linear.py

Lines changed: 0 additions & 58 deletions
This file was deleted.

torch/nn/functional.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from torch._C import _infer_size, _add_docstr
1010
from . import _functions
1111
from .modules import utils
12-
from ._functions.linear import Bilinear
1312
from ._functions.padding import ConstantPadNd
1413
from ._functions import vision
1514
from ._functions.thnn.fold import Col2Im, Im2Col
@@ -1001,10 +1000,7 @@ def linear(input, weight, bias=None):
10011000

10021001

10031002
def bilinear(input1, input2, weight, bias=None):
1004-
if bias is None:
1005-
return Bilinear.apply(input1, input2, weight)
1006-
else:
1007-
return Bilinear.apply(input1, input2, weight, bias)
1003+
return torch._C._VariableFunctions.bilinear(input1, input2, weight, bias)
10081004

10091005

10101006
def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2,

torch/nn/modules/linear.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,11 @@ class Bilinear(Module):
7373
Default: ``True``
7474
7575
Shape:
76-
- Input: :math:`(N, \text{in1_features})`, :math:`(N, \text{in2_features})`
77-
- Output: :math:`(N, \text{out_features})`
76+
- Input: :math:`(N, *, \text{in1_features})`, :math:`(N, *, \text{in2_features})`
77+
where :math:`*` means any number of additional dimensions. All but the last
78+
dimension of the inputs should be the same.
79+
- Output: :math:`(N, *, \text{out_features})` where all but the last dimension
80+
are the same shape as the input.
7881
7982
Attributes:
8083
weight: the learnable weights of the module of shape

0 commit comments

Comments
 (0)