Skip to content

Commit fc62b95

Browse files
committed
svd bwd
1 parent c903f37 commit fc62b95

File tree

8 files changed

+123
-28
lines changed

8 files changed

+123
-28
lines changed

aten/src/ATen/native/NativeFunctions.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _det_with_svd(const Tensor& self) {
286286
// check symmetric
287287
bool symmetric = self.equal(self.transpose(0, 1));
288288

289-
auto svd = self.svd(false);
289+
auto svd = self.svd(true);
290290
auto sigma = std::get<1>(svd);
291291
auto u = std::get<0>(svd);
292292
auto v = std::get<2>(svd);

test/test_autograd.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1915,6 +1915,13 @@ def random_symmetric_matrix(l):
19151915
return A.mm(A.transpose(0, 1))
19161916

19171917

1918+
def random_fullrank_matrix_distinct_singular_value(l):
1919+
A = torch.randn(l, l)
1920+
u, _, v = A.svd()
1921+
s = torch.arange(1, l + 1).mul_(1.0 / (l + 1))
1922+
return u.mm(torch.diag(s)).mm(v.transpose(0, 1))
1923+
1924+
19181925
class dont_convert(tuple):
19191926
pass
19201927

@@ -2187,6 +2194,8 @@ class dont_convert(tuple):
21872194
('det', lambda: random_square_matrix_of_rank(S, S - 2), (), 'dim2_null', (), [skipIfNoLapack]),
21882195
('det', lambda: random_square_matrix_of_rank(S, 1), (), 'rank1', (), [skipIfNoLapack]),
21892196
('det', lambda: random_square_matrix_of_rank(S, 2), (), 'rank2', (), [skipIfNoLapack]),
2197+
('det', lambda: random_fullrank_matrix_distinct_singular_value(S), (), 'distinct_postive_s', (), [skipIfNoLapack]),
2198+
('svd', lambda: random_fullrank_matrix_distinct_singular_value(S), (), '', (), [skipIfNoLapack]),
21902199
('gesv', (S, S), ((S, S),), '', (), [skipIfNoLapack]),
21912200
('potrf', _make_cov(S), (True,), '', (), [skipIfNoLapack]),
21922201
('eq', (S, S, S), ((S, S, S),)),
@@ -2363,7 +2372,17 @@ def maybe_non_contig(tensor):
23632372
'potrf'
23642373
}
23652374
EXCLUDE_GRADGRADCHECK = {
2366-
'det'
2375+
'svd'
2376+
}
2377+
EXCLUDE_GRADGRADCHECK_BY_TEST_NAME = {
2378+
# Some of the following det ones pass because random matrix has full rank
2379+
# with high probability. But we can't rely on this. So only test gradgrad on
2380+
# test_det_distinct_postive_s.
2381+
'test_det',
2382+
'test_det_symmetric',
2383+
'test_det_dim2_null',
2384+
'test_det_rank1',
2385+
'test_det_rank2'
23672386
}
23682387

23692388

@@ -2417,10 +2436,10 @@ def gradgradcheck_method_precision_override(test_name):
24172436
return override
24182437

24192438

2420-
def run_grad_and_gradgrad_checks(test_case, test_name, apply_method, output_variable,
2439+
def run_grad_and_gradgrad_checks(test_case, name, test_name, apply_method, output_variable,
24212440
input_variables, run_gradgradcheck=True):
24222441
test_case.assertTrue(gradcheck(apply_method, input_variables, eps=1e-6, atol=PRECISION))
2423-
if not run_gradgradcheck:
2442+
if name in EXCLUDE_GRADGRADCHECK or test_name in EXCLUDE_GRADGRADCHECK_BY_TEST_NAME:
24242443
return
24252444
grad_y = generate_gradoutput(output_variable, non_contiguous=True)
24262445
gradgradcheck_precision_override = gradgradcheck_method_precision_override(test_name)
@@ -2442,7 +2461,7 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
24422461
test_case.assertEqual(unpack_variables(output_variable), output_tensor)
24432462

24442463
if run_grad_checks:
2445-
run_grad_and_gradgrad_checks(test_case, test_name, apply_fn,
2464+
run_grad_and_gradgrad_checks(test_case, name, test_name, apply_fn,
24462465
output_variable, f_args_variable)
24472466

24482467
self_variable = f_args_variable[0]
@@ -2486,10 +2505,9 @@ def check(name):
24862505
# TODO: check that both have changed after adding all inplace ops
24872506

24882507
if not is_inplace and name not in EXCLUDE_GRADCHECK:
2489-
run_grad_and_gradgrad_checks(self, test_name,
2508+
run_grad_and_gradgrad_checks(self, name, test_name,
24902509
lambda *inputs: getattr(inputs[0], name)(*inputs[1:]),
2491-
output_variable, (self_variable,) + args_variable,
2492-
name not in EXCLUDE_GRADGRADCHECK)
2510+
output_variable, (self_variable,) + args_variable)
24932511

24942512
# functional interface tests
24952513
if hasattr(torch, name) and name not in EXCLUDE_FUNCTIONAL:

tools/autograd/derivatives.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@
553553
self: sum_backward(grad, self.sizes(), dim, keepdim)
554554

555555
- name: svd(Tensor self, bool some)
556-
self: not_implemented("svd")
556+
self: svd_backward(grads, self, some, res1, res2, res3)
557557

558558
- name: symeig(Tensor self, bool eigenvectors, bool upper)
559559
self: not_implemented("symeig")

tools/autograd/gen_python_functions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@
5454
# to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h.
5555
SUPPORTED_RETURN_TYPES = {
5656
'Tensor', 'std::tuple<Tensor,Tensor>',
57-
'std::tuple<Tensor,Tensor,Tensor>', 'std::vector<Tensor>',
57+
'std::tuple<Tensor,Tensor,Tensor>',
58+
'std::tuple<Tensor,Tensor,Tensor,Tensor>',
59+
'std::vector<Tensor>',
5860
'Scalar', 'bool', 'int64_t', 'void*'
5961
}
6062

tools/autograd/templates/Functions.cpp

Lines changed: 71 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "Functions.h"
22
#include <ATen/WrapDimUtils.h>
3+
#include <iostream>
34

45
// define constants like M_PI and C keywords for MSVC
56
#ifdef _MSC_VER
@@ -502,28 +503,85 @@ std::tuple<Tensor, Tensor, Tensor> prelu_double_backward(
502503
}
503504
}
504505

506+
// https://j-towns.github.io/papers/svd-derivative.pdf
507+
Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
508+
bool some, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v) {
509+
auto m = self.size(0);
510+
auto n = self.size(1);
511+
auto k = sigma.size(0);
512+
513+
Tensor u, v;
514+
if (!some) {
515+
// ignore the free subspace
516+
u = raw_u.narrow(1, 0, k);
517+
v = raw_v.narrow(1, 0, k);
518+
} else {
519+
u = raw_u;
520+
v = raw_v;
521+
}
522+
523+
auto gu = grads[0];
524+
auto gsigma = grads[1];
525+
auto gv = grads[2];
526+
auto im = self.type().eye(m);
527+
auto in = self.type().eye(n);
528+
auto ut = u.t();
529+
auto vt = v.t();
530+
auto sigma_mat = sigma.diag();
531+
auto sigma_mat_inv = sigma.pow(-1).diag();
532+
auto sigma_expanded_sq = sigma.pow(2).expand_as(sigma_mat);
533+
auto F = (sigma_expanded_sq - sigma_expanded_sq.t()).pow(-1);
534+
auto& long_type = sigma.type().toScalarType(at::kLong);
535+
auto diag_indices = long_type.arange(0, F.numel(), k + 1);
536+
F.view({-1}).index_fill_(0, diag_indices, 0);
537+
538+
Tensor u_term, sigma_term, v_term;
539+
540+
if (gu.defined()) {
541+
u_term = u.mm(F.mul(ut.mm(gu) - gu.t().mm(u))).mm(sigma_mat);
542+
if (m > k) {
543+
u_term = u_term + (im - u.mm(ut)).mm(gu).mm(sigma_mat_inv);
544+
}
545+
u_term = u_term.mm(vt);
546+
} else {
547+
u_term = self.type().zeros({1}).expand_as(self);
548+
}
549+
550+
if (gsigma.defined()) {
551+
sigma_term = u.mm(gsigma.diag()).mm(vt);
552+
} else {
553+
sigma_term = self.type().zeros({1}).expand_as(self);
554+
}
555+
556+
if (gv.defined()) {
557+
auto gvt = gv.t();
558+
v_term = sigma_mat.mm(F.mul(vt.mm(gv) - gvt.mm(v))).mm(vt);
559+
if (n > k) {
560+
v_term = v_term + sigma_mat_inv.mm(gvt.mm(in - v.mm(vt)));
561+
}
562+
v_term = u.mm(v_term);
563+
} else {
564+
v_term = self.type().zeros({1}).expand_as(self);
565+
}
566+
567+
return u_term + sigma_term + v_term;
568+
}
569+
505570
// Formula:
506571
// d det / d A_ij = \sum_k (\prod_{l neq k} Sigma_l) U_ik V_jk
507572
// that is, if det != 0
508573
// d det / d A = U * (Sigma / det) * V^T
509574
Tensor _det_with_svd_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
510575
const Tensor& det, const Tensor& u, const Tensor& sigma, const Tensor& v) {
576+
std::vector<torch::autograd::Variable> svd_grads(grads.begin() + 1, grads.end());
577+
auto svd_term = svd_backward(svd_grads, self, true, u, sigma, v);
578+
511579
auto det_grad = grads[0];
512-
// If any gradient is defined on svd, then it must be in a double backward
513-
// because the svd results are not exposed to users. That is, it can only come
514-
// from auto-differentiating this method:
515-
// dA = _det_with_svd_backward(d det, A, [det, u, s, v]=_det_with_svd(A)),
516-
// getting ddu, dds, ddv, and calling this method again to accumulate ddA.
517-
for (size_t i = 1; i < 4; i++) {
518-
if (grads[i].defined()) {
519-
throw std::runtime_error("Double backward through det is not supported.");
520-
}
521-
}
522580
auto size = self.size(0);
523581
auto null_dim = size - sigma.nonzero().size(0);
524582
if (null_dim >= 2) {
525583
// \prod_{l neq k} Sigma_l is zero every where
526-
return zeros_like(self);
584+
return svd_term;
527585
}
528586
if (null_dim == 1) {
529587
// only last sigma is 0
@@ -532,10 +590,10 @@ Tensor _det_with_svd_backward(const std::vector<torch::autograd::Variable> &grad
532590
auto scale = sigma.narrow(0, 0, size - 1).prod();
533591
auto last_u = u.narrow(1, size - 1, 1);
534592
auto last_v = v.narrow(1, size - 1, 1);
535-
return last_u.mm(last_v.transpose(0, 1)).mul_(scale.mul_(det_grad));
593+
return svd_term + last_u.mm(last_v.transpose(0, 1)).mul_(scale.mul_(det_grad));
536594
}
537595
// no zero singular values
538-
return u.mm(sigma.pow(-1).mul_(det.mul(det_grad)).diag()).mm(v.transpose(0, 1));
596+
return svd_term + u.mm(sigma.pow(-1).mul_(det.mul(det_grad)).diag()).mm(v.transpose(0, 1));
539597
}
540598

541599
}

tools/jit/templates/aten_dispatch.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ void pack_list(list_of_retainable & outputs, std::tuple<Tensor, Tensor, Tensor>
9090
outputs.push_back(toRetainableSteal(std::move(std::get<1>(v))));
9191
outputs.push_back(toRetainableSteal(std::move(std::get<2>(v))));
9292
}
93-
void pack_list(std::vector<Tensor> & outputs, std::tuple<Tensor, Tensor, Tensor, Tensor> v) {
93+
void pack_list(list_of_retainable & outputs, std::tuple<Tensor, Tensor, Tensor, Tensor> v) {
9494
outputs.push_back(toRetainableSteal(std::move(std::get<0>(v))));
9595
outputs.push_back(toRetainableSteal(std::move(std::get<1>(v))));
9696
outputs.push_back(toRetainableSteal(std::move(std::get<2>(v))));

torch/_torch_docs.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4276,18 +4276,29 @@
42764276
`U, S, V = torch.svd(A)` returns the singular value decomposition of a
42774277
real matrix `A` of size `(n x m)` such that :math:`A = USV'*`.
42784278
4279-
`U` is of shape `n x n`
4279+
`U` is of shape `n x min(n, m)`
42804280
4281-
`S` is of shape `n x m`
4281+
`S` is a diagonal square matrix of shape `min(n, m) x min(n, m)`, represented as
4282+
a vector of shape `(min(n, m),)` containing its diagonal entries.
42824283
4283-
`V` is of shape `m x m`.
4284+
`V` is of shape `m x min(n, m)`.
42844285
42854286
:attr:`some` represents the number of singular values to be computed.
42864287
If `some=True`, it computes some and `some=False` computes all.
42874288
42884289
.. note:: Irrespective of the original strides, the returned matrix `U`
42894290
will be transposed, i.e. with strides `(1, n)` instead of `(n, 1)`.
42904291
4292+
.. note:: Extra care needs to be taken when backward through `U` and `V`
4293+
outputs. Such operation is really only stable when :attr:`input` is
4294+
full rank with all distinct singular values. Otherwise, `NaN` can
4295+
appear as the gradients are not properly defined. Also, when
4296+
:attr:`some` = `False`, the gradients on `U[:, min(n, m):]` and
4297+
`V[:, min(n, m):]` will be ignored as those vectors can be arbitrary
4298+
bases of the subspaces.
4299+
4300+
.. note:: Double backward through :meth:`~torch.svd` is not supported currently.
4301+
42914302
Args:
42924303
input (Tensor): the input 2D Tensor
42934304
some (bool, optional): controls the number of singular values to be computed

torch/functional.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,13 @@ def maybeSqueeze(tensor):
248248

249249

250250
def det(var):
251-
"""Calculates determinant of a 2D square Variable
251+
"""Calculates determinant of a 2D square Variable.
252+
253+
.. note::
254+
Backward through `det` internally uses SVD results. So double backward
255+
through `det` will need to backward through :meth:`~Tensor.svd`. This
256+
can be unstable in certain cases. Please see :meth:`~torch.svd` for
257+
details.
252258
253259
Arguments:
254260
var (Variable): The input 2D square Variable.

0 commit comments

Comments
 (0)