-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Implement torch.i0 #43132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Implement torch.i0 #43132
Changes from all commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
13560f9
implement bessel function
muthuArivoli e922c86
implement kaiser
muthuArivoli 8029d52
add tests for i0
muthuArivoli 60eb59d
bound tests
muthuArivoli 6615770
add docs
muthuArivoli df513f2
autograd
muthuArivoli 1601252
attempt to fix cuda
muthuArivoli bbca196
use c10 cuda compat
muthuArivoli 95d92df
try with function implementation in cuda
muthuArivoli d6aeb9f
fix casting
muthuArivoli b68bdb6
Merge remote-tracking branch 'upstream/master' into implement-kaiser
muthuArivoli e49d160
use doubles and test against scipy
muthuArivoli b54a701
fix casting in cuda
muthuArivoli 4c52a39
fix tests
muthuArivoli 3586f98
fix tests 2
muthuArivoli 19839fc
fix bfloat16
muthuArivoli 852ae9d
fix float16 test
muthuArivoli d055a65
remove kaiser window
muthuArivoli 30a6a9b
template for other dtypes
muthuArivoli a43cfba
Merge remote-tracking branch 'upstream/master' into implement-kaiser
muthuArivoli 0b20f1a
fix comments and docs
muthuArivoli 806760d
add ranged tests
muthuArivoli 2043ace
use precision override
muthuArivoli e2cf848
licensing and documentation updates
muthuArivoli 008248d
compute in given type
muthuArivoli fde8113
update tests
muthuArivoli 243209b
Revert "compute in given type"
muthuArivoli 8ae978a
compute in float for bfloat
muthuArivoli 23c3c36
Merge remote-tracking branch 'upstream/master' into implement-kaiser
muthuArivoli 8c7e095
Merge remote-tracking branch 'upstream/master' into implement-kaiser
muthuArivoli 1bfd21d
review fixes
muthuArivoli 3b5c0f3
updates
muthuArivoli 087c13e
fix comments
muthuArivoli File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -114,13 +114,40 @@ Date: February 1996 | |
| #undef CENTRAL_RANGE | ||
|
|
||
| /* | ||
| * The following function comes with the following copyright notice. | ||
| * It has been released under the BSD license. | ||
| * Note [3-Clause BSD License for the Cephes Math Library] | ||
| * Code derived from implementations in the Cephes Math Library should mention its derivation and reference | ||
| * this note (ex. 'This function is derived from the implementation of X in the Cephes Math Library. See note | ||
| * [3-Clause BSD License for the Cephes Math Library]. The license is: | ||
| * Copyright (c) 2018, Steven Moshier | ||
| * All rights reserved. | ||
| * | ||
| * Cephes Math Library Release 2.8: June, 2000 | ||
| * Copyright 1984, 1987, 1992, 2000 by Stephen L. Moshier | ||
| * Redistribution and use in source and binary forms, with or without | ||
| * modification, are permitted provided that the following conditions are met: | ||
| * * Redistributions of source code must retain the above copyright | ||
| * notice, this list of conditions and the following disclaimer. | ||
| * * Redistributions in binary form must reproduce the above copyright | ||
| * notice, this list of conditions and the following disclaimer in the | ||
| * documentation and/or other materials provided with the distribution. | ||
| * * Neither the name of the nor the | ||
| * names of its contributors may be used to endorse or promote products | ||
| * derived from this software without specific prior written permission. | ||
| * | ||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND | ||
| * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED | ||
| * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
| * DISCLAIMED. IN NO EVENT SHALL Steven Moshier BE LIABLE FOR ANY | ||
| * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES | ||
| * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; | ||
| * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND | ||
| * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
| * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS | ||
| * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
| */ | ||
|
|
||
| /* | ||
| * This function is derived from the implementation of the zeta function in the Cephes Math Library. | ||
| * See note [3-Clause BSD License for the Cephes Math Library]. | ||
| */ | ||
| static inline double zeta(double x, double q) { | ||
| static double MACHEP = 1.11022302462515654042E-16; | ||
| static double A[] = { | ||
|
|
@@ -244,14 +271,11 @@ static inline float trigamma(float x) { | |
| result += (1 + 1 / (2*x) + ixx * (1.f/6 - ixx * (1.f/30 - ixx * (1.f/42)))) / x; | ||
| return sign * result; | ||
| } | ||
|
|
||
| /* | ||
| * The following function comes with the following copyright notice. | ||
| * It has been released under the BSD license. | ||
| * | ||
| * Cephes Math Library Release 2.8: June, 2000 | ||
| * Copyright 1984, 1987, 1992, 2000 by Stephen L. Moshier | ||
| * This function is derived from the implementation of the digamma function in the Cephes Math Library. | ||
| * See note [3-Clause BSD License for the Cephes Math Library]. | ||
| */ | ||
|
|
||
| static inline double calc_digamma(double x) { | ||
| static double PSI_10 = 2.25175258906672110764; | ||
| if (x == 0) { | ||
|
|
@@ -296,11 +320,8 @@ static inline double calc_digamma(double x) { | |
| } | ||
|
|
||
| /* | ||
| * The following function comes with the following copyright notice. | ||
| * It has been released under the BSD license. | ||
| * | ||
| * Cephes Math Library Release 2.8: June, 2000 | ||
| * Copyright 1984, 1987, 1992, 2000 by Stephen L. Moshier | ||
| * This function is derived from the implementation of the digamma function in the Cephes Math Library. | ||
| * See note [3-Clause BSD License for the Cephes Math Library]. | ||
| */ | ||
| static inline float calc_digamma(float x) { | ||
| static float PSI_10 = 2.25175258906672110764f; | ||
|
|
@@ -384,3 +405,138 @@ calc_gcd(T a, T b) { | |
| } | ||
| return b; | ||
| } | ||
|
|
||
| /* | ||
| * This function is derived from the implementation of the chbevl function in the Cephes Math Library. | ||
| * See note [3-Clause BSD License for the Cephes Math Library]. | ||
| * | ||
| * Evaluates the series | ||
| * | ||
| * len-1 | ||
| * - ' | ||
| * y = > array[i] T (x/2) | ||
| * - i | ||
| * i=0 | ||
| * | ||
| * of Chebyshev polynomials Ti at argument x/2. | ||
| * | ||
| * Coefficients are stored in reverse order, i.e. the zero order term is last in the array. Note len is the number of | ||
| * coefficients, not the order. | ||
| * | ||
| * If coefficients are for the interval a to b, x must have been transformed to x -> 2(2x - b - a)/(b-a) before | ||
| * entering the routine. This maps x from (a, b) to (-1, 1), over which the Chebyshev polynomials are defined. | ||
| * | ||
| * If the coefficients are for the inverted interval, in which (a, b) is mapped to (1/b, 1/a), the transformation | ||
| * required is x -> 2(2ab/x - b - a)/(b-a). If b is infinity, this becomes x -> 4a/x - 1. | ||
| */ | ||
| template <typename T> | ||
| static inline typename std::enable_if<std::is_floating_point<T>::value, T>::type | ||
muthuArivoli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| chbevl(T x, T array[], size_t len) { | ||
| T b0, b1, b2; | ||
|
|
||
| b0 = array[0]; | ||
| b1 = static_cast<T>(0.0); | ||
|
|
||
| for (size_t i = 1; i < len; ++i) { | ||
| b2 = b1; | ||
| b1 = b0; | ||
| b0 = x * b1 - b2 + array[i]; | ||
| } | ||
|
|
||
| return (static_cast<T>(0.5) * (b0 - b2)); | ||
| } | ||
|
|
||
| /* | ||
| * This function is derived from the implementation of the i0 function in the Cephes Math Library. | ||
| * See note [3-Clause BSD License for the Cephes Math Library]. | ||
| * | ||
| * Computes an approximation of the zeroth order modified Bessel function of the first kind. | ||
| * The approximation is actually two (sub)approximations, both using a Chebyshev polynomial expansion. | ||
| * One approximates the function over [0, 8], and the other over (8, infinity). This function takes the absolute value | ||
| * of all inputs to convert them into the domain of the approximation. | ||
| */ | ||
| template <typename T> | ||
| static inline typename std::enable_if<std::is_floating_point<T>::value, T>::type | ||
| calc_i0(T _x) { | ||
muthuArivoli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| T x = std::abs(_x); | ||
| /* Chebyshev coefficients for exp(-x) I0(x) | ||
| * in the interval [0,8]. | ||
| * | ||
| * lim(x->0){ exp(-x) I0(x) } = 1. | ||
| */ | ||
| static T A[] = { | ||
| -4.41534164647933937950E-18, | ||
| 3.33079451882223809783E-17, | ||
| -2.43127984654795469359E-16, | ||
| 1.71539128555513303061E-15, | ||
| -1.16853328779934516808E-14, | ||
| 7.67618549860493561688E-14, | ||
| -4.85644678311192946090E-13, | ||
| 2.95505266312963983461E-12, | ||
| -1.72682629144155570723E-11, | ||
| 9.67580903537323691224E-11, | ||
| -5.18979560163526290666E-10, | ||
| 2.65982372468238665035E-9, | ||
| -1.30002500998624804212E-8, | ||
| 6.04699502254191894932E-8, | ||
| -2.67079385394061173391E-7, | ||
| 1.11738753912010371815E-6, | ||
| -4.41673835845875056359E-6, | ||
| 1.64484480707288970893E-5, | ||
| -5.75419501008210370398E-5, | ||
| 1.88502885095841655729E-4, | ||
| -5.76375574538582365885E-4, | ||
| 1.63947561694133579842E-3, | ||
| -4.32430999505057594430E-3, | ||
| 1.05464603945949983183E-2, | ||
| -2.37374148058994688156E-2, | ||
| 4.93052842396707084878E-2, | ||
| -9.49010970480476444210E-2, | ||
| 1.71620901522208775349E-1, | ||
| -3.04682672343198398683E-1, | ||
| 6.76795274409476084995E-1 | ||
| }; | ||
|
|
||
| /* Chebyshev coefficients for exp(-x) sqrt(x) I0(x) | ||
| * in the inverted interval [8,infinity]. | ||
| * | ||
| * lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi). | ||
| */ | ||
| static T B[] = { | ||
| -7.23318048787475395456E-18, | ||
| -4.83050448594418207126E-18, | ||
| 4.46562142029675999901E-17, | ||
| 3.46122286769746109310E-17, | ||
| -2.82762398051658348494E-16, | ||
| -3.42548561967721913462E-16, | ||
| 1.77256013305652638360E-15, | ||
| 3.81168066935262242075E-15, | ||
| -9.55484669882830764870E-15, | ||
| -4.15056934728722208663E-14, | ||
| 1.54008621752140982691E-14, | ||
| 3.85277838274214270114E-13, | ||
| 7.18012445138366623367E-13, | ||
| -1.79417853150680611778E-12, | ||
| -1.32158118404477131188E-11, | ||
| -3.14991652796324136454E-11, | ||
| 1.18891471078464383424E-11, | ||
| 4.94060238822496958910E-10, | ||
| 3.39623202570838634515E-9, | ||
| 2.26666899049817806459E-8, | ||
| 2.04891858946906374183E-7, | ||
| 2.89137052083475648297E-6, | ||
| 6.88975834691682398426E-5, | ||
| 3.36911647825569408990E-3, | ||
| 8.04490411014108831608E-1 | ||
| }; | ||
|
|
||
| if (x <= 8.0) { | ||
| T y = (x / 2.0) - 2.0; | ||
| return static_cast<T>(std::exp(x) * chbevl(y, A, 30)); | ||
| } | ||
|
|
||
| return static_cast<T>(std::exp(x) * chbevl(static_cast<T>(32.0 / x - 2.0), B, 25) / std::sqrt(x)); | ||
| } | ||
|
|
||
| // Upcast bfloat16 input to float for numerical accuracy purposes | ||
| inline c10::BFloat16 calc_i0(c10::BFloat16 a) { return calc_i0(static_cast<float>(a)); } | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add a comment explaining the upcasting behavior (and our reasoning for it) |
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.