Skip to content

Commit 551360a

Browse files
committed
Complete tests for potrs
1 parent 3fe5df3 commit 551360a

File tree

5 files changed

+58
-6
lines changed

5 files changed

+58
-6
lines changed

aten/src/ATen/Declarations.cwrap

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2738,7 +2738,7 @@
27382738
default: U
27392739
]]
27402740
[[
2741-
name: _potrs_single
2741+
name: _th_potrs_single
27422742
cname: potrs
27432743
types:
27442744
- Float

aten/src/ATen/core/aten_interned_strings.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,6 @@ _(aten, _floor) \
8181
_(aten, _fused_dropout) \
8282
_(aten, _ger) \
8383
_(aten, _gesv_helper) \
84-
_(aten, _gesv_single) \
85-
_(aten, _getri_single) \
8684
_(aten, _indexCopy) \
8785
_(aten, _indices) \
8886
_(aten, _inverse_helper) \
@@ -104,7 +102,6 @@ _(aten, _pad_packed_sequence) \
104102
_(aten, _pdist_backward) \
105103
_(aten, _pdist_forward) \
106104
_(aten, _potrs_helper) \
107-
_(aten, _potrs_single) \
108105
_(aten, _prod) \
109106
_(aten, _prodall) \
110107
_(aten, _range) \

aten/src/ATen/native/BatchLinearAlgebra.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ Tensor _potrs_helper_cpu(const Tensor& self, const Tensor& A, bool upper) {
269269
// Supports arbitrary batch dimensions for self and A
270270
Tensor potrs(const Tensor& self, const Tensor& A, bool upper) {
271271
if (self.dim() <= 2 && A.dim() <= 2) {
272-
return at::_potrs_single(self, A, upper);
272+
return at::_th_potrs_single(self, A, upper);
273273
}
274274

275275
Tensor self_broadcasted, A_broadcasted;
@@ -281,7 +281,7 @@ Tensor& potrs_out(Tensor& result, const Tensor& self, const Tensor& A, bool uppe
281281
AT_CHECK(self.dim() == 2 && A.dim() == 2,
282282
"torch.potrs() with the `out` keyword does not support batching. "
283283
"b.dim() (", self.dim(), ") and A.dim() (", A.dim(), ") must both be 2.");
284-
return at::_potrs_single_out(result, self, A, upper);
284+
return at::_th_potrs_single_out(result, self, A, upper);
285285
}
286286

287287
}} // namespace at::native

test/test_cuda.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1557,6 +1557,10 @@ def test_potrs(self):
15571557
def test_potrs_batched(self):
15581558
_TestTorchMixin._test_potrs_batched(self, lambda t: t.cuda())
15591559

1560+
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
1561+
def test_potrs_batched_dims(self):
1562+
_TestTorchMixin._test_potrs_batched_dims(self, lambda t: t.cuda())
1563+
15601564
def test_view(self):
15611565
_TestTorchMixin._test_view(self, lambda t: t.cuda())
15621566

test/test_torch.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5404,6 +5404,57 @@ def get_cholesky(bmat, upper):
54045404
def test_potrs_batched(self):
54055405
self._test_potrs_batched(self, lambda t: t)
54065406

5407+
@staticmethod
5408+
def _test_potrs_batched_dims(self, cast):
5409+
if not TEST_NUMPY:
5410+
return
5411+
5412+
from numpy.linalg import solve
5413+
from common_utils import random_symmetric_pd_matrix
5414+
5415+
for upper in [True, False]:
5416+
# TODO: This function should be replaced after batch potrf is ready
5417+
def get_cholesky(bmat, upper):
5418+
n = bmat.size(-1)
5419+
cholesky = torch.stack([m.cholesky(upper) for m in bmat.reshape(-1, n, n)])
5420+
return cholesky.reshape_as(bmat)
5421+
5422+
# test against numpy.linalg.solve
5423+
A = cast(random_symmetric_pd_matrix(4, 2, 1, 3))
5424+
b = cast(torch.randn(2, 1, 3, 4, 6))
5425+
L = get_cholesky(A, upper)
5426+
x = torch.potrs(b, L, upper=upper)
5427+
x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
5428+
self.assertEqual(x.data, cast(x_exp))
5429+
5430+
# broadcasting b
5431+
A = cast(random_symmetric_pd_matrix(4, 2, 1, 3))
5432+
b = cast(torch.randn(4, 6))
5433+
L = get_cholesky(A, upper)
5434+
x = torch.potrs(b, L, upper=upper)
5435+
x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
5436+
self.assertEqual(x.data, cast(x_exp))
5437+
5438+
# broadcasting A
5439+
A = cast(random_symmetric_pd_matrix(4))
5440+
b = cast(torch.randn(2, 1, 3, 4, 2))
5441+
L = get_cholesky(A, upper)
5442+
x = torch.potrs(b, L, upper=upper)
5443+
x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
5444+
self.assertEqual(x.data, cast(x_exp))
5445+
5446+
# broadcasting both A & b
5447+
A = cast(random_symmetric_pd_matrix(4, 1, 3, 1))
5448+
b = cast(torch.randn(2, 1, 3, 4, 5))
5449+
L = get_cholesky(A, upper)
5450+
x = torch.potrs(b, L, upper=upper)
5451+
x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
5452+
self.assertEqual(x.data, cast(x_exp))
5453+
5454+
@skipIfNoLapack
5455+
def test_potrs_batched_dims(self):
5456+
self._test_potrs_batched_dims(self, lambda t: t)
5457+
54075458
@skipIfNoLapack
54085459
def test_potri(self):
54095460
a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),

0 commit comments

Comments
 (0)