Skip to content

10% difference noticed between jit and python model #20814

@zhangguanheng66

Description

@zhangguanheng66

🐛 Bug

We recently develop a jit test for torch.nn.functional.multi_head_attention_forward function. The test fails due to the numerical discrepancy between the jit version and the python version.

To Reproduce

Steps to reproduce the behavior:

  1. Check out the branch in PR (Remove functionality unsupported by the JIT from multi_head_attention_forward. #20653)
  2. Run the unit test "python test/test_jit.py TestScript.test_torchscript_multi_head_attn"
  3. (Optional) check the rel. error by printout.

Expected behavior

The jit version and python version are expected to generate very close results.

Environment

PyTorch version: 1.1.0a0+8f9f7ed
Is debug build: No
CUDA used to build PyTorch: 9.2.88

OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: version 3.12.2

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 9.2.88
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100

Nvidia driver version: 410.79
cuDNN version: Could not collect

Versions of relevant libraries:
[pip] numpy==1.15.4
[pip] numpydoc==0.8.0
[pip] torch==1.1.0a0+8f9f7ed
[conda] blas 1.0 mkl
[conda] magma-cuda90 2.5.0 1 pytorch
[conda] mkl 2019.1 144
[conda] mkl-include 2019.3 199
[conda] mkl-service 1.1.2 py37he904b0f_5
[conda] mkl_fft 1.0.6 py37hd81dba3_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] torch 1.1.0a0+8f9f7ed dev_0

Metadata

Metadata

Assignees

No one assigned

    Labels

    oncall: jitAdd this issue/PR to JIT oncall triage queue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions