Skip to content

Commit 5d130e4

Browse files
vishwakftwfacebook-github-bot
authored andcommitted
Allowing batching for det/logdet/slogdet operations (#22909)
Summary: Changelog: - Add batching for det / logdet / slogdet operations - Update derivative computation to support batched inputs (and consequently batched outputs) - Update docs Pull Request resolved: #22909 Test Plan: - Add a `test_det_logdet_slogdet_batched` method in `test_torch.py` to test `torch.det`, `torch.logdet` and `torch.slogdet` on batched inputs. This relies on the correctness of `torch.det` on single matrices (tested by `test_det_logdet_slogdet`). A port of this test is added to `test_cuda.py` - Add autograd tests for batched inputs Differential Revision: D16580988 Pulled By: ezyang fbshipit-source-id: b76c87212fbe621f42a847e3b809b5e60cfcdb7a
1 parent 5b66062 commit 5d130e4

File tree

8 files changed

+315
-95
lines changed

8 files changed

+315
-95
lines changed

aten/src/ATen/native/LinearAlgebra.cpp

Lines changed: 44 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -19,77 +19,60 @@ namespace native {
1919
// det(P) = \pm 1, this method returns a 3-tuple:
2020
// (det(P), diag(U), info),
2121
// where info helps us identify singular matrices.
22-
static inline std::tuple<double, Tensor, int> _lu_det_P_diag_U_info(const Tensor& self) {
23-
Tensor p, lu, info;
24-
std::tie(lu, p, info) = at::_lu_with_info(self, /*pivot=*/true, /*check_errors=*/false);
25-
int int_info = info.item<int32_t>();
26-
TORCH_CHECK(int_info >= 0, "LU factorization (getrf) failed with info = ", int_info);
27-
auto n = self.size(0);
28-
auto num_exchanges = (at::arange(1, n + 1, p.options()) != p).nonzero().size(0);
29-
if (num_exchanges % 2 == 1) {
30-
return std::make_tuple(-1., lu.diag(), int_info);
31-
} else {
32-
return std::make_tuple(1., lu.diag(), int_info);
33-
}
22+
static inline std::tuple<Tensor, Tensor> _lu_det_P_diag_U(const Tensor& self) {
23+
Tensor pivs, lu, infos;
24+
std::tie(lu, pivs, infos) = at::_lu_with_info(self, /*pivot=*/true, /*check_errors=*/false);
25+
TORCH_CHECK(infos.ge(0).all().item<uint8_t>(), "Invalid argument passed to lu");
26+
auto n = self.size(-1);
27+
auto num_exchanges = (at::arange(1, n + 1, pivs.options()) != pivs).sum(-1, /*keepdim=*/false, /*dtype=*/self.scalar_type()).fmod_(2);
28+
return std::tuple<Tensor, Tensor>(num_exchanges.mul_(-2).add_(1),
29+
lu.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1));
3430
}
3531

3632
Tensor det(const Tensor& self) {
37-
TORCH_CHECK(at::isFloatingType(self.scalar_type()) &&
38-
self.dim() == 2 && self.size(0) == self.size(1),
39-
"det(", self.type(), "{", self.sizes(), "}): expected a 2D square tensor "
40-
"of floating types");
41-
double det_P;
42-
Tensor diag_U;
43-
int info;
44-
std::tie(det_P, diag_U, info) = _lu_det_P_diag_U_info(self);
45-
if (info > 0) {
46-
return at::zeros({}, self.options());
47-
} else {
48-
return diag_U.prod().mul_(det_P);
49-
}
33+
squareCheckInputs(self);
34+
TORCH_CHECK(at::isFloatingType(self.scalar_type()), "Expected a floating point tensor as input");
35+
36+
Tensor det_P, diag_U;
37+
std::tie(det_P, diag_U) = _lu_det_P_diag_U(self);
38+
// complete_det is 0 when U is singular (U(i, i) = 0 for some i in [1, self.size(-1)]).
39+
// The product accumulation takes care of this case, and hence no special case handling is required.
40+
auto complete_det = diag_U.prod(-1).mul_(det_P);
41+
return complete_det;
5042
}
5143

5244
Tensor logdet(const Tensor& self) {
53-
TORCH_CHECK(at::isFloatingType(self.scalar_type()) &&
54-
self.dim() == 2 && self.size(0) == self.size(1),
55-
"logdet(", self.type(), "{", self.sizes(), "}): expected a 2D square tensor "
56-
"of floating types");
57-
double det_P;
58-
Tensor diag_U;
59-
int info;
60-
std::tie(det_P, diag_U, info) = _lu_det_P_diag_U_info(self);
61-
if (info > 0) {
62-
return at::full({}, -std::numeric_limits<double>::infinity(), self.options());
63-
}
64-
// `det_sign` is the sign of the determinant. We work on `diag_U.sign()` for
65-
// numerical stability when diag_U has a lot small values.
66-
auto det_sign = diag_U.sign().prod().mul_(det_P);
67-
// This synchronizes on GPU, but `_lu_det_P_diag_U_info` above already synchronizes
68-
if (det_sign.item<double>() <= 0) {
69-
return det_sign.log_(); // get proper nan (det<0) or -inf (det=0)
70-
} else {
71-
return diag_U.abs_().log_().sum();
45+
squareCheckInputs(self);
46+
TORCH_CHECK(at::isFloatingType(self.scalar_type()), "Expected a floating point tensor as input");
47+
48+
Tensor det_P, diag_U;
49+
std::tie(det_P, diag_U) = _lu_det_P_diag_U(self);
50+
Tensor det_sign = diag_U.sign().prod(-1).mul_(det_P);
51+
52+
// If det_sign > 0, diag_U.abs_().log_().sum(-1) gives logdet (this means U is not singular).
53+
// If det_sign <= 0, then we get proper nan (when det < 0, i.e., det_sign) or -inf (when det = 0, i.e., U is singular).
54+
// U is singular when U(i, i) = 0 for some i in [1, self.size(-1)].
55+
Tensor logdet_vals = diag_U.abs_().log_().sum(-1);
56+
if (self.dim() > 2) {
57+
logdet_vals.index_put_((det_sign < 0).nonzero_numpy(), at::full({}, NAN, self.options()));
58+
} else if (det_sign.item<double>() < 0) {
59+
logdet_vals.fill_(NAN);
7260
}
61+
return logdet_vals;
7362
}
7463

7564
std::tuple<Tensor, Tensor> slogdet(const Tensor& self) {
76-
TORCH_CHECK(at::isFloatingType(self.scalar_type()) &&
77-
self.dim() == 2 && self.size(0) == self.size(1),
78-
"slogdet(", self.type(), "{", self.sizes(), "}): expected a 2D square tensor "
79-
"of floating types");
80-
double det_P;
81-
Tensor diag_U;
82-
int info;
83-
std::tie(det_P, diag_U, info) = _lu_det_P_diag_U_info(self);
84-
if (info > 0) {
85-
return std::make_tuple(at::zeros({}, self.options()),
86-
at::full({}, -std::numeric_limits<double>::infinity(), self.options()));
87-
} else {
88-
// `det_sign` is the sign of the determinant. We work on `diag_U.sign()` for
89-
// numerical stability when diag_U has a lot small values.
90-
auto det_sign = diag_U.sign().prod().mul_(det_P);
91-
return std::make_tuple(det_sign, diag_U.abs_().log_().sum());
92-
}
65+
squareCheckInputs(self);
66+
TORCH_CHECK(at::isFloatingType(self.scalar_type()), "Expected a floating point tensor as input");
67+
68+
Tensor det_P, diag_U;
69+
std::tie(det_P, diag_U) = _lu_det_P_diag_U(self);
70+
auto det_sign = diag_U.sign().prod(-1).mul_(det_P);
71+
// abslogdet_val is -inf if U is singular, in which case diag_U.abs_().log_().sum(-1) will return -inf.
72+
// U is singular when U(i, i) = 0 for some i in [1, self.size(-1)].
73+
// Since abslogdet_val cannot take nan, no special case handling is required.
74+
auto abslogdet_val = diag_U.abs_().log_().sum(-1);
75+
return std::make_tuple(det_sign, abslogdet_val);
9376
}
9477

9578
Tensor pinverse(const Tensor& self, double rcond) {

test/common_methods_invocations.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,13 @@ def method_tests():
637637
('det', lambda: random_square_matrix_of_rank(S, 2), NO_ARGS, 'rank2', (), NO_ARGS, [skipIfNoLapack]),
638638
('det', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS,
639639
'distinct_singular_values', (), NO_ARGS, [skipIfNoLapack]),
640+
('det', (3, 3, S, S), NO_ARGS, 'batched', (), NO_ARGS, [skipIfNoLapack]),
641+
('det', (3, 3, 1, 1), NO_ARGS, 'batched_1x1', (), NO_ARGS, [skipIfNoLapack]),
642+
('det', lambda: random_symmetric_matrix(S, 3), NO_ARGS, 'batched_symmetric', (), NO_ARGS, [skipIfNoLapack]),
643+
('det', lambda: random_symmetric_psd_matrix(S, 3), NO_ARGS, 'batched_symmetric_psd', (), NO_ARGS, [skipIfNoLapack]),
644+
('det', lambda: random_symmetric_pd_matrix(S, 3), NO_ARGS, 'batched_symmetric_pd', (), NO_ARGS, [skipIfNoLapack]),
645+
('det', lambda: random_fullrank_matrix_distinct_singular_value(S, 3, 3), NO_ARGS,
646+
'batched_distinct_singular_values', (), NO_ARGS, [skipIfNoLapack]),
640647
# For `logdet` and `slogdet`, the function at det=0 is not smooth.
641648
# We need to exclude tests with det=0 (e.g. dim2_null, rank1, rank2) and use
642649
# `make_nonzero_det` to make the random matrices have nonzero det. For
@@ -650,6 +657,14 @@ def method_tests():
650657
'symmetric_pd', (), NO_ARGS, [skipIfNoLapack]),
651658
('logdet', lambda: make_nonzero_det(random_fullrank_matrix_distinct_singular_value(S), 1, 0), NO_ARGS,
652659
'distinct_singular_values', (), NO_ARGS, [skipIfNoLapack]),
660+
('logdet', lambda: make_nonzero_det(torch.randn(3, 3, S, S), 1), NO_ARGS, 'batched', (), NO_ARGS, [skipIfNoLapack]),
661+
('logdet', lambda: make_nonzero_det(torch.randn(3, 3, 1, 1), 1), NO_ARGS, 'batched_1x1', (), NO_ARGS, [skipIfNoLapack]),
662+
('logdet', lambda: make_nonzero_det(random_symmetric_matrix(S, 3), 1), NO_ARGS,
663+
'batched_symmetric', (), NO_ARGS, [skipIfNoLapack]),
664+
('logdet', lambda: make_nonzero_det(random_symmetric_pd_matrix(S, 3), 1), NO_ARGS,
665+
'batched_symmetric_pd', (), NO_ARGS, [skipIfNoLapack]),
666+
('logdet', lambda: make_nonzero_det(random_fullrank_matrix_distinct_singular_value(S, 3), 1, 0), NO_ARGS,
667+
'batched_distinct_singular_values', (), NO_ARGS, [skipIfNoLapack]),
653668
('slogdet', lambda: make_nonzero_det(torch.randn(1, 1), 1), NO_ARGS,
654669
'1x1_pos_det', (), NO_ARGS, [skipIfNoLapack], itemgetter(1)),
655670
('slogdet', lambda: make_nonzero_det(torch.randn(1, 1), -1), NO_ARGS,
@@ -664,6 +679,16 @@ def method_tests():
664679
'symmetric_pd', (), NO_ARGS, [skipIfNoLapack], itemgetter(1)),
665680
('slogdet', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS,
666681
'distinct_singular_values', (), NO_ARGS, [skipIfNoLapack], itemgetter(1)),
682+
('slogdet', lambda: make_nonzero_det(torch.randn(3, 3, 1, 1), -1), NO_ARGS,
683+
'batched_1x1_neg_det', (), NO_ARGS, [skipIfNoLapack], itemgetter(1)),
684+
('slogdet', lambda: make_nonzero_det(torch.randn(3, 3, S, S), 1), NO_ARGS,
685+
'batched_pos_det', (), NO_ARGS, [skipIfNoLapack], itemgetter(1)),
686+
('slogdet', lambda: make_nonzero_det(random_symmetric_matrix(S, 3)), NO_ARGS,
687+
'batched_symmetric', (), NO_ARGS, [skipIfNoLapack], itemgetter(1)),
688+
('slogdet', lambda: random_symmetric_pd_matrix(S, 3), NO_ARGS,
689+
'batched_symmetric_pd', (), NO_ARGS, [skipIfNoLapack], itemgetter(1)),
690+
('slogdet', lambda: random_fullrank_matrix_distinct_singular_value(S, 3), NO_ARGS,
691+
'batched_distinct_singular_values', (), NO_ARGS, [skipIfNoLapack], itemgetter(1)),
667692
('symeig', lambda: random_symmetric_matrix(S), (True, False), 'lower', (), NO_ARGS, [skipIfNoLapack]),
668693
('symeig', lambda: random_symmetric_matrix(S), (True, True), 'upper', (), NO_ARGS, [skipIfNoLapack]),
669694
('symeig', lambda: random_symmetric_matrix(M), (True, True), 'large', (), NO_ARGS, [skipIfNoLapack]),
@@ -1082,14 +1107,23 @@ def unpack_variables(args):
10821107
'test_det_dim2_null',
10831108
'test_det_rank1',
10841109
'test_det_rank2',
1110+
'test_det_batched',
1111+
'test_det_batched_1x1',
1112+
'test_det_batched_symmetric',
1113+
'test_det_batched_symmetric_psd',
10851114
# `other` expand_as(self, other) is not used in autograd.
10861115
'test_expand_as',
10871116
'test_logdet',
10881117
'test_logdet_1x1',
10891118
'test_logdet_symmetric',
1119+
'test_logdet_batched',
1120+
'test_logdet_batched_1x1',
1121+
'test_logdet_batched_symmetric',
10901122
'test_slogdet_1x1_neg_det',
10911123
'test_slogdet_neg_det',
10921124
'test_slogdet_symmetric',
1125+
'test_slogdet_batched_1x1_neg_det',
1126+
'test_slogdet_batched_symmetric',
10931127
'test_cdist',
10941128
}
10951129

test/common_utils.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -949,30 +949,35 @@ def random_square_matrix_of_rank(l, rank):
949949

950950
def random_symmetric_matrix(l, *batches):
951951
A = torch.randn(*(batches + (l, l)))
952-
for i in range(l):
953-
for j in range(i):
954-
A[..., i, j] = A[..., j, i]
952+
A = (A + A.transpose(-2, -1)).div_(2)
955953
return A
956954

957955

958-
def random_symmetric_psd_matrix(l):
959-
A = torch.randn(l, l)
960-
return A.mm(A.transpose(0, 1))
956+
def random_symmetric_psd_matrix(l, *batches):
957+
A = torch.randn(*(batches + (l, l)))
958+
return torch.matmul(A, A.transpose(-2, -1))
961959

962960

963961
def random_symmetric_pd_matrix(l, *batches):
964962
A = torch.randn(*(batches + (l, l)))
965-
return A.matmul(A.transpose(-2, -1)) + torch.eye(l) * 1e-5
963+
return torch.matmul(A, A.transpose(-2, -1)) + torch.eye(l) * 1e-5
966964

967965

968966
def make_nonzero_det(A, sign=None, min_singular_value=0.1):
969967
u, s, v = A.svd()
970-
s[s < min_singular_value] = min_singular_value
971-
A = u.mm(torch.diag(s)).mm(v.t())
972-
det = A.det().item()
968+
s.clamp_(min=min_singular_value)
969+
A = torch.matmul(u, torch.matmul(torch.diag_embed(s), v.transpose(-2, -1)))
970+
det = A.det()
973971
if sign is not None:
974-
if (det < 0) ^ (sign < 0):
975-
A[0, :].neg_()
972+
if A.dim() == 2:
973+
det = det.item()
974+
if (det < 0) ^ (sign < 0):
975+
A[0, :].neg_()
976+
else:
977+
cond = ((det < 0) ^ (sign < 0)).nonzero()
978+
if cond.size(0) > 0:
979+
for i in range(cond.size(0)):
980+
A[list(cond[i])][0, :].neg_()
976981
return A
977982

978983

test/test_cuda.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2220,6 +2220,10 @@ def test_chain_matmul(self):
22202220
def test_det_logdet_slogdet(self):
22212221
_TestTorchMixin._test_det_logdet_slogdet(self, 'cuda')
22222222

2223+
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2224+
def test_det_logdet_slogdet_batched(self):
2225+
_TestTorchMixin._test_det_logdet_slogdet_batched(self, 'cuda')
2226+
22232227
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
22242228
def test_solve(self):
22252229
_TestTorchMixin._test_solve(self, lambda t: t.cuda())

test/test_jit.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14532,6 +14532,11 @@ def forward(self, x, y):
1453214532
'test_slogdet_pos_det',
1453314533
'test_slogdet_symmetric',
1453414534
'test_slogdet_symmetric_pd',
14535+
'test_slogdet_batched_1x1_neg_det',
14536+
'test_slogdet_batched_pos_det',
14537+
'test_slogdet_batched_symmetric',
14538+
'test_slogdet_batched_symmetric_pd',
14539+
'test_slogdet_batched_distinct_singular_values'
1453514540
}
1453614541

1453714542
# known to be failing in script

test/test_torch.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6496,6 +6496,59 @@ def get_random_mat_scale(n):
64966496
def test_det_logdet_slogdet(self):
64976497
self._test_det_logdet_slogdet(self, 'cpu')
64986498

6499+
@staticmethod
6500+
def _test_det_logdet_slogdet_batched(self, device):
6501+
from common_utils import (random_symmetric_matrix, random_symmetric_psd_matrix,
6502+
random_symmetric_pd_matrix, random_square_matrix_of_rank)
6503+
6504+
# mat_chars denotes matrix characteristics
6505+
# possible values are: sym, sym_psd, sym_pd, sing, non_sym
6506+
def run_test(matsize, batchdims, mat_chars):
6507+
num_matrices = reduce(lambda x, y: x * y, batchdims, 1)
6508+
list_of_matrices = []
6509+
6510+
for idx in range(num_matrices):
6511+
mat_type = idx % len(mat_chars)
6512+
if mat_chars[mat_type] == 'sym':
6513+
list_of_matrices.append(random_symmetric_matrix(matsize).to(device=device))
6514+
elif mat_chars[mat_type] == 'sym_psd':
6515+
list_of_matrices.append(random_symmetric_psd_matrix(matsize).to(device=device))
6516+
elif mat_chars[mat_type] == 'sym_pd':
6517+
list_of_matrices.append(random_symmetric_pd_matrix(matsize).to(device=device))
6518+
elif mat_chars[mat_type] == 'sing':
6519+
list_of_matrices.append(random_square_matrix_of_rank(matsize, matsize // 2).to(device=device))
6520+
elif mat_chars[mat_type] == 'non_sing':
6521+
list_of_matrices.append(random_square_matrix_of_rank(matsize, matsize).to(device=device))
6522+
full_tensor = torch.stack(list_of_matrices, dim=0).reshape(batchdims + (matsize, matsize))
6523+
# Scaling adapted from `get_random_mat_scale` in _test_det_logdet_slogdet
6524+
full_tensor *= (math.factorial(matsize - 1) ** (-1.0 / (2 * matsize)))
6525+
6526+
for fn in [torch.det, torch.logdet, torch.slogdet]:
6527+
expected_value = []
6528+
actual_value = fn(full_tensor)
6529+
for full_idx in product(*map(lambda x: list(range(x)), batchdims)):
6530+
expected_value.append(fn(full_tensor[full_idx]))
6531+
6532+
if fn == torch.slogdet:
6533+
sign_value = torch.stack([tup[0] for tup in expected_value], dim=0).reshape(batchdims)
6534+
expected_value = torch.stack([tup[1] for tup in expected_value], dim=0).reshape(batchdims)
6535+
self.assertEqual(sign_value, actual_value[0], allow_inf=True)
6536+
self.assertEqual(expected_value, actual_value[1], allow_inf=True)
6537+
else:
6538+
expected_value = torch.stack(expected_value, dim=0).reshape(batchdims)
6539+
self.assertEqual(actual_value, expected_value, allow_inf=True)
6540+
6541+
for matsize, batchdims in product([3, 5], [(3,), (5, 3)]):
6542+
run_test(matsize, batchdims, mat_chars=['sym_pd'])
6543+
run_test(matsize, batchdims, mat_chars=['sing'])
6544+
run_test(matsize, batchdims, mat_chars=['non_sing'])
6545+
run_test(matsize, batchdims, mat_chars=['sym', 'sym_pd', 'sym_psd'])
6546+
run_test(matsize, batchdims, mat_chars=['sing', 'non_sing'])
6547+
6548+
@skipIfNoLapack
6549+
def test_det_logdet_slogdet_batched(self):
6550+
self._test_det_logdet_slogdet_batched(self, 'cpu')
6551+
64996552
@staticmethod
65006553
def _test_fft_ifft_rfft_irfft(self, device='cpu'):
65016554
def _test_complex(sizes, signal_ndim, prepro_fn=lambda x: x):

0 commit comments

Comments
 (0)