|
| 1 | +.. _complex_numbers-doc: |
| 2 | + |
| 3 | +Complex Numbers |
| 4 | +=============== |
| 5 | + |
| 6 | +Complex numbers are numbers that can be expressed in the form :math:`a + bj`, where a and b are real numbers, |
| 7 | +and *j* is a solution of the equation :math:`x^2 = −1`. Complex numbers frequently occur in mathematics and |
| 8 | +engineering, especially in signal processing. Tensors of complex dtypes provide a more natural user experience |
| 9 | +for users and libraries (eg. TorchAudio) that previously worked around the lack of complex tensors by using |
| 10 | +float tensors with shape :math:`(..., 2)` where the last dimension contained the real and imaginary values. |
| 11 | + |
| 12 | +Operations on complex tensors (eg :func:`torch.mv`, :func:`torch.matmul`) are likely to be faster and more |
| 13 | +memory efficient than operations on float tensors mimicking them. Operations involving complex numbers in |
| 14 | +PyTorch are optimized to use vectorized assembly instructions and specialized kernels (e.g. LAPACK, CuBlas). |
| 15 | +Thus using functions for complex tensors will provide performance benefits as opposed to users defining |
| 16 | +their own functions. |
| 17 | + |
| 18 | +.. note:: |
| 19 | + Spectral Ops currently don't use complex tensors but the API would be soon updated to use complex tensors. |
| 20 | + |
| 21 | +.. warning :: |
| 22 | + Complex Tensors is a beta feature and subject to change. |
| 23 | +
|
| 24 | +Creating Complex Tensors |
| 25 | +------------------------ |
| 26 | + |
| 27 | +We support two complex dtypes: `torch.cfloat` and `torch.cdouble` |
| 28 | + |
| 29 | +:: |
| 30 | + |
| 31 | + >>> x = torch.randn(2,2, dtype=torch.cfloat) |
| 32 | + >>> x |
| 33 | + tensor([[-0.4621-0.0303j, -0.2438-0.5874j], |
| 34 | + [ 0.7706+0.1421j, 1.2110+0.1918j]]) |
| 35 | + |
| 36 | +.. note:: |
| 37 | + |
| 38 | + The default dtype for complex tensors is determined by the default floating point dtype. |
| 39 | + If the default floating point dtype is torch.float64 then complex numbers are inferred to |
| 40 | + have a dtype of torch.complex128, otherwise they are assumed to have a dtype of torch.complex64. |
| 41 | + |
| 42 | +All factory functions apart from :func:`torch.linspace`, :func:`torch.logspace`, and :func:`torch.arange` are |
| 43 | +supported for complex tensors. |
| 44 | + |
| 45 | +Transition from the old representation |
| 46 | +-------------------------------------- |
| 47 | + |
| 48 | +Users who currently worked around the lack of complex tensors with real tensors of shape `(..., 2)` |
| 49 | +can easily to switch using the complex tensors in their code using :func:`torch.view_as_complex` and |
| 50 | +- :func:`torch.view_as_real`. Note that these functions don't perform any copy and |
| 51 | +return a view of the input Tensor. |
| 52 | + |
| 53 | +:: |
| 54 | + |
| 55 | + >>> x = torch.randn(3, 2) |
| 56 | + >>> x |
| 57 | + tensor([[ 0.6125, -0.1681], |
| 58 | + [-0.3773, 1.3487], |
| 59 | + [-0.0861, -0.7981]]) |
| 60 | + >>> y = torch.view_as_complex(x) |
| 61 | + >>> y |
| 62 | + tensor([ 0.6125-0.1681j, -0.3773+1.3487j, -0.0861-0.7981j]) |
| 63 | + >>> torch.view_as_real(y) |
| 64 | + tensor([[ 0.6125, -0.1681], |
| 65 | + [-0.3773, 1.3487], |
| 66 | + [-0.0861, -0.7981]]) |
| 67 | + |
| 68 | +Accessing real and imag |
| 69 | +----------------------- |
| 70 | + |
| 71 | +The real and imaginary values of a complex tensor can be accessed using the :attr:`real` and |
| 72 | +:attr:`imag` views. |
| 73 | + |
| 74 | +:: |
| 75 | + |
| 76 | + >>> y.real |
| 77 | + tensor([ 0.6125, -0.3773, -0.0861]) |
| 78 | + >>> y.imag |
| 79 | + tensor([-0.1681, 1.3487, -0.7981]) |
| 80 | + |
| 81 | +Angle and abs |
| 82 | +------------- |
| 83 | + |
| 84 | +The angle and absolute values of a complex tensor can be accessed using :func:`torch.angle` and |
| 85 | +`torch.abs`. |
| 86 | + |
| 87 | +:: |
| 88 | + |
| 89 | + >>> x1=torch.tensor([3j, 4+4j]) |
| 90 | + >>> x1.abs() |
| 91 | + tensor([3.0000, 5.6569]) |
| 92 | + >>> x1.angle() |
| 93 | + tensor([1.5708, 0.7854]) |
| 94 | + |
| 95 | +Linear Algebra |
| 96 | +-------------- |
| 97 | + |
| 98 | +Currently, there is very minimal linear algebra operation support for complex tensors. |
| 99 | +We currently support :func:`torch.mv`, :func:`torch.svd`, :func:`torch.qr`, and :func:`torch.inverse` |
| 100 | +(the latter three are only supported on CPU). However we are working to add support for more |
| 101 | +functions soon: :func:`torch.matmul`, :func:`torch.solve`, :func:`torch.eig`, :func:`torch.eig`, |
| 102 | +:func:`torch.symeig`. If any of these would help your use case, please |
| 103 | +`search <https://github.com/pytorch/pytorch/issues?q=is%3Aissue+is%3Aopen+complex>`_ |
| 104 | +if an issue has already been filed and if not, `file one <https://github.com/pytorch/pytorch/issues/new/choose>`_. |
| 105 | + |
| 106 | + |
| 107 | +Serialization |
| 108 | +------------- |
| 109 | + |
| 110 | +Complex Tensors can be serialized, allowing data to be saved as complex values. |
| 111 | + |
| 112 | +:: |
| 113 | + |
| 114 | + >>> torch.save(y, 'complex_tensor.pt') |
| 115 | + >>> torch.load('complex_tensor.pt') |
| 116 | + tensor([ 0.6125-0.1681j, -0.3773+1.3487j, -0.0861-0.7981j]) |
| 117 | + |
| 118 | + |
| 119 | +Autograd |
| 120 | +-------- |
| 121 | + |
| 122 | +PyTorch supports Autograd for Complex Tensors. The autograd APIs can be |
| 123 | +used for both holomorphic and non-holomorphic functions. For non-holomorphic |
| 124 | +functions, the gradient is evaluated as if it were holomorphic. For more details, |
| 125 | +check out the note :ref:`complex_autograd-doc`. |
| 126 | + |
| 127 | +Gradient calculation can also be easily done for functions not supported for complex tensors |
| 128 | +yet by enclosing the unsupported operations between :func:`torch.view_as_real` and |
| 129 | +:func:`torch.view_as_complex` functions. The example shown below computes the dot product |
| 130 | +of two complex tensors, by performing operations on complex tensors viewed as real tensors. |
| 131 | +As shown below, the gradients computed have the same value as you would get if you were to perform |
| 132 | +the operations on complex tensors. |
| 133 | + |
| 134 | +:: |
| 135 | + >>> # computes the complex dot product for complex vectors |
| 136 | + >>> # represented as float vectors |
| 137 | + >>> # math: for complex numbers a and b vdot(a, b) = a.conj() * b |
| 138 | + >>> def vdot(x, y): |
| 139 | + >>> z = torch.empty_like(x) |
| 140 | + >>> z[:, 0] = x[:, 0] * y[:, 0] + x[:, 1] * y[:, 1] |
| 141 | + >>> z[:, 1] = x[:, 0] * y[:, 1] - x[:, 1] * y[:, 0] |
| 142 | + >>> return z |
| 143 | + |
| 144 | + >>> x = torch.randn(2, dtype=torch.cfloat, requires_grad=True) |
| 145 | + >>> y = torch.randn(2, dtype=torch.cfloat, requires_grad=True) |
| 146 | + |
| 147 | + >>> x1 = torch.view_as_real(x.clone()) |
| 148 | + >>> y1 = torch.view_as_real(y.clone()) |
| 149 | + >>> z = torch.view_as_complex(vdot(x1, y1)) |
| 150 | + >>> z.sum().backward() |
| 151 | + |
| 152 | + >>> x.grad # equals y.conj() |
| 153 | + tensor([0.5560+0.2285j, 1.5326-0.4576j]) |
| 154 | + >>> y |
| 155 | + tensor([0.5560-0.2285j, 1.5326+0.4576j], requires_grad=True) |
| 156 | + >>> y.grad # equals x.conj() |
| 157 | + tensor([ 0.0766-1.0273j, -0.4325+0.2226j]) |
| 158 | + >>> x |
| 159 | + tensor([ 0.0766+1.0273j, -0.4325-0.2226j], requires_grad=True) |
| 160 | + |
| 161 | +We do not support the following subsystems: |
| 162 | + |
| 163 | +* Quantization |
| 164 | + |
| 165 | +* JIT |
| 166 | + |
| 167 | +* Sparse Tensors |
| 168 | + |
| 169 | +* Distributed |
| 170 | + |
| 171 | +If any of these would help your use case, please `search <https://github.com/pytorch/pytorch/issues?q=is%3Aissue+is%3Aopen+complex>`_ |
| 172 | +if an issue has already been filed and if not, `file one <https://github.com/pytorch/pytorch/issues/new/choose>`_. |
0 commit comments