Skip to content

Commit 179807a

Browse files
ssnlfacebook-github-bot
authored andcommitted
Fix MAGMA svd and eig (#9082)
Summary: Fixes #9079 There is room for speed-up for both functions (see #9083), but let's get this in to unblock #9052 . Closes #9082 Reviewed By: ezyang Differential Revision: D8711687 Pulled By: SsnL fbshipit-source-id: f043a9bf55cb6aec5126c3331d35761f7aa3f8e3
1 parent 474fdd7 commit 179807a

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

aten/src/THC/generic/THCTensorMathMagma.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ static void THCTensor_(copyTensor2d)(THCState *state, real *dst, THCTensor *self
3030
size_t len = THCTensor_(nElement)(state, self)*sizeof(real);
3131
THCTensor *temp = THCTensor_(newTranspose)(state, self, 0, 1);
3232
THCTensor *selfc = THCTensor_(newContiguous)(state, temp);
33-
THCudaCheck(cudaMemcpy(dst, THCStorage_(data)(state, self->storage) + selfc->storageOffset, len, cudaMemcpyDeviceToHost));
33+
THCudaCheck(cudaMemcpy(dst, THCStorage_(data)(state, selfc->storage) + selfc->storageOffset, len, cudaMemcpyDeviceToHost));
3434
THCTensor_(free)(state, temp);
3535
THCTensor_(free)(state, selfc);
3636
}

test/test_cuda.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,23 @@ def tmp(t):
435435
('inverse', new_t(20, 20), lambda t: [], None, float_types, False),
436436
('geqrf', new_t(20, 20), lambda t: [], None, float_types, False,
437437
unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")),
438+
('svd', new_t(10, 10), lambda t: [], 'square', float_types_no_half, False,
439+
unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")),
440+
('svd', lambda t: new_t(10, 10)(t).t(), lambda t: [True], 'square_col_maj',
441+
float_types_no_half, False,
442+
unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")),
443+
('svd', new_t(20, 5), lambda t: [True], 'tall_some', float_types_no_half, False,
444+
unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")),
445+
('svd', new_t(20, 5), lambda t: [False], 'tall_all', float_types_no_half, False,
446+
unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")),
447+
('svd', lambda t: new_t(5, 20)(t).t(), lambda t: [True],
448+
'tall_some_col_maj', float_types_no_half, False,
449+
unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")),
450+
('svd', lambda t: new_t(5, 20)(t).t(), lambda t: [False],
451+
'tall_all_col_maj', float_types_no_half, False,
452+
unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")),
453+
('eig', new_t(10, 10), lambda t: [True], 'with_eigvec', float_types_no_half, False,
454+
unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")),
438455
]
439456

440457
# TODO: random functions, cat, gather, scatter, index*, masked*,

0 commit comments

Comments
 (0)