Skip to content

Conversation

@kshitij12345
Copy link
Collaborator

Fixes #41817

@mruberry mruberry added the module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul label Aug 26, 2020
@mruberry mruberry self-requested a review August 26, 2020 04:24
@mruberry mruberry added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 26, 2020
@mruberry
Copy link
Collaborator

Hey @kshitij12345! Would you add a test for this behavior to ensure it doesn't regress?

run_test([10, 20, 30, 5])
run_test([15, 5, 10, 20, 25])

with self.assertRaisesRegex(RuntimeError, "chain_matmul: Expected one or more matrices"):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have updated the existing test here. To verify the behaviour.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytorch/test/test_torch.py

Lines 7224 to 7241 in 5fb5244

@dtypes(torch.double)
def test_chain_matmul(self, device, dtype):
def product(matrices):
for mat in matrices[1:]:
matrices[0] = matrices[0].mm(mat)
return matrices[0]
def run_test(p):
matrices = []
for (pi, pi_1) in zip(p[:-1], p[1:]):
matrices.append(torch.randn(pi, pi_1, dtype=dtype, device=device))
self.assertEqual(torch.chain_matmul(*matrices), product(matrices))
run_test([10, 20, 30, 5])
run_test([15, 5, 10, 20, 25])
with self.assertRaisesRegex(RuntimeError, "chain_matmul: Expected one or more matrices"):
torch.chain_matmul()

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Thanks @kshitij12345!

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@mruberry merged this pull request in 01b5c06.

@kshitij12345 kshitij12345 deleted the fix/chain-matmul/empty-args branch August 28, 2020 16:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Segfault when passing an empty input to torch.chain_matmul

4 participants