@@ -5788,28 +5788,35 @@ def run_test(A_dims, b_dims, cast, upper, transpose, unitriangular):
57885788 def test_triangular_solve_batched_dims(self):
57895789 self._test_triangular_solve_batched_dims(self, lambda t: t)
57905790
5791- @skipIfNoLapack
5792- def test_gels(self):
5791+ @staticmethod
5792+ def _test_lstsq(self, device):
5793+ def cast_fn(tensor):
5794+ return tensor.to(device=device)
5795+
57935796 def _test_underdetermined(a, b, expectedNorm):
5797+ # underdetermined systems are not supported on the GPU
5798+ if 'cuda' in device:
5799+ return
5800+
57945801 m = a.size()[0]
57955802 n = a.size()[1]
57965803 assert(m <= n)
57975804
57985805 a_copy = a.clone()
57995806 b_copy = b.clone()
5800- res1 = torch.gels (b, a)[0]
5807+ res1 = torch.lstsq (b, a)[0]
58015808 self.assertEqual(a, a_copy, 0)
58025809 self.assertEqual(b, b_copy, 0)
58035810 self.assertEqual((torch.mm(a, res1) - b).norm(), expectedNorm, 1e-8)
58045811
5805- ta = torch.Tensor()
5806- tb = torch.Tensor()
5807- res2 = torch.gels (b, a, out=(tb, ta))[0]
5812+ ta = cast_fn( torch.Tensor() )
5813+ tb = cast_fn( torch.Tensor() )
5814+ res2 = torch.lstsq (b, a, out=(tb, ta))[0]
58085815 self.assertEqual(a, a_copy, 0)
58095816 self.assertEqual(b, b_copy, 0)
58105817 self.assertEqual((torch.mm(a, res1) - b).norm(), expectedNorm, 1e-8)
58115818
5812- res3 = torch.gels (b, a, out=(b, a))[0]
5819+ res3 = torch.lstsq (b, a, out=(b, a))[0]
58135820 self.assertEqual((torch.mm(a_copy, b) - b_copy).norm(), expectedNorm, 1e-8)
58145821 self.assertEqual(res1, tb, 0)
58155822 self.assertEqual(res1, b, 0)
@@ -5823,7 +5830,6 @@ def _test_overdetermined(a, b, expectedNorm):
58235830
58245831 def check_norm(a, b, expected_norm, gels_result):
58255832 # Checks |ax - b| and the residual info from the result
5826- n = a.size()[1]
58275833
58285834 # The first n rows is the least square solution.
58295835 # Rows n to m-1 contain residual information.
@@ -5836,19 +5842,19 @@ def check_norm(a, b, expected_norm, gels_result):
58365842
58375843 a_copy = a.clone()
58385844 b_copy = b.clone()
5839- res1 = torch.gels (b, a)[0]
5845+ res1 = torch.lstsq (b, a)[0]
58405846 self.assertEqual(a, a_copy, 0)
58415847 self.assertEqual(b, b_copy, 0)
58425848 check_norm(a, b, expectedNorm, res1)
58435849
5844- ta = torch.Tensor()
5845- tb = torch.Tensor()
5846- res2 = torch.gels (b, a, out=(tb, ta))[0]
5850+ ta = cast_fn( torch.Tensor() )
5851+ tb = cast_fn( torch.Tensor() )
5852+ res2 = torch.lstsq (b, a, out=(tb, ta))[0]
58475853 self.assertEqual(a, a_copy, 0)
58485854 self.assertEqual(b, b_copy, 0)
58495855 check_norm(a, b, expectedNorm, res2)
58505856
5851- res3 = torch.gels (b, a, out=(b, a))[0]
5857+ res3 = torch.lstsq (b, a, out=(b, a))[0]
58525858 check_norm(a_copy, b_copy, expectedNorm, res3)
58535859
58545860 self.assertEqual(res1, tb, 0)
@@ -5858,51 +5864,55 @@ def check_norm(a, b, expected_norm, gels_result):
58585864
58595865 # basic test
58605866 expectedNorm = 0
5861- a = torch.Tensor(((1.44, -9.96, -7.55, 8.34),
5862- (-7.84, -0.28, 3.24, 8.09),
5863- (-4.39, -3.24, 6.27, 5.28),
5864- (4.53, 3.83, -6.64, 2.06))).t()
5865- b = torch.Tensor(((8.58, 8.26, 8.48, -5.28),
5866- (9.35, -4.43, -0.70, -0.26))).t()
5867+ a = cast_fn( torch.Tensor(((1.44, -9.96, -7.55, 8.34),
5868+ (-7.84, -0.28, 3.24, 8.09),
5869+ (-4.39, -3.24, 6.27, 5.28),
5870+ (4.53, 3.83, -6.64, 2.06) ))).t()
5871+ b = cast_fn( torch.Tensor(((8.58, 8.26, 8.48, -5.28),
5872+ (9.35, -4.43, -0.70, -0.26) ))).t()
58675873 _test_underdetermined(a, b, expectedNorm)
58685874
5869- # test overderemined
5875+ # test overdetermined
58705876 expectedNorm = 17.390200628863
5871- a = torch.Tensor(((1.44, -9.96, -7.55, 8.34, 7.08, -5.45),
5872- (-7.84, -0.28, 3.24, 8.09, 2.52, -5.70),
5873- (-4.39, -3.24, 6.27, 5.28, 0.74, -1.19),
5874- (4.53, 3.83, -6.64, 2.06, -2.47, 4.70))).t()
5875- b = torch.Tensor(((8.58, 8.26, 8.48, -5.28, 5.72, 8.93),
5876- (9.35, -4.43, -0.70, -0.26, -7.36, -2.52))).t()
5877+ a = cast_fn( torch.Tensor(((1.44, -9.96, -7.55, 8.34, 7.08, -5.45),
5878+ (-7.84, -0.28, 3.24, 8.09, 2.52, -5.70),
5879+ (-4.39, -3.24, 6.27, 5.28, 0.74, -1.19),
5880+ (4.53, 3.83, -6.64, 2.06, -2.47, 4.70) ))).t()
5881+ b = cast_fn( torch.Tensor(((8.58, 8.26, 8.48, -5.28, 5.72, 8.93),
5882+ (9.35, -4.43, -0.70, -0.26, -7.36, -2.52) ))).t()
58775883 _test_overdetermined(a, b, expectedNorm)
58785884
58795885 # test underdetermined
58805886 expectedNorm = 0
5881- a = torch.Tensor(((1.44, -9.96, -7.55),
5882- (-7.84, -0.28, 3.24),
5883- (-4.39, -3.24, 6.27),
5884- (4.53, 3.83, -6.64))).t()
5885- b = torch.Tensor(((8.58, 8.26, 8.48),
5886- (9.35, -4.43, -0.70))).t()
5887+ a = cast_fn( torch.Tensor(((1.44, -9.96, -7.55),
5888+ (-7.84, -0.28, 3.24),
5889+ (-4.39, -3.24, 6.27),
5890+ (4.53, 3.83, -6.64) ))).t()
5891+ b = cast_fn( torch.Tensor(((8.58, 8.26, 8.48),
5892+ (9.35, -4.43, -0.70) ))).t()
58875893 _test_underdetermined(a, b, expectedNorm)
58885894
58895895 # test reuse
58905896 expectedNorm = 0
5891- a = torch.Tensor(((1.44, -9.96, -7.55, 8.34),
5892- (-7.84, -0.28, 3.24, 8.09),
5893- (-4.39, -3.24, 6.27, 5.28),
5894- (4.53, 3.83, -6.64, 2.06))).t()
5895- b = torch.Tensor(((8.58, 8.26, 8.48, -5.28),
5896- (9.35, -4.43, -0.70, -0.26))).t()
5897- ta = torch.Tensor()
5898- tb = torch.Tensor()
5899- torch.gels (b, a, out=(tb, ta))
5897+ a = cast_fn( torch.Tensor(((1.44, -9.96, -7.55, 8.34),
5898+ (-7.84, -0.28, 3.24, 8.09),
5899+ (-4.39, -3.24, 6.27, 5.28),
5900+ (4.53, 3.83, -6.64, 2.06) ))).t()
5901+ b = cast_fn( torch.Tensor(((8.58, 8.26, 8.48, -5.28),
5902+ (9.35, -4.43, -0.70, -0.26) ))).t()
5903+ ta = cast_fn( torch.Tensor() )
5904+ tb = cast_fn( torch.Tensor() )
5905+ torch.lstsq (b, a, out=(tb, ta))
59005906 self.assertEqual((torch.mm(a, tb) - b).norm(), expectedNorm, 1e-8)
5901- torch.gels (b, a, out=(tb, ta))
5907+ torch.lstsq (b, a, out=(tb, ta))
59025908 self.assertEqual((torch.mm(a, tb) - b).norm(), expectedNorm, 1e-8)
5903- torch.gels (b, a, out=(tb, ta))
5909+ torch.lstsq (b, a, out=(tb, ta))
59045910 self.assertEqual((torch.mm(a, tb) - b).norm(), expectedNorm, 1e-8)
59055911
5912+ @skipIfNoLapack
5913+ def test_lstsq(self):
5914+ self._test_lstsq(self, 'cpu')
5915+
59065916 @skipIfNoLapack
59075917 def test_eig(self):
59085918 a = torch.Tensor(((1.96, 0.00, 0.00, 0.00, 0.00),
@@ -8923,9 +8933,9 @@ def fn(torchfn, *args):
89238933 q, r = fn(torch.qr, (3, 0), False)
89248934 self.assertEqual([(3, 3), (3, 0)], [q.shape, r.shape])
89258935
8926- # gels
8927- self.assertRaises(RuntimeError, lambda: torch.gels (torch.randn(0, 0), torch.randn(0, 0)))
8928- self.assertRaises(RuntimeError, lambda: torch.gels (torch.randn(0,), torch.randn(0, 0)))
8936+ # lstsq
8937+ self.assertRaises(RuntimeError, lambda: torch.lstsq (torch.randn(0, 0), torch.randn(0, 0)))
8938+ self.assertRaises(RuntimeError, lambda: torch.lstsq (torch.randn(0,), torch.randn(0, 0)))
89298939
89308940 def test_expand(self):
89318941 tensor = torch.rand(1, 8, 1)
0 commit comments