Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1584,6 +1584,22 @@ def test_cat_negdim_2(self):
lambda a, b, c, dim: torch.cat((a, b, c), dim),
True, f_args_variable, f_args_tensor)

@skipIfNoLapack
def test_trtrs(self):
def _test_with_size(N, C):
A = Variable(torch.rand(N, N), requires_grad=True)
b = Variable(torch.rand(N, C), requires_grad=True)

for upper, transpose, unitriangular in product((True, False), repeat=3):
def func(A, b):
return torch.trtrs(b, A, upper, transpose, unitriangular)

gradcheck(func, [A, b])
gradgradcheck(func, [A, b])

_test_with_size(S, S + 1)
_test_with_size(S, S - 1)

def test_variable_traverse(self):
def get_out_and_unrefed_cycle():
inp = Variable(torch.randn(10), requires_grad=True)
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@
self: grad.triu(diagonal)

- name: trtrs(Tensor self, Tensor A, bool upper, bool transpose, bool unitriangular)
self: not_implemented("trtrs")
self, A: trtrs_backward(grads[0], grads[1], self, A, res1, upper, transpose, unitriangular, grad_input_mask)

- name: trunc(Tensor self)
self: zeros_like(grad)
Expand Down
32 changes: 32 additions & 0 deletions tools/autograd/templates/Functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,38 @@ Tensor _det_with_svd_backward(const std::vector<torch::autograd::Variable> &grad
return svd_term + u.mm(sigma.pow(-1).mul_(det.mul(det_grad)).diag()).mm(v.transpose(0, 1));
}

// Reference:
// https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
// Sec. 2.3.1 Matrix inverse product
std::tuple<Tensor, Tensor> trtrs_backward(
const Tensor & grad_x, const Tensor & grad_m,
const Tensor & b, const Tensor & a, const Tensor & x,
const bool upper, const bool transpose, const bool unitriangular,
std::array<bool, 2> output_mask) {
Tensor grad_b, grad_a;
if (grad_x.defined()) {
grad_b = std::get<0>(grad_x.trtrs(a, upper, !transpose, unitriangular));
if (output_mask[1]) {
grad_a = transpose ? -x.mm(grad_b.t()) : -grad_b.mm(x.t());
if (upper) {
grad_a = grad_a.triu((int) unitriangular);
} else {
grad_a = grad_a.tril(-((int) unitriangular));
}
}
}
if (!grad_a.defined()) {
grad_a = a.type().zeros({1}).expand_as(a);
}
if (!grad_b.defined()) {
grad_b = b.type().zeros({1}).expand_as(b);
}
if (output_mask[1] && grad_m.defined()) {
grad_a = grad_a.add(grad_m);
}
return std::tuple<Tensor, Tensor>{grad_b, grad_a};
}

}

${autograd_function_definitions}
Expand Down