-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
Issue description
I tried doing batchwise dot product across channels or rather pairwise similarity between all pairs of features for two sets of feature matrices using torch.einsum, but it fails. Could be a bug or is this to be expected?
To Reproduce
>>> import torch
>>> a = torch.randn(1, 3, 24, 20)
>>> b = torch.randn(5, 3, 24, 20)
>>> torch.einsum("bijk, bilm -> bjklm", a, b).shape
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python3.6/dist-packages/torch/functional.py", line 245, in einsum
return torch._C._VariableFunctions.einsum(equation, operands)
RuntimeError: size of dimension does not match previous size, operand 1, dim 0Expected behavior
In NumPy i get
>>> import numpy as np
>>> a = np.random.rand(1, 3, 24, 20)
>>> b = np.random.rand(5, 3, 24, 20)
>>> np.einsum("bijk, bilm -> bjklm", a, b).shape
(5, 24, 20, 24, 20)Environment
PyTorch version: 1.0.0
Is debug build: No
CUDA used to build PyTorch: 9.0.176
OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: Could not collect
Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: GeForce GTX 950M
Nvidia driver version: 396.44
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.1.4
/usr/lib/x86_64-linux-gnu/libcudnn_static_v7.a
/usr/local/cuda-9.2/lib64/libcudnn.so
/usr/local/cuda-9.2/lib64/libcudnn.so.7
/usr/local/cuda-9.2/lib64/libcudnn.so.7.1.4
/usr/local/cuda-9.2/lib64/libcudnn_static.a
Versions of relevant libraries:
[pip3] numpy (1.15.4)
[pip3] torch (1.0.0)
[pip3] torchvision (0.2.1)
[conda] Could not collect
cc @mruberry @rgommers @vincentqb @vishwakftw @jianyuh @nikitaved @pearu