Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
13560f9
implement bessel function
muthuArivoli Aug 16, 2020
e922c86
implement kaiser
muthuArivoli Aug 16, 2020
8029d52
add tests for i0
muthuArivoli Aug 16, 2020
60eb59d
bound tests
muthuArivoli Aug 16, 2020
6615770
add docs
muthuArivoli Aug 16, 2020
df513f2
autograd
muthuArivoli Aug 16, 2020
1601252
attempt to fix cuda
muthuArivoli Aug 17, 2020
bbca196
use c10 cuda compat
muthuArivoli Aug 17, 2020
95d92df
try with function implementation in cuda
muthuArivoli Aug 18, 2020
d6aeb9f
fix casting
muthuArivoli Aug 18, 2020
b68bdb6
Merge remote-tracking branch 'upstream/master' into implement-kaiser
muthuArivoli Aug 18, 2020
e49d160
use doubles and test against scipy
muthuArivoli Aug 19, 2020
b54a701
fix casting in cuda
muthuArivoli Aug 20, 2020
4c52a39
fix tests
muthuArivoli Aug 20, 2020
3586f98
fix tests 2
muthuArivoli Aug 20, 2020
19839fc
fix bfloat16
muthuArivoli Aug 21, 2020
852ae9d
fix float16 test
muthuArivoli Aug 21, 2020
d055a65
remove kaiser window
muthuArivoli Aug 27, 2020
30a6a9b
template for other dtypes
muthuArivoli Aug 28, 2020
a43cfba
Merge remote-tracking branch 'upstream/master' into implement-kaiser
muthuArivoli Aug 28, 2020
0b20f1a
fix comments and docs
muthuArivoli Aug 28, 2020
806760d
add ranged tests
muthuArivoli Aug 28, 2020
2043ace
use precision override
muthuArivoli Aug 28, 2020
e2cf848
licensing and documentation updates
muthuArivoli Aug 31, 2020
008248d
compute in given type
muthuArivoli Aug 31, 2020
fde8113
update tests
muthuArivoli Aug 31, 2020
243209b
Revert "compute in given type"
muthuArivoli Sep 1, 2020
8ae978a
compute in float for bfloat
muthuArivoli Sep 1, 2020
23c3c36
Merge remote-tracking branch 'upstream/master' into implement-kaiser
muthuArivoli Sep 1, 2020
8c7e095
Merge remote-tracking branch 'upstream/master' into implement-kaiser
muthuArivoli Sep 1, 2020
1bfd21d
review fixes
muthuArivoli Sep 1, 2020
3b5c0f3
updates
muthuArivoli Sep 2, 2020
087c13e
fix comments
muthuArivoli Sep 2, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions aten/src/ATen/core/NamedRegistrations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
m.impl("hypot", CppFunction::makeFallthrough());
m.impl("hypot.out", CppFunction::makeFallthrough());
m.impl("hypot_", CppFunction::makeFallthrough());
m.impl("i0", CppFunction::makeFallthrough());
m.impl("i0.out", CppFunction::makeFallthrough());
m.impl("i0_", CppFunction::makeFallthrough());
m.impl("imag", CppFunction::makeFallthrough());
m.impl("index_fill.Dimname_Scalar", CppFunction::makeFallthrough());
m.impl("index_fill.Dimname_Tensor", CppFunction::makeFallthrough());
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,8 @@ _(aten, histc) \
_(aten, hspmm) \
_(aten, hstack) \
_(aten, hypot) \
_(aten, i0) \
_(aten, i0_) \
_(aten, ifft) \
_(aten, index) \
_(aten, index_add) \
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/cpu/vec256/vec256_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,9 @@ struct Vec256 {
}
return ret;
}
Vec256<T> i0() const {
return map(calc_i0);
}
Vec256<T> neg() const {
// NB: the trailing return type is needed because we need to coerce the
// return value back to T in the case of unary operator- incuring a
Expand Down
14 changes: 14 additions & 0 deletions aten/src/ATen/cpu/vec256/vec256_bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,20 @@ template <> class Vec256<BFloat16> {
auto o2 = Sleef_hypotf8_u05(hi, b2);
return cvtfp32_bf16(o1, o2);
}
Vec256<BFloat16> i0() const {
__m256 lo, hi;
cvtbf16_fp32(values, lo, hi);
__at_align32__ float tmp1[size() / 2], tmp2[size() / 2];
_mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
_mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
for (int64_t i = 0; i < size() / 2; i++) {
tmp1[i] = calc_i0(tmp1[i]);
tmp2[i] = calc_i0(tmp2[i]);
}
auto o1 = _mm256_loadu_ps(tmp1);
auto o2 = _mm256_loadu_ps(tmp2);
return cvtfp32_bf16(o1, o2);
}
Vec256<BFloat16> log() const {
return map(Sleef_logf8_u10);
}
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/cpu/vec256/vec256_double.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ template <> class Vec256<double> {
Vec256<double> hypot(const Vec256<double> &b) const {
return Vec256<double>(Sleef_hypotd4_u05(values, b));
}
Vec256<double> i0() const {
return map(calc_i0);
}
Vec256<double> log() const {
return Vec256<double>(Sleef_logd4_u10(values));
}
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/cpu/vec256/vec256_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ template <> class Vec256<float> {
Vec256<float> hypot(const Vec256<float> &b) const {
return Vec256<float>(Sleef_hypotf8_u05(values, b));
}
Vec256<float> i0() const {
return map(calc_i0);
}
Vec256<float> neg() const {
return _mm256_xor_ps(_mm256_set1_ps(-0.f), values);
}
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/cpu/vec256/vec256_float_neon.h
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,9 @@ template <> class Vec256<float> {
}
return loadu(tmp);
}
Vec256<float> i0() const {
return map(calc_i0);
}
Vec256<float> log() const {
return map(std::log);
}
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/cpu/vml.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ IMPLEMENT_VML(erfinv)
IMPLEMENT_VML_BUG(exp)
IMPLEMENT_VML_BUG(expm1)
IMPLEMENT_VML_BUG(floor)
IMPLEMENT_VML(i0)
IMPLEMENT_VML(reciprocal)
IMPLEMENT_VML_BUG(log)
IMPLEMENT_VML_BUG(log10)
Expand Down
7 changes: 2 additions & 5 deletions aten/src/ATen/native/Distributions.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,11 +259,8 @@ C10_DEVICE scalar_t sample_binomial(scalar_t count, scalar_t prob, BaseSampler<a
}

/*
* The following function comes with the following copyright notice.
* It has been released under the BSD license.
*
* Cephes Math Library Release 2.8: June, 2000
* Copyright 1984, 1987, 1992, 2000 by Stephen L. Moshier
* This function is derived from the implementation of the digamma function in the Cephes Math Library.
* See note [3-Clause BSD License for the Cephes Math Library] in ATen/native/Math.h.
*/
template<typename scalar_t, typename accscalar_t>
C10_DEVICE static inline scalar_t digamma_one(scalar_t x) {
Expand Down
186 changes: 171 additions & 15 deletions aten/src/ATen/native/Math.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,40 @@ Date: February 1996
#undef CENTRAL_RANGE

/*
* The following function comes with the following copyright notice.
* It has been released under the BSD license.
* Note [3-Clause BSD License for the Cephes Math Library]
* Code derived from implementations in the Cephes Math Library should mention its derivation and reference
* this note (ex. 'This function is derived from the implementation of X in the Cephes Math Library. See note
* [3-Clause BSD License for the Cephes Math Library]. The license is:
* Copyright (c) 2018, Steven Moshier
* All rights reserved.
*
* Cephes Math Library Release 2.8: June, 2000
* Copyright 1984, 1987, 1992, 2000 by Stephen L. Moshier
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL Steven Moshier BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

/*
* This function is derived from the implementation of the zeta function in the Cephes Math Library.
* See note [3-Clause BSD License for the Cephes Math Library].
*/
static inline double zeta(double x, double q) {
static double MACHEP = 1.11022302462515654042E-16;
static double A[] = {
Expand Down Expand Up @@ -244,14 +271,11 @@ static inline float trigamma(float x) {
result += (1 + 1 / (2*x) + ixx * (1.f/6 - ixx * (1.f/30 - ixx * (1.f/42)))) / x;
return sign * result;
}

/*
* The following function comes with the following copyright notice.
* It has been released under the BSD license.
*
* Cephes Math Library Release 2.8: June, 2000
* Copyright 1984, 1987, 1992, 2000 by Stephen L. Moshier
* This function is derived from the implementation of the digamma function in the Cephes Math Library.
* See note [3-Clause BSD License for the Cephes Math Library].
*/

static inline double calc_digamma(double x) {
static double PSI_10 = 2.25175258906672110764;
if (x == 0) {
Expand Down Expand Up @@ -296,11 +320,8 @@ static inline double calc_digamma(double x) {
}

/*
* The following function comes with the following copyright notice.
* It has been released under the BSD license.
*
* Cephes Math Library Release 2.8: June, 2000
* Copyright 1984, 1987, 1992, 2000 by Stephen L. Moshier
* This function is derived from the implementation of the digamma function in the Cephes Math Library.
* See note [3-Clause BSD License for the Cephes Math Library].
*/
static inline float calc_digamma(float x) {
static float PSI_10 = 2.25175258906672110764f;
Expand Down Expand Up @@ -384,3 +405,138 @@ calc_gcd(T a, T b) {
}
return b;
}

/*
* This function is derived from the implementation of the chbevl function in the Cephes Math Library.
* See note [3-Clause BSD License for the Cephes Math Library].
*
* Evaluates the series
*
* len-1
* - '
* y = > array[i] T (x/2)
* - i
* i=0
*
* of Chebyshev polynomials Ti at argument x/2.
*
* Coefficients are stored in reverse order, i.e. the zero order term is last in the array. Note len is the number of
* coefficients, not the order.
*
* If coefficients are for the interval a to b, x must have been transformed to x -> 2(2x - b - a)/(b-a) before
* entering the routine. This maps x from (a, b) to (-1, 1), over which the Chebyshev polynomials are defined.
*
* If the coefficients are for the inverted interval, in which (a, b) is mapped to (1/b, 1/a), the transformation
* required is x -> 2(2ab/x - b - a)/(b-a). If b is infinity, this becomes x -> 4a/x - 1.
*/
template <typename T>
static inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
chbevl(T x, T array[], size_t len) {
T b0, b1, b2;

b0 = array[0];
b1 = static_cast<T>(0.0);

for (size_t i = 1; i < len; ++i) {
b2 = b1;
b1 = b0;
b0 = x * b1 - b2 + array[i];
}

return (static_cast<T>(0.5) * (b0 - b2));
}

/*
* This function is derived from the implementation of the i0 function in the Cephes Math Library.
* See note [3-Clause BSD License for the Cephes Math Library].
*
* Computes an approximation of the zeroth order modified Bessel function of the first kind.
* The approximation is actually two (sub)approximations, both using a Chebyshev polynomial expansion.
* One approximates the function over [0, 8], and the other over (8, infinity). This function takes the absolute value
* of all inputs to convert them into the domain of the approximation.
*/
template <typename T>
static inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
calc_i0(T _x) {
T x = std::abs(_x);
/* Chebyshev coefficients for exp(-x) I0(x)
* in the interval [0,8].
*
* lim(x->0){ exp(-x) I0(x) } = 1.
*/
static T A[] = {
-4.41534164647933937950E-18,
3.33079451882223809783E-17,
-2.43127984654795469359E-16,
1.71539128555513303061E-15,
-1.16853328779934516808E-14,
7.67618549860493561688E-14,
-4.85644678311192946090E-13,
2.95505266312963983461E-12,
-1.72682629144155570723E-11,
9.67580903537323691224E-11,
-5.18979560163526290666E-10,
2.65982372468238665035E-9,
-1.30002500998624804212E-8,
6.04699502254191894932E-8,
-2.67079385394061173391E-7,
1.11738753912010371815E-6,
-4.41673835845875056359E-6,
1.64484480707288970893E-5,
-5.75419501008210370398E-5,
1.88502885095841655729E-4,
-5.76375574538582365885E-4,
1.63947561694133579842E-3,
-4.32430999505057594430E-3,
1.05464603945949983183E-2,
-2.37374148058994688156E-2,
4.93052842396707084878E-2,
-9.49010970480476444210E-2,
1.71620901522208775349E-1,
-3.04682672343198398683E-1,
6.76795274409476084995E-1
};

/* Chebyshev coefficients for exp(-x) sqrt(x) I0(x)
* in the inverted interval [8,infinity].
*
* lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi).
*/
static T B[] = {
-7.23318048787475395456E-18,
-4.83050448594418207126E-18,
4.46562142029675999901E-17,
3.46122286769746109310E-17,
-2.82762398051658348494E-16,
-3.42548561967721913462E-16,
1.77256013305652638360E-15,
3.81168066935262242075E-15,
-9.55484669882830764870E-15,
-4.15056934728722208663E-14,
1.54008621752140982691E-14,
3.85277838274214270114E-13,
7.18012445138366623367E-13,
-1.79417853150680611778E-12,
-1.32158118404477131188E-11,
-3.14991652796324136454E-11,
1.18891471078464383424E-11,
4.94060238822496958910E-10,
3.39623202570838634515E-9,
2.26666899049817806459E-8,
2.04891858946906374183E-7,
2.89137052083475648297E-6,
6.88975834691682398426E-5,
3.36911647825569408990E-3,
8.04490411014108831608E-1
};

if (x <= 8.0) {
T y = (x / 2.0) - 2.0;
return static_cast<T>(std::exp(x) * chbevl(y, A, 30));
}

return static_cast<T>(std::exp(x) * chbevl(static_cast<T>(32.0 / x - 2.0), B, 25) / std::sqrt(x));
}

// Upcast bfloat16 input to float for numerical accuracy purposes
inline c10::BFloat16 calc_i0(c10::BFloat16 a) { return calc_i0(static_cast<float>(a)); }
Copy link
Collaborator

@mruberry mruberry Sep 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a comment explaining the upcasting behavior (and our reasoning for it)

5 changes: 5 additions & 0 deletions aten/src/ATen/native/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,10 @@ Tensor& floor_out(Tensor& result, const Tensor& self) {
Tensor floor(const Tensor& self) { return unary_op_impl(self, at::floor_out); }
Tensor& floor_(Tensor& self) { return unary_op_impl_(self, at::floor_out); }

Tensor& i0_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, i0_stub); }
Tensor i0(const Tensor& self) { return unary_op_impl(self, at::i0_out); }
Tensor& i0_(Tensor& self) { return unary_op_impl_(self, at::i0_out); }

Tensor& log_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, log_stub); }
Tensor log(const Tensor& self) { return unary_op_impl(self, at::log_out); }
Tensor& log_(Tensor& self) { return unary_op_impl_(self, at::log_out); }
Expand Down Expand Up @@ -607,6 +611,7 @@ DEFINE_DISPATCH(exp_stub);
DEFINE_DISPATCH(expm1_stub);
DEFINE_DISPATCH(floor_stub);
DEFINE_DISPATCH(frac_stub);
DEFINE_DISPATCH(i0_stub);
DEFINE_DISPATCH(log_stub);
DEFINE_DISPATCH(log10_stub);
DEFINE_DISPATCH(log1p_stub);
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/UnaryOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ DECLARE_DISPATCH(unary_fn, exp_stub);
DECLARE_DISPATCH(unary_fn, expm1_stub);
DECLARE_DISPATCH(unary_fn, floor_stub);
DECLARE_DISPATCH(unary_fn, frac_stub);
DECLARE_DISPATCH(unary_fn, i0_stub);
DECLARE_DISPATCH(unary_fn, log_stub);
DECLARE_DISPATCH(unary_fn, log10_stub);
DECLARE_DISPATCH(unary_fn, log1p_stub);
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,7 @@ IMPLEMENT_COMPLEX_KERNEL(FLOATING, log)
IMPLEMENT_COMPLEX_KERNEL(FLOATING, log10)
IMPLEMENT_FLOAT_KERNEL(FLOATING, log1p)
IMPLEMENT_COMPLEX_KERNEL(FLOATING, log2)
IMPLEMENT_FLOAT_KERNEL(FLOATING, i0)
IMPLEMENT_COMPLEX_KERNEL(FLOATING, round)
IMPLEMENT_COMPLEX_KERNEL(FLOATING, sin)
IMPLEMENT_COMPLEX_KERNEL(FLOATING, sqrt)
Expand Down
Loading