Skip to content

Commit c681b03

Browse files
ssnlsoumith
authored andcommitted
Add determinant function on variable; Add backward on svd (#3816)
* determinant on variable * svd bwd
1 parent 80c8635 commit c681b03

File tree

20 files changed

+424
-20
lines changed

20 files changed

+424
-20
lines changed

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ If you are working on the CUDA code, here are some useful CUDA debugging tips:
207207
1. `CUDA_DEBUG=1` will enable CUDA debugging symbols (-g -G). This is particularly
208208
helpful in debugging device code. However, it will slow down the build process,
209209
so use wisely.
210-
2. `cuda-gdb` and `cuda-memcheck` are your best CUDA debuging friends. Unlike`gdb`,
210+
2. `cuda-gdb` and `cuda-memcheck` are your best CUDA debugging friends. Unlike`gdb`,
211211
`cuda-gdb` can display actual values in a CUDA tensor (rather than all zeros).
212212

213213

aten/src/ATen/Declarations.cwrap

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3504,6 +3504,7 @@
35043504
- Double
35053505
backends:
35063506
- CPU
3507+
- CUDA
35073508
variants:
35083509
- method
35093510
- function

aten/src/ATen/native/NativeFunctions.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,52 @@ Tensor & unsqueeze_(Tensor& self, int64_t dim) {
270270
return self.as_strided_(std::get<0>(g), std::get<1>(g));
271271
}
272272

273+
// For backward, we save svd.
274+
// http://www.ics.forth.gr/cvrl/publications/conferences/2000_eccv_SVD_jacobian.pdf
275+
// But instead of gesvd SVD A = U(A) Sig(A) V(A)^T, which doesn't specify signs
276+
// of determinants of U and V, we consider det(A) = \prod Sig_(A), where
277+
// 1. A = U_(A) Sig_(A) V(A)^T
278+
// 2. Sig_(A) and U_(A) can be different in signs in first row/col from
279+
// their counterparts so that U_(A) * V_(A) have +1 determinant
280+
std::tuple<Tensor, Tensor, Tensor, Tensor> _det_with_svd(const Tensor& self) {
281+
if (!at::isFloatingType(self.type().scalarType()) ||
282+
self.dim() != 2 || self.size(0) != self.size(1)) {
283+
std::ostringstream ss;
284+
ss << "det(" << self.type() << "{" << self.sizes() << "}): expected a 2D"
285+
<< "square tensor of floating types";
286+
throw std::runtime_error(ss.str());
287+
}
288+
// check symmetric
289+
bool symmetric = self.equal(self.transpose(0, 1));
290+
291+
auto svd = self.svd(true);
292+
auto sigma = std::get<1>(svd);
293+
auto u = std::get<0>(svd);
294+
auto v = std::get<2>(svd);
295+
auto det = sigma.prod();
296+
if (!symmetric) {
297+
auto qr = self.geqrf();
298+
auto a = std::get<0>(qr);
299+
auto tau = std::get<1>(qr);
300+
// non-zero values in tau represent Householder reflectors, which has -1 det
301+
int64_t num_reflectors = tau.nonzero().size(0);
302+
auto qr_det = a.diag().prod();
303+
if (num_reflectors % 2 == 1) {
304+
qr_det = -qr_det;
305+
}
306+
det = qr_det; // QR is more stable than svd, so use it anyways
307+
if ((qr_det < 0).any() ^ (det < 0).any()) { // if different sign
308+
u.narrow(1, 0, 1).mul_(-1);
309+
sigma.narrow(0, 0, 1).mul_(-1);
310+
}
311+
}
312+
return std::make_tuple(det, u, sigma, v);
313+
}
314+
315+
Tensor det(const Tensor& self) {
316+
return std::get<0>(self._det_with_svd());
317+
}
318+
273319
Tensor stack(TensorList tensors, int64_t dim) {
274320
if (tensors.size() == 0) {
275321
throw std::runtime_error("stack expects a non-empty TensorList");

aten/src/ATen/native/native_functions.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@
7676

7777
- func: unsqueeze_(Tensor self, int64_t dim) -> Tensor
7878

79+
- func: _det_with_svd(Tensor self) -> (Tensor, Tensor, Tensor, Tensor)
80+
81+
- func: det(Tensor self) -> Tensor
82+
7983
- func: stack(TensorList tensors, int64_t dim=0) -> Tensor
8084
variants: function
8185

aten/src/THC/generic/THCTensorMathMagma.cu

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,51 @@ THC_API void THCTensor_(potrs)(THCState *state, THCTensor *rb_, THCTensor *b, TH
584584
#endif
585585
}
586586

587+
THC_API void THCTensor_(geqrf)(THCState *state, THCTensor *ra_, THCTensor *rtau_, THCTensor *a_)
588+
{
589+
#ifdef USE_MAGMA
590+
THArgCheck(a_->nDimension == 2, 2, "A should be 2 dimensional");
591+
592+
THCTensor *a = THCTensor_(newColumnMajor)(state, ra_, a_);
593+
int64_t m = a->size[0];
594+
int64_t n = a->size[1];
595+
int64_t k = (m < n ? m : n);
596+
597+
#ifdef MAGMA_V2
598+
#if defined(THC_REAL_IS_FLOAT)
599+
int64_t nb = magma_get_sgeqrf_nb(m, n);
600+
#else
601+
int64_t nb = magma_get_dgeqrf_nb(m, n);
602+
#endif
603+
#else
604+
#if defined(THC_REAL_IS_FLOAT)
605+
int64_t nb = magma_get_sgeqrf_nb(m);
606+
#else
607+
int64_t nb = magma_get_dgeqrf_nb(m);
608+
#endif
609+
#endif
610+
611+
real *rtau_data = th_magma_malloc_pinned<real>(k);
612+
real *a_data = THCTensor_(data)(state, a);
613+
614+
int info;
615+
#if defined(THC_REAL_IS_FLOAT)
616+
magma_sgeqrf2_gpu(m, n, a_data, m, rtau_data, &info);
617+
#else
618+
magma_dgeqrf2_gpu(m, n, a_data, m, rtau_data, &info);
619+
#endif
620+
621+
if (info != 0)
622+
THError("MAGMA geqrf2 : Argument %d : illegal value.", -info);
623+
624+
THCTensor_(freeCopyTo)(state, a, ra_);
625+
THCTensor_(copyArray1d)(state, rtau_, rtau_data, k);
626+
magma_free_pinned(rtau_data);
627+
#else
628+
THError(NoMagma(geqrf));
629+
#endif
630+
}
631+
587632
THC_API void THCTensor_(qr)(THCState *state, THCTensor *rq_, THCTensor *rr_, THCTensor *a_)
588633
{
589634
#ifdef USE_MAGMA
@@ -614,6 +659,11 @@ THC_API void THCTensor_(qr)(THCState *state, THCTensor *rq_, THCTensor *rr_, THC
614659
real *work_data = THCTensor_(data)(state, work);
615660

616661
int info;
662+
// We need to call two different versions of ?geqrf:
663+
// ?geqrf_gpu allows fast computation of Q via ?orqrf_gpu, but doesn't give
664+
// R properly. Note that the MAGMA documentation for this method is wrong.
665+
// http://icl.cs.utk.edu/magma/forum/viewtopic.php?f=2&t=1015&p=2800&hilit=geqrf_gpu#p2800
666+
// ?geqrf2_gpu gives correct R, but doesn't allow computation of Q via ?orqrf_gpu
617667
#if defined(THC_REAL_IS_FLOAT)
618668
magma_sgeqrf2_gpu(m, n, a_data, m, tau_data, &info);
619669
#else

aten/src/THC/generic/THCTensorMathMagma.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ THC_API void THCTensor_(getri)(THCState *state, THCTensor *ra_, THCTensor *a);
1515
THC_API void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo);
1616
THC_API void THCTensor_(potrf)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo);
1717
THC_API void THCTensor_(potrs)(THCState *state, THCTensor *rb_, THCTensor *a, THCTensor *b, const char *uplo);
18+
THC_API void THCTensor_(geqrf)(THCState *state, THCTensor *ra_, THCTensor *rtau_, THCTensor *a_);
1819
THC_API void THCTensor_(qr)(THCState *state, THCTensor *rq_, THCTensor *rr_, THCTensor *a);
1920

20-
2121
#endif // defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
2222

2323
#endif

docs/source/torch.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ BLAS and LAPACK Operations
188188
.. autofunction:: ger
189189
.. autofunction:: gesv
190190
.. autofunction:: inverse
191+
.. autofunction:: det
191192
.. autofunction:: matmul
192193
.. autofunction:: mm
193194
.. autofunction:: mv

test/test_autograd.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1898,6 +1898,30 @@ def _make_cov(S):
18981898
return torch.mm(L, L.t())
18991899

19001900

1901+
def random_square_matrix_of_rank(l, rank):
1902+
assert rank <= l
1903+
A = torch.randn(l, l)
1904+
u, s, v = A.svd()
1905+
for i in range(l):
1906+
if i >= rank:
1907+
s[i] = 0
1908+
elif s[i] == 0:
1909+
s[i] = 1
1910+
return u.mm(torch.diag(s)).mm(v.transpose(0, 1))
1911+
1912+
1913+
def random_symmetric_matrix(l):
1914+
A = torch.randn(l, l)
1915+
return A.mm(A.transpose(0, 1))
1916+
1917+
1918+
def random_fullrank_matrix_distinct_singular_value(l):
1919+
A = torch.randn(l, l)
1920+
u, _, v = A.svd()
1921+
s = torch.arange(1, l + 1).mul_(1.0 / (l + 1))
1922+
return u.mm(torch.diag(s)).mm(v.transpose(0, 1))
1923+
1924+
19011925
class dont_convert(tuple):
19021926
pass
19031927

@@ -1906,7 +1930,6 @@ class dont_convert(tuple):
19061930
M = 10
19071931
S = 5
19081932

1909-
19101933
# (name, size, args...)
19111934
method_tests = [
19121935
('add', (S, S, S), ((S, S, S),)),
@@ -2166,6 +2189,13 @@ class dont_convert(tuple):
21662189
('index_copy', (S, S), (0, index_perm_variable(2, S), (2, S)), 'dim', [0]),
21672190
('index_fill', (S, S), (0, index_variable(2, S), 2), 'dim', [0]),
21682191
('inverse', (S, S), (), '', (), [skipIfNoLapack]),
2192+
('det', (S, S), (), '', (), [skipIfNoLapack]),
2193+
('det', lambda: random_symmetric_matrix(S), (), 'symmetric', (), [skipIfNoLapack]),
2194+
('det', lambda: random_square_matrix_of_rank(S, S - 2), (), 'dim2_null', (), [skipIfNoLapack]),
2195+
('det', lambda: random_square_matrix_of_rank(S, 1), (), 'rank1', (), [skipIfNoLapack]),
2196+
('det', lambda: random_square_matrix_of_rank(S, 2), (), 'rank2', (), [skipIfNoLapack]),
2197+
('det', lambda: random_fullrank_matrix_distinct_singular_value(S), (), 'distinct_postive_s', (), [skipIfNoLapack]),
2198+
('svd', lambda: random_fullrank_matrix_distinct_singular_value(S), (), '', (), [skipIfNoLapack]),
21692199
('gesv', (S, S), ((S, S),), '', (), [skipIfNoLapack]),
21702200
('potrf', _make_cov(S), (True,), '', (), [skipIfNoLapack]),
21712201
('eq', (S, S, S), ((S, S, S),)),
@@ -2303,6 +2333,8 @@ def maybe_non_contig(tensor):
23032333
return Variable(maybe_non_contig(arg), requires_grad=requires_grad)
23042334
elif isinstance(arg, Variable) and non_contiguous:
23052335
return Variable(maybe_non_contig(arg.data), requires_grad=arg.requires_grad)
2336+
elif callable(arg):
2337+
return map_arg(arg())
23062338
else:
23072339
return arg
23082340
return tuple(map_arg(arg) for arg in call_args)
@@ -2339,6 +2371,19 @@ def maybe_non_contig(tensor):
23392371
EXCLUDE_GRADCHECK = {
23402372
'potrf'
23412373
}
2374+
EXCLUDE_GRADGRADCHECK = {
2375+
'svd'
2376+
}
2377+
EXCLUDE_GRADGRADCHECK_BY_TEST_NAME = {
2378+
# Some of the following det ones pass because random matrix has full rank
2379+
# with high probability. But we can't rely on this. So only test gradgrad on
2380+
# test_det_distinct_postive_s.
2381+
'test_det',
2382+
'test_det_symmetric',
2383+
'test_det_dim2_null',
2384+
'test_det_rank1',
2385+
'test_det_rank2'
2386+
}
23422387

23432388

23442389
def exclude_tensor_method(name, test_name):
@@ -2359,6 +2404,7 @@ def exclude_tensor_method(name, test_name):
23592404
'resize_as',
23602405
'scatter',
23612406
'scatter_add',
2407+
'det',
23622408
}
23632409
if test_name in exclude_all_tensor_method_by_test_name:
23642410
return True
@@ -2390,17 +2436,19 @@ def gradgradcheck_method_precision_override(test_name):
23902436
return override
23912437

23922438

2393-
def run_grad_and_gradgrad_checks(test_case, test_name, apply_method, output_variable, input_variables):
2439+
def run_grad_and_gradgrad_checks(test_case, name, test_name, apply_method, output_variable,
2440+
input_variables, run_gradgradcheck=True):
23942441
test_case.assertTrue(gradcheck(apply_method, input_variables, eps=1e-6, atol=PRECISION))
2395-
2442+
if name in EXCLUDE_GRADGRADCHECK or test_name in EXCLUDE_GRADGRADCHECK_BY_TEST_NAME:
2443+
return
23962444
grad_y = generate_gradoutput(output_variable, non_contiguous=True)
23972445
gradgradcheck_precision_override = gradgradcheck_method_precision_override(test_name)
23982446
if gradgradcheck_precision_override is not None:
23992447
atol = gradgradcheck_precision_override['atol']
24002448
rtol = gradgradcheck_precision_override['rtol']
24012449
test_case.assertTrue(gradgradcheck(apply_method, input_variables, grad_y, atol=atol, rtol=rtol))
24022450
else:
2403-
test_case.assertTrue(gradgradcheck(apply_method, input_variables, grad_y,))
2451+
test_case.assertTrue(gradgradcheck(apply_method, input_variables, grad_y))
24042452

24052453

24062454
def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
@@ -2413,7 +2461,7 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
24132461
test_case.assertEqual(unpack_variables(output_variable), output_tensor)
24142462

24152463
if run_grad_checks:
2416-
run_grad_and_gradgrad_checks(test_case, test_name, apply_fn,
2464+
run_grad_and_gradgrad_checks(test_case, name, test_name, apply_fn,
24172465
output_variable, f_args_variable)
24182466

24192467
self_variable = f_args_variable[0]
@@ -2457,7 +2505,7 @@ def check(name):
24572505
# TODO: check that both have changed after adding all inplace ops
24582506

24592507
if not is_inplace and name not in EXCLUDE_GRADCHECK:
2460-
run_grad_and_gradgrad_checks(self, test_name,
2508+
run_grad_and_gradgrad_checks(self, name, test_name,
24612509
lambda *inputs: getattr(inputs[0], name)(*inputs[1:]),
24622510
output_variable, (self_variable,) + args_variable)
24632511

test/test_cuda.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,8 @@ def tmp(t):
316316
('qr', small_2d_lapack_fat, lambda t: [], 'fat', float_types),
317317
('qr', large_2d_lapack, lambda t: [], 'big', float_types),
318318
('inverse', new_t(20, 20), lambda t: [], None, float_types),
319-
319+
('geqrf', new_t(20, 20), lambda t: [], None, float_types),
320+
# TODO: add det to here once Variable and Tensor are the same thing
320321
]
321322

322323
# TODO: random functions, cat, gather, scatter, index*, masked*,
@@ -938,6 +939,10 @@ def test_caching_pinned_memory_multi_gpu(self):
938939
def _select_broadcastable_dims(dims_full=None):
939940
return TestTorch._select_broadcastable_dims(dims_full)
940941

942+
@unittest.skipIf(not HAS_MAGMA, "no MAGMA library detected")
943+
def test_det(self):
944+
TestTorch._test_det(self, lambda t: t.cuda())
945+
941946
def test_broadcast(self):
942947
TestTorch._test_broadcast(self, lambda t: t.cuda())
943948

0 commit comments

Comments
 (0)