Skip to content

Conversation

@IvanYashchuk
Copy link
Collaborator

Fixes #42418.

The problem was that the non-contiguous batched matrices were passed to gemmStridedBatched.

The following code fails on master and works with the proposed patch:

import torch
x = torch.tensor([[1., 2, 3], [4., 5, 6]], device='cuda:0')
c = torch.as_strided(x, size=[2, 2, 2], stride=[3, 1, 1])
torch.einsum('...ab,...bc->...ac', c, c)

@dr-ci
Copy link

dr-ci bot commented Aug 2, 2020

💊 CI failures summary and remediations

As of commit cfff0b4 (more details on the Dr. CI page):


  • 2/2 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_doc_test (1/1)

Step: "Doc test" (full log | diagnosis details | 🔁 rerun)

Aug 04 16:49:42 caused by: Connection refused (os error 111)
Aug 04 16:49:41 +++++ extract_trap_cmd 
Aug 04 16:49:41 +++++ printf '%s\n' '' 
Aug 04 16:49:41 ++++ printf '%s\n' cleanup 
Aug 04 16:49:41 +++ trap -- ' 
Aug 04 16:49:41 cleanup' EXIT 
Aug 04 16:49:41 +++ [[ pytorch-doc-test != *pytorch-win-* ]] 
Aug 04 16:49:41 +++ which sccache 
Aug 04 16:49:41 +++ sccache --stop-server 
Aug 04 16:49:42 Stopping sccache server... 
Aug 04 16:49:42 error: couldn't connect to server 
Aug 04 16:49:42 caused by: Connection refused (os error 111) 
Aug 04 16:49:42 +++ true 
Aug 04 16:49:42 +++ rm /var/lib/jenkins/sccache_error.log 
Aug 04 16:49:42 +++ SCCACHE_ERROR_LOG=/var/lib/jenkins/sccache_error.log 
Aug 04 16:49:42 +++ SCCACHE_IDLE_TIMEOUT=1200 
Aug 04 16:49:42 +++ RUST_LOG=sccache::server=error 
Aug 04 16:49:42 +++ sccache --start-server 
Aug 04 16:49:42 Starting sccache server... 
Aug 04 16:49:42 +++ sccache --zero-stats 
Aug 04 16:49:42 Compile requests                 0 
Aug 04 16:49:42 Compile requests executed        0 

1 failure not recognized by patterns:

Job Step Action
CircleCI pytorch_python_doc_build Doc Build and Push 🔁 rerun

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 32 times.

@ngimel
Copy link
Collaborator

ngimel commented Aug 3, 2020

Thank you, the fix looks good! Please add the test to the test suite. There's test_baddbmm in test_torch.py, but for some reason it's enabled only on the CPU. Can you try enabling it on cuda and adding a testcase for the behavior your enabled there?

@IvanYashchuk
Copy link
Collaborator Author

Sure, I will do that.

@IvanYashchuk
Copy link
Collaborator Author

While writing tests I found that torch.mm also doesn't work with this kind of strided input

import torch
x = torch.tensor([[1., 2, 3], [4., 5, 6]], device='cuda:0')
c = torch.as_strided(x, size=[2, 2, 2], stride=[3, 1, 1])
torch.mm(c[0], c[0]) # Fails

The same problem was in ATen/native/cuda/LinearAlgebra.cu:prepare_matrix_for_cublas as in THCTensor_(baddbmm) that the input was not transformed into a contiguous array. I have fixed that.
I've added the tests both for pytorch.mm and pytorch.bmm. Correctness is checked via comparing to the NumPy result.

@IvanYashchuk IvanYashchuk changed the title Fix the bug in THCTensor_(baddbmm) for strided views input Fix the bug in THCTensor_(baddbmm) and ATen's addmm_cuda for strided views input Aug 3, 2020
Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, this looks great! I have a small suggestion about the test.

Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks!

@ngimel
Copy link
Collaborator

ngimel commented Aug 4, 2020

flake8 error is real

@IvanYashchuk
Copy link
Collaborator Author

flake8 error is real

Is it okay to ignore E731 because of assigned lambdas instead of def in the test?

@ngimel
Copy link
Collaborator

ngimel commented Aug 4, 2020

Yeah, that's fine, you could also use lambdas directly as arguments, but it does not matter. Sorry to ask, but can you please rebase? We had an issue with docker images today, so CI on this PR is failing because it is against a bad base, and can't find docker images.

@IvanYashchuk
Copy link
Collaborator Author

IvanYashchuk commented Aug 4, 2020

Alright, I did something wrong here. I've rebased onto master and pushed to branch and PR was automatically closed. I'll try to fix that.

@ngimel
Copy link
Collaborator

ngimel commented Aug 4, 2020

It somehow became 0-commit 0-line pull request, that's probably why it was closed. Something wrong with the rebase?

@IvanYashchuk
Copy link
Collaborator Author

IvanYashchuk commented Aug 4, 2020

I've recovered the branch. I am sorry for the inconvenience. I've re-opened the PR.

@IvanYashchuk IvanYashchuk reopened this Aug 4, 2020
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ngimel
Copy link
Collaborator

ngimel commented Aug 4, 2020

Thank you!

@facebook-github-bot
Copy link
Contributor

@ngimel merged this pull request in b9e68e0.

@IvanYashchuk IvanYashchuk deleted the fix-issue-42418 branch August 8, 2020 06:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

einsum fails in THCudaBlas_DgemmStridedBatched only when running on GPU

5 participants