Skip to content

Commit abab0a5

Browse files
committed
Doc note for complex
ghstack-source-id: be61901 Pull Request resolved: #41252 Doc note for Complex Numbers ghstack-source-id: be61901 Pull Request resolved: #41253
1 parent 75b6dd3 commit abab0a5

File tree

3 files changed

+175
-0
lines changed

3 files changed

+175
-0
lines changed

docs/source/complex_numbers.rst

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
.. _complex_numbers-doc:
2+
3+
Complex Numbers
4+
===============
5+
6+
Complex numbers are numbers that can be expressed in the form :math:`a + bj`, where a and b are real numbers,
7+
and *j* is a solution of the equation :math:`x^2 = −1`. Complex numbers frequently occur in mathematics and
8+
engineering, especially in signal processing. Tensors of complex dtypes provide a more natural user experience
9+
for users and libraries (eg. TorchAudio) that previously worked around the lack of complex tensors by using
10+
float tensors with shape :math:`(..., 2)` where the last dimension contained the real and imaginary values.
11+
12+
Operations on complex tensors (eg :func:`torch.mv`, :func:`torch.matmul`) are likely to be faster and more
13+
memory efficient than operations on float tensors mimicking them. Operations involving complex numbers in
14+
PyTorch are optimized to use vectorized assembly instructions and specialized kernels (e.g. LAPACK, CuBlas).
15+
Thus using functions for complex tensors will provide performance benefits as opposed to users defining
16+
their own functions.
17+
18+
.. note::
19+
Spectral Ops currently don't use complex tensors but the API would be soon updated to use complex tensors.
20+
21+
.. warning ::
22+
Complex Tensors is a beta feature and subject to change.
23+
24+
Creating Complex Tensors
25+
------------------------
26+
27+
We support two complex dtypes: `torch.cfloat` and `torch.cdouble`
28+
29+
::
30+
31+
>>> x = torch.randn(2,2, dtype=torch.cfloat)
32+
>>> x
33+
tensor([[-0.4621-0.0303j, -0.2438-0.5874j],
34+
[ 0.7706+0.1421j, 1.2110+0.1918j]])
35+
36+
.. note::
37+
38+
The default dtype for complex tensors is determined by the default floating point dtype.
39+
If the default floating point dtype is torch.float64 then complex numbers are inferred to
40+
have a dtype of torch.complex128, otherwise they are assumed to have a dtype of torch.complex64.
41+
42+
All factory functions apart from :func:`torch.linspace`, :func:`torch.logspace`, and :func:`torch.arange` are
43+
supported for complex tensors.
44+
45+
Transition from the old representation
46+
--------------------------------------
47+
48+
Users who currently worked around the lack of complex tensors with real tensors of shape `(..., 2)`
49+
can easily to switch using the complex tensors in their code using :func:`torch.view_as_complex` and
50+
- :func:`torch.view_as_real`. Note that these functions don't perform any copy and
51+
return a view of the input Tensor.
52+
53+
::
54+
55+
>>> x = torch.randn(3, 2)
56+
>>> x
57+
tensor([[ 0.6125, -0.1681],
58+
[-0.3773, 1.3487],
59+
[-0.0861, -0.7981]])
60+
>>> y = torch.view_as_complex(x)
61+
>>> y
62+
tensor([ 0.6125-0.1681j, -0.3773+1.3487j, -0.0861-0.7981j])
63+
>>> torch.view_as_real(y)
64+
tensor([[ 0.6125, -0.1681],
65+
[-0.3773, 1.3487],
66+
[-0.0861, -0.7981]])
67+
68+
Accessing real and imag
69+
-----------------------
70+
71+
The real and imaginary values of a complex tensor can be accessed using the :attr:`real` and
72+
:attr:`imag` views.
73+
74+
::
75+
76+
>>> y.real
77+
tensor([ 0.6125, -0.3773, -0.0861])
78+
>>> y.imag
79+
tensor([-0.1681, 1.3487, -0.7981])
80+
81+
Angle and abs
82+
-------------
83+
84+
The angle and absolute values of a complex tensor can be accessed using :func:`torch.angle` and
85+
`torch.abs`.
86+
87+
::
88+
89+
>>> x1=torch.tensor([3j, 4+4j])
90+
>>> x1.abs()
91+
tensor([3.0000, 5.6569])
92+
>>> x1.angle()
93+
tensor([1.5708, 0.7854])
94+
95+
Linear Algebra
96+
--------------
97+
98+
Currently, there is very minimal linear algebra operation support for complex tensors.
99+
We currently support :func:`torch.mv`, :func:`torch.svd`, :func:`torch.qr`, and :func:`torch.inverse`
100+
(the latter three are only supported on CPU). However we are working to add support for more
101+
functions soon: :func:`torch.matmul`, :func:`torch.solve`, :func:`torch.eig`, :func:`torch.eig`,
102+
:func:`torch.symeig`. If any of these would help your use case, please
103+
`search <https://github.com/pytorch/pytorch/issues?q=is%3Aissue+is%3Aopen+complex>`_
104+
if an issue has already been filed and if not, `file one <https://github.com/pytorch/pytorch/issues/new/choose>`_.
105+
106+
107+
Serialization
108+
-------------
109+
110+
Complex Tensors can be serialized, allowing data to be saved as complex values.
111+
112+
::
113+
114+
>>> torch.save(y, 'complex_tensor.pt')
115+
>>> torch.load('complex_tensor.pt')
116+
tensor([ 0.6125-0.1681j, -0.3773+1.3487j, -0.0861-0.7981j])
117+
118+
119+
Autograd
120+
--------
121+
122+
PyTorch supports Autograd for Complex Tensors. The autograd APIs can be
123+
used for both holomorphic and non-holomorphic functions. For non-holomorphic
124+
functions, the gradient is evaluated as if it were holomorphic. For more details,
125+
check out the note :ref:`complex_autograd-doc`.
126+
127+
Gradient calculation can also be easily done for functions not supported for complex tensors
128+
yet by enclosing the unsupported operations between :func:`torch.view_as_real` and
129+
:func:`torch.view_as_complex` functions. The example shown below computes the dot product
130+
of two complex tensors, by performing operations on complex tensors viewed as real tensors.
131+
As shown below, the gradients computed have the same value as you would get if you were to perform
132+
the operations on complex tensors.
133+
134+
::
135+
>>> # computes the complex dot product for complex vectors
136+
>>> # represented as float vectors
137+
>>> # math: for complex numbers a and b vdot(a, b) = a.conj() * b
138+
>>> def vdot(x, y):
139+
>>> z = torch.empty_like(x)
140+
>>> z[:, 0] = x[:, 0] * y[:, 0] + x[:, 1] * y[:, 1]
141+
>>> z[:, 1] = x[:, 0] * y[:, 1] - x[:, 1] * y[:, 0]
142+
>>> return z
143+
144+
>>> x = torch.randn(2, dtype=torch.cfloat, requires_grad=True)
145+
>>> y = torch.randn(2, dtype=torch.cfloat, requires_grad=True)
146+
147+
>>> x1 = torch.view_as_real(x.clone())
148+
>>> y1 = torch.view_as_real(y.clone())
149+
>>> z = torch.view_as_complex(vdot(x1, y1))
150+
>>> z.sum().backward()
151+
152+
>>> x.grad # equals y.conj()
153+
tensor([0.5560+0.2285j, 1.5326-0.4576j])
154+
>>> y
155+
tensor([0.5560-0.2285j, 1.5326+0.4576j], requires_grad=True)
156+
>>> y.grad # equals x.conj()
157+
tensor([ 0.0766-1.0273j, -0.4325+0.2226j])
158+
>>> x
159+
tensor([ 0.0766+1.0273j, -0.4325-0.2226j], requires_grad=True)
160+
161+
We do not support the following subsystems:
162+
163+
* Quantization
164+
165+
* JIT
166+
167+
* Sparse Tensors
168+
169+
* Distributed
170+
171+
If any of these would help your use case, please `search <https://github.com/pytorch/pytorch/issues?q=is%3Aissue+is%3Aopen+complex>`_
172+
if an issue has already been filed and if not, `file one <https://github.com/pytorch/pytorch/issues/new/choose>`_.

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ PyTorch is an optimized tensor library for deep learning using GPUs and CPUs.
4545
nn.init
4646
onnx
4747
optim
48+
complex_numbers
4849
quantization
4950
rpc
5051
torch.random <random>

docs/source/notes/autograd.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ Autograd relies on the user to write thread safe C++ hooks. If you want the hook
211211
to be correctly applied in multithreading environment, you will need to write
212212
proper thread locking code to ensure the hooks are thread safe.
213213

214+
.. _complex_autograd-doc:
215+
214216
Autograd for Complex Numbers
215217
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
216218

0 commit comments

Comments
 (0)