-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Bug
I implemented functions to perform a cholesky update via PyTorch and hoped for better execution times by utilizing the jit decorator. Unfortunately, then the result of the cholesky update is not longer correct. However, while debugging I realized that I am able to get the correct result when using a print-statement within the update function.
To Reproduce
#!/usr/bin/env python3
import torch
# All functions based on 1998__Stewart__Matrix_Algorithms-Basic_Decompositions
@torch.jit.script
def rotgen(a_in, b_in):
r = abs(a_in) + abs(b_in)
if bool(r == 0.0):
a_out = a_in.clone()
b_out = b_in.clone()
c_out = torch.ones_like(a_out, dtype=torch.double)
s_out = torch.zeros_like(a_out, dtype=torch.double)
else:
a_out = r * ((a_in / r)**2 + (b_in / r)**2)**0.5
b_out = torch.zeros_like(a_out, dtype=torch.double)
c_out = a_in / a_out
s_out = b_in / a_out
return a_out, b_out, c_out, s_out
@torch.jit.script
def rotapp(x_in, y_in,
c, s):
x_out = c * x_in + s * y_in
y_out = c * y_in - s * x_in
return x_out, y_out
@torch.jit.script
def cholupdate(R_in, v_in):
R_out = R_in.clone()
v_out = v_in.clone()
p = len(v_in)
for k in range(p):
R_out[k, k], v_out[k], c, s = rotgen(R_out[k, k], v_out[k])
R_out[k, k+1:p], v_out[k+1:p] = rotapp(R_out[k, k+1:p], v_out[k+1:p],
c, s)
#print(v_out)
return R_out
@torch.jit.script
def cholxdate(R, v, w):
R_out = R.transpose(1, 0)
v_use = v * float(abs(w)**0.25)
R_out = cholupdate(R_out, v_use)
R_out = R_out.transpose(1, 0)
return R_out
if __name__ == '__main__':
import numpy as np
dim = 2
# manually set the variables
w = np.array([1.0], dtype=np.float64)
v = np.full(dim, np.sqrt(1/3), dtype=np.float64)
R = np.zeros((dim, dim), dtype=np.float64)
R[np.tril_indices(dim)] = np.arange(1, np.cumsum(range(dim+1))[-1]+1, dtype=np.float64)
M = np.dot(R, R.T)
# make the variables torch tensors
M = torch.from_numpy(M)
R = torch.from_numpy(R)
v = torch.from_numpy(v)
w = torch.from_numpy(w)
# cholesky update
M_up0 = M + torch.sqrt(w) * torch.ger(v, v)
R_up = cholxdate(R, v, w)
M_up = torch.mm(R_up, R_up.transpose(1, 0))
R_up0 = torch.cholesky(M_up0)
print('cholesky update:')
print('M_up')
print(M_up)
print('M_up0')
print(M_up0)
print('R_up')
print(R_up)
print('R_up0')
print(R_up0)
Executing the code without the print-statement in cholupdate leads to the following wrong result:
cholesky update:
M_up
tensor([[ 1.3333, 2.3333],
[ 2.3333, 13.4167]], dtype=torch.float64, grad_fn=<MmBackward>)
M_up0
tensor([[ 1.3333, 2.3333],
[ 2.3333, 13.3333]], dtype=torch.float64)
R_up
tensor([[1.1547, 0.0000],
[2.0207, 3.0551]], dtype=torch.float64, grad_fn=<TransposeBackward0>)
R_up0
tensor([[1.1547, 0.0000],
[2.0207, 3.0414]], dtype=torch.float64)
Uncommenting the print-statement fixes the issue and the function output is correct:
cholesky update:
M_up
tensor([[ 1.3333, 2.3333],
[ 2.3333, 13.3333]], dtype=torch.float64, grad_fn=<MmBackward>)
M_up0
tensor([[ 1.3333, 2.3333],
[ 2.3333, 13.3333]], dtype=torch.float64)
R_up
tensor([[1.1547, 0.0000],
[2.0207, 3.0414]], dtype=torch.float64, grad_fn=<TransposeBackward0>)
R_up0
tensor([[1.1547, 0.0000],
[2.0207, 3.0414]], dtype=torch.float64)
Expected behavior
The print-statement should not have any impact on the output of cholupdate, which should always produce the correct result.
Environment
Collecting environment information...
PyTorch version: 1.0.1
Is debug build: No
CUDA used to build PyTorch: 9.0.176
OS: Ubuntu 18.04.2 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: Could not collect
Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 9.0.176
GPU models and configuration: GPU 0: GeForce GTX 1080 Ti
Nvidia driver version: 390.116
cuDNN version: Could not collect
Versions of relevant libraries:
[pip3] numpy==1.16.2
[pip3] numpydoc==0.8.0
[pip3] torch==1.0.1
[pip3] torchvision==0.2.1
[conda] blas 1.0 mkl
[conda] mkl 2019.3 199
[conda] mkl-service 1.1.2 py37he904b0f_5
[conda] mkl_fft 1.0.10 py37ha843d7b_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] pytorch 1.0.1 cuda90py37h8b0c50b_0
[conda] torchvision 0.2.1 py_2 pytorch