Skip to content

Batched symeig and qr are very slow on GPU #22573

@calincru

Description

@calincru

🐛 Bug

The recently added batched eigenvalue decomposition via torch.symeig is very slow on GPU (pr: #21858, issue: #7500).

To Reproduce

import torch
a = torch.rand(500, 2, 2)
a = 0.5 * (a + a.transpose(1, 2))
w, _ = torch.symeig(a)  # fast (~0.0006s)
a = a.cuda()
w, _ = torch.symeig(a)  # slow (~0.9s)

Expected behavior

The GPU variant should be at least as fast as the CPU one. This is an elementary matrix operation and GPUs should be fast at that.

Environment

PyTorch version: 1.2.0.dev20190707
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 16.04.6 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609
CMake version: version 3.5.1

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: GeForce GTX TITAN X
GPU 1: GeForce GTX TITAN X
GPU 2: GeForce GTX TITAN X
GPU 3: GeForce GTX TITAN X

Nvidia driver version: 418.67
cuDNN version: Probably one of the following:
/usr/local/cuda-10.0/targets/x86_64-linux/lib/libcudnn.so.7
/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudnn.so.5
/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudnn.so.7.2.1

Versions of relevant libraries:
[pip3] numpy==1.15.1
[conda] mkl 2019.4 243
[conda] pytorch-nightly 1.2.0.dev20190707 py3.7_cuda9.0.176_cudnn7.5.1_0 pytorch

Additional context

I assume this is not surprising given the following comment (CC @vishwakftw)

// We create temporary tensors on the CPU, because tensors on the GPU
// cause segfault when passed to magmaSymeig. The data is later
// moved to the appropriate device.
// In the case where self.numel() == 0, we just return an empty tensor of
// dimensions on the CUDA (to avoid the unnecessary "to(at::kCUDA)")
but it should nonetheless be fixed. It's not clear to me if that implies there's a bug in MAGMA and whether something is being done about it.

cc @ngimel @vincentqb @vishwakftw @jianyuh @nikitaved @pearu @mruberry @heitorschueroff @VitalyFedyunin

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cudaRelated to torch.cuda, and CUDA support in generalmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmulmodule: performanceIssues related to performance, either of kernel code or framework gluetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions