Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions aten/src/ATen/native/Linear.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include "ATen/ATen.h"
#include "ATen/NativeFunctions.h"


namespace at { namespace native {

Tensor bilinear(const Tensor& input1, const Tensor& input2, const Tensor& weight, const Tensor& bias) {
AT_ASSERT(input1.dim() == input2.dim(), "bilinear(): input dimensions do not match: got %lld and %lld",
(long long)input1.dim(), (long long)input2.dim());
for (int64_t i = 0; i < input1.dim() - 1; i++) {
AT_ASSERT(input1.size(i) == input2.size(i),
"bilinear(): input batch dimensions do not match at dim %lld: got %lld and %lld",
(long long)i, (long long)input1.size(i), (long long)input2.size(i));
}
AT_ASSERT(input1.size(input1.dim() - 1) == weight.size(1),
"bilinear(): input1 size does not match weight size: got %lld but expected %lld",
(long long)input1.size(input1.dim() - 1), (long long)weight.size(1));
AT_ASSERT(input2.size(input2.dim() - 1) == weight.size(2),
"bilinear(): input2 size does not match weight size: got %lld but expected %lld",
(long long)input2.size(input2.dim() - 1), (long long)weight.size(2));
AT_ASSERT(bias.defined() && bias.size(0) == weight.size(0),
"bilinear(): bias size does not match weight size: got %lld but expected %lld",
(long long)bias.size(0), (long long)weight.size(0));

auto b_input1 = input1.unsqueeze(-2).unsqueeze(-2);
auto b_input2 = input2.unsqueeze(-2).unsqueeze(-1);

auto output = at::matmul(at::matmul(b_input1, weight), b_input2);
output = output.squeeze(-1).squeeze(-1);
if (bias.defined()) {
output = output + bias;
}
return output;
}

}} // namespace at::native
3 changes: 3 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@

- func: bernoulli_(Tensor self, double p=0.5, Generator* generator=nullptr) -> Tensor

- func: bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias) -> Tensor
variants: function

- func: cat(TensorList tensors, int64_t dim=0) -> Tensor
variants: function

Expand Down
7 changes: 7 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4145,6 +4145,13 @@ def test_bilinear(self):
_assertGradAndGradgradChecks(self, lambda x1, x2: F.bilinear(x1, x2, module.weight, module.bias),
(input1_1, input2_1))

def test_bilinear_broadcasting(self):
m = nn.Bilinear(5, 6, 8)
input1 = torch.randn(2, 3, 5)
input2 = torch.randn(2, 3, 6)
expected = m(input1.view(6, 5), input2.view(6, 6)).view(2, 3, 8)
self.assertEqual(expected, m(input1, input2))

def test_conv_tbc(self):
inp = Variable(torch.randn(9, 4, 5), requires_grad=True)
weight = Variable(torch.randn(3, 5, 6), requires_grad=True)
Expand Down
58 changes: 0 additions & 58 deletions torch/nn/_functions/linear.py

This file was deleted.

6 changes: 1 addition & 5 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from torch._C import _infer_size, _add_docstr
from . import _functions
from .modules import utils
from ._functions.linear import Bilinear
from ._functions.padding import ConstantPadNd
from ._functions import vision
from ._functions.thnn.fold import Col2Im, Im2Col
Expand Down Expand Up @@ -1001,10 +1000,7 @@ def linear(input, weight, bias=None):


def bilinear(input1, input2, weight, bias=None):
if bias is None:
return Bilinear.apply(input1, input2, weight)
else:
return Bilinear.apply(input1, input2, weight, bias)
return torch._C._VariableFunctions.bilinear(input1, input2, weight, bias)


def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2,
Expand Down
7 changes: 5 additions & 2 deletions torch/nn/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,11 @@ class Bilinear(Module):
Default: ``True``

Shape:
- Input: :math:`(N, \text{in1_features})`, :math:`(N, \text{in2_features})`
- Output: :math:`(N, \text{out_features})`
- Input: :math:`(N, *, \text{in1_features})`, :math:`(N, *, \text{in2_features})`
where :math:`*` means any number of additional dimensions. All but the last
dimension of the inputs should be the same.
- Output: :math:`(N, *, \text{out_features})` where all but the last dimension
are the same shape as the input.

Attributes:
weight: the learnable weights of the module of shape
Expand Down