Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
40cf01a
Add cosine window
alvgaona Oct 11, 2022
241037e
Add chebyshev window
alvgaona Oct 11, 2022
d46d227
Cast int
alvgaona Oct 11, 2022
56aebc7
Add signal module and a few windows
alvgaona Oct 11, 2022
1603573
update
alvgaona Oct 11, 2022
e911e40
update
alvgaona Oct 11, 2022
e637105
remove
alvgaona Oct 11, 2022
949cab0
Update docs
alvgaona Oct 11, 2022
b0c49ec
Update docstrings
alvgaona Oct 11, 2022
e0864e7
Undo C++ impl
alvgaona Oct 11, 2022
4798e85
Undo old changes
alvgaona Oct 11, 2022
02b0fa5
Undo old changes
alvgaona Oct 11, 2022
ca8450d
Add try-except
alvgaona Oct 11, 2022
ef1f952
Remove dummy code
alvgaona Oct 11, 2022
a02f669
Remove unusued import
alvgaona Oct 11, 2022
caa6a32
Add center
alvgaona Oct 11, 2022
c44c4c0
fix lin issues
alvgaona Oct 11, 2022
de61473
remove unused code
alvgaona Oct 11, 2022
ad8850c
fixing changes
alvgaona Oct 11, 2022
4fd93cb
fixing lints
alvgaona Oct 11, 2022
fa26443
fixing lints
alvgaona Oct 11, 2022
18910eb
fixing lints
alvgaona Oct 11, 2022
d0845af
fixing lints
alvgaona Oct 11, 2022
ae8f46a
update docs
alvgaona Oct 11, 2022
8666c9d
Reduce computation time and improve numeric stability
alvgaona Oct 11, 2022
19ac119
Update gaussian and exp windows
alvgaona Oct 11, 2022
6225f64
Address review
alvgaona Oct 11, 2022
bf5f8bb
Update docstring
alvgaona Oct 11, 2022
e9deb0e
Add docstr differently
alvgaona Oct 11, 2022
0bb072f
Update cosine
alvgaona Oct 11, 2022
330ba1d
Add note
alvgaona Oct 11, 2022
1960b2b
Update signature
alvgaona Oct 11, 2022
c69bddd
Fix lint
alvgaona Oct 11, 2022
75a80e7
Update docs
alvgaona Oct 11, 2022
3b9ddd8
Fix numeric errors
alvgaona Oct 11, 2022
0f0db71
Fix numeric errors
alvgaona Oct 11, 2022
65d2f35
Add comment
alvgaona Oct 11, 2022
a00a9e4
Fix lint
alvgaona Oct 11, 2022
9b88ffe
Address review
alvgaona Oct 11, 2022
cb28057
Move docstrs
alvgaona Oct 11, 2022
23485c5
Revert test changes
alvgaona Oct 11, 2022
7a6f207
Update OpInfo and tests
alvgaona Oct 11, 2022
f5e1c45
Update sqrt and ref kwargs
alvgaona Oct 11, 2022
66f5bd0
Remove unnecesary tests
alvgaona Oct 11, 2022
75b8010
Address comments
alvgaona Oct 11, 2022
359e62b
Solve lint issues
alvgaona Oct 11, 2022
7aa10c4
Update signal.py
alvgaona Oct 11, 2022
bb198a2
Update tests
alvgaona Oct 11, 2022
93ab755
Update tests
alvgaona Oct 11, 2022
6cd2592
Add make reference
alvgaona Oct 11, 2022
5bb9243
Update docstrings and errors
alvgaona Oct 11, 2022
e08461c
Correct docstrings
alvgaona Oct 11, 2022
f564293
Update examples
alvgaona Oct 11, 2022
3fbecdd
Update docstrings
alvgaona Oct 11, 2022
9c116df
Skip dtype torch.float16
alvgaona Oct 11, 2022
3364f0b
Fix lint
alvgaona Oct 11, 2022
e1af0e8
Use math instead of numpy
alvgaona Oct 11, 2022
75e1887
Update DecorateInfo
alvgaona Oct 11, 2022
bbb20f7
Fix OpInfo
alvgaona Oct 11, 2022
761ead2
Expect failure
alvgaona Oct 11, 2022
3cfee59
Add decorateinfo
alvgaona Oct 11, 2022
e030843
Add device_type
alvgaona Oct 11, 2022
1e2e1a6
Address review
alvgaona Oct 11, 2022
ca79241
Fix lint err
alvgaona Oct 11, 2022
2f82579
Address review and fix lint
alvgaona Oct 11, 2022
04abf81
Merge branch 'master' of https://github.com/pytorch/pytorch into new-…
alvgaona Oct 11, 2022
32de5f6
Merge branch 'master' of https://github.com/pytorch/pytorch into new-…
alvgaona Oct 11, 2022
f29e59d
Expected failure test_meta torch.half
alvgaona Oct 11, 2022
a6103f9
Update tests and index
alvgaona Oct 12, 2022
04851f8
Address comments
alvgaona Oct 13, 2022
d7ea5a3
Remove random
alvgaona Oct 13, 2022
db86f80
Reorder op_db
alvgaona Oct 13, 2022
ce4bb65
Fix bug and add tests
alvgaona Oct 13, 2022
10a0b14
Add Optional
alvgaona Oct 13, 2022
77bf43b
Replace arange with linspace
alvgaona Oct 13, 2022
21dd563
Change signature
alvgaona Oct 13, 2022
7df83fc
rerun checks
alvgaona Oct 14, 2022
96378f1
Merge branch 'master' of https://github.com/pytorch/pytorch into new-…
alvgaona Oct 14, 2022
e399648
Skip test_dispatch_symbolic_meta
alvgaona Oct 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ Features described in this documentation are classified by release status:
torch.jit <jit>
torch.linalg <linalg>
torch.monitor <monitor>
torch.signal <signal>
torch.special <special>
torch.overrides
torch.package <package>
Expand Down
23 changes: 23 additions & 0 deletions docs/source/signal.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
.. role:: hidden
:class: hidden-section

torch.signal
============
.. automodule:: torch.signal
.. currentmodule:: torch.signal

The `torch.signal` module, modeled after SciPy's `signal <https://docs.scipy.org/doc/scipy/reference/signal.html>`_ module.

torch.signal.windows
--------------------

.. automodule:: torch.signal.windows
.. currentmodule:: torch.signal.windows

.. autosummary::
:toctree: generated
:nosignatures:

cosine
exponential
gaussian
1 change: 1 addition & 0 deletions torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,7 @@ def _assert(condition, message):
from torch import futures as futures
from torch import nested as nested
from torch import nn as nn
from torch.signal import windows as windows
from torch import optim as optim
import torch.optim._multi_tensor
from torch import multiprocessing as multiprocessing
Expand Down
5 changes: 5 additions & 0 deletions torch/signal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from . import windows

__all__ = [
'windows'
]
10 changes: 10 additions & 0 deletions torch/signal/windows/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import warnings

from .windows import cosine, exponential, gaussian


__all__ = [
'cosine',
'exponential',
'gaussian',
]
298 changes: 298 additions & 0 deletions torch/signal/windows/windows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
from typing import Optional

import torch
from math import sqrt

from torch import Tensor
from torch._prims_common import is_float_dtype
from torch._torch_docs import factory_common_args, parse_kwargs, merge_dicts

__all__ = [
'cosine',
'exponential',
'gaussian',
]

window_common_args = merge_dicts(
parse_kwargs(
"""
M (int): the length of the window.
In other words, the number of points of the returned window.
sym (bool, optional): If `False`, returns a periodic window suitable for use in spectral analysis.
If `True`, returns a symmetric window suitable for use in filter design. Default: `True`.
"""
),
factory_common_args,
{"normalization": "The window is normalized to 1 (maximum value is 1). However, the 1 doesn't appear if "
"`M` is even and `sym` is `True`."}
)


def _add_docstr(*args):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just reuse the existing _add_docstr?

Copy link
Contributor Author

@alvgaona alvgaona Oct 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conventional _add_docstr is under the torch._C namespace, which I reckon doesn't work for Pure Python functions. I tried it before and I get this error (trying to follow the _torch_docs.py way of adding docstrings).

Input:

def dummy(a):
    print(a)


_add_docstr(
    dummy,
    r"""
    Ha1
    """
)

Output:

TypeError: don't know how to add docstring to type 'function'

r"""Adds docstrings to a given decorated function.

Specially useful when then docstrings needs string interpolation, e.g., with
str.format().
REMARK: Do not use this function if the docstring doesn't need string
interpolation, just write a conventional docstring.

Args:
args (str):
"""

def decorator(o):
o.__doc__ = "".join(args)
return o

return decorator


def _window_function_checks(function_name: str, M: int, dtype: torch.dtype, layout: torch.layout) -> None:
r"""Performs common checks for all the defined windows.
This function should be called before computing any window.

Args:
function_name (str): name of the window function.
M (int): length of the window.
dtype (:class:`torch.dtype`): the desired data type of returned tensor.
layout (:class:`torch.layout`): the desired layout of returned tensor.
"""
if M < 0:
raise ValueError(f'{function_name} requires non-negative window length, got M={M}')
if layout is not torch.strided:
raise ValueError(f'{function_name} is implemented for strided tensors only, got: {layout}')
if not is_float_dtype(dtype):
raise ValueError(f'{function_name} expects floating point dtypes, got: {dtype}')


@_add_docstr(
r"""
Computes a window with an exponential waveform.
Also known as Poisson window.

The exponential window is defined as follows:

.. math::
w(n) = \exp{\left(-\frac{|n - c|}{\tau}\right)}

where `c` is the center of the window.
""",
r"""

{normalization}

Args:
{M}

Keyword args:
center (float, optional): where the center of the window will be located.
Default: `M / 2` if `sym` is `False`, else `(M - 1) / 2`.
tau (float, optional): the decay value.
Tau is generally associated with a percentage, that means, that the value should
vary within the interval (0, 100]. If tau is 100, it is considered the uniform window.
Default: 1.0.
{sym}
{dtype}
{layout}
{device}
{requires_grad}

Examples::

>>> # Generate a symmetric exponential window of size 10 and with a decay value of 1.0.
>>> # The center will be at (M - 1) / 2, where M is 10.
>>> torch.signal.windows.exponential(10)
tensor([0.0111, 0.0302, 0.0821, 0.2231, 0.6065, 0.6065, 0.2231, 0.0821, 0.0302, 0.0111])

>>> # Generate a periodic exponential window and decay factor equal to .5
>>> torch.signal.windows.exponential(10,sym=False,tau=.5)
tensor([4.5400e-05, 3.3546e-04, 2.4788e-03, 1.8316e-02, 1.3534e-01, 1.0000e+00, 1.3534e-01, 1.8316e-02, 2.4788e-03, 3.3546e-04])
""".format(
**window_common_args
),
)
def exponential(
M: int,
*,
center: Optional[float] = None,
tau: float = 1.0,
sym: bool = True,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
device: Optional[torch.device] = None,
requires_grad: bool = False
) -> Tensor:
if dtype is None:
dtype = torch.get_default_dtype()

_window_function_checks('exponential', M, dtype, layout)

if M == 0:
return torch.empty((0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad)

if tau <= 0:
raise ValueError(f'Tau must be positive, got: {tau} instead.')

if sym and center is not None:
raise ValueError('Center must be None for symmetric windows')

if center is None:
center = (M if not sym and M > 1 else M - 1) / 2.0

constant = 1 / tau

"""
Note that non-integer step is subject to floating point rounding errors when comparing against end;
thus, to avoid inconsistency, we added an epsilon equal to `step / 2` to `end`.
"""
k = torch.linspace(start=-center * constant,
end=(-center + (M - 1)) * constant,
steps=M,
dtype=dtype,
layout=layout,
device=device,
requires_grad=requires_grad)

return torch.exp(-torch.abs(k))


@_add_docstr(
r"""
Computes a window with a simple cosine waveform.
Also known as the sine window.

The cosine window is defined as follows:

.. math::
w(n) = \cos{\left(\frac{\pi n}{M} - \frac{\pi}{2}\right)} = \sin{\left(\frac{\pi n}{M}\right)}
""",
r"""

{normalization}

Args:
{M}

Keyword args:
{sym}
{dtype}
{layout}
{device}
{requires_grad}

Examples::

>>> # Generate a symmetric cosine window.
>>> torch.signal.windows.cosine(10)
tensor([0.1564, 0.4540, 0.7071, 0.8910, 0.9877, 0.9877, 0.8910, 0.7071, 0.4540, 0.1564])

>>> # Generate a periodic cosine window.
>>> torch.signal.windows.cosine(10,sym=False)
tensor([0.1423, 0.4154, 0.6549, 0.8413, 0.9595, 1.0000, 0.9595, 0.8413, 0.6549, 0.4154])
""".format(
**window_common_args,
),
)
def cosine(
M: int,
*,
sym: bool = True,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
device: Optional[torch.device] = None,
requires_grad: bool = False
) -> Tensor:
if dtype is None:
dtype = torch.get_default_dtype()

_window_function_checks('cosine', M, dtype, layout)

if M == 0:
return torch.empty((0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad)

start = 0.5
constant = torch.pi / (M + 1 if not sym and M > 1 else M)

k = torch.linspace(start=start * constant,
end=(start + (M - 1)) * constant,
steps=M,
dtype=dtype,
layout=layout,
device=device,
requires_grad=requires_grad)

return torch.sin(k)


@_add_docstr(
r"""
Computes a window with a gaussian waveform.

The gaussian window is defined as follows:

.. math::
w(n) = \exp{\left(-\left(\frac{n}{2\sigma}\right)^2\right)}
""",
r"""

{normalization}

Args:
{M}

Keyword args:
std (float, optional): the standard deviation of the gaussian. It controls how narrow or wide the window is.
Default: 1.0.
{sym}
{dtype}
{layout}
{device}
{requires_grad}

Examples::

>>> # Generate a symmetric gaussian window with a standard deviation of 1.0.
>>> torch.signal.windows.gaussian(10)
tensor([4.0065e-05, 2.1875e-03, 4.3937e-02, 3.2465e-01, 8.8250e-01, 8.8250e-01, 3.2465e-01, 4.3937e-02, 2.1875e-03, 4.0065e-05])

>>> # Generate a periodic gaussian window and standard deviation equal to 0.9.
>>> torch.signal.windows.gaussian(10,sym=False,std=0.9)
tensor([1.9858e-07, 5.1365e-05, 3.8659e-03, 8.4658e-02, 5.3941e-01, 1.0000e+00, 5.3941e-01, 8.4658e-02, 3.8659e-03, 5.1365e-05])
""".format(
**window_common_args,
),
)
def gaussian(
M: int,
*,
std: float = 1.0,
sym: bool = True,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
device: Optional[torch.device] = None,
requires_grad: bool = False
) -> Tensor:
if dtype is None:
dtype = torch.get_default_dtype()

_window_function_checks('gaussian', M, dtype, layout)

if M == 0:
return torch.empty((0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad)

if std <= 0:
raise ValueError(f'Standard deviation must be positive, got: {std} instead.')

start = -(M if not sym and M > 1 else M - 1) / 2.0

constant = 1 / (std * sqrt(2))

k = torch.linspace(start=start * constant,
end=(start + (M - 1)) * constant,
steps=M,
dtype=dtype,
layout=layout,
device=device,
requires_grad=requires_grad)

return torch.exp(-k ** 2)
9 changes: 8 additions & 1 deletion torch/testing/_internal/opinfo/definitions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from typing import List

from torch.testing._internal.opinfo.core import OpInfo
from torch.testing._internal.opinfo.definitions import _masked, fft, linalg, special
from torch.testing._internal.opinfo.definitions import (
_masked,
fft,
linalg,
signal,
special,
)

# Operator database
op_db: List[OpInfo] = [
*fft.op_db,
*linalg.op_db,
*signal.op_db,
*special.op_db,
*_masked.op_db,
]
Expand Down
Loading