-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[jiterator] reduce kernel code duplication #73908
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[jiterator] reduce kernel code duplication #73908
Conversation
CI Flow Status⚛️ CI FlowRuleset - Version:
|
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 07dc80a (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
| @@ -0,0 +1,33 @@ | |||
| #pragma once | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to a new file as jit_macros.h includes CUDAConfig.h which is only available in CUDA build.
lezcano
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice to see the direction this is taking! Now, I have one question really. I don't see the equivalent of some bits of code that were removed in this PR. Where did those go?
I also added a small performance nit that could be fixed in this PR or a separate PR.
aten/src/ATen/native/Math.h
Outdated
| T x = fabs(_x); | ||
|
|
||
| if (x <= T{8.0}) { | ||
| T coefficients[] = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit unrelated to this PR. Make this constexp or, at the very least, static (provided cuda allows..)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. Thanks!
| /* | ||
| * This function is derived from the implementation of the i0e function in the Cephes Math Library. | ||
| * See note [3-Clause BSD License for the Cephes Math Library]. | ||
| * | ||
| * Computes an approximation of the exponentially scaled zeroth order modified Bessel function of the first kind. | ||
| * The approximation is actually two (sub)approximations, both using a Chebyshev polynomial expansion. | ||
| * One approximates the function over [0, 8], and the other over (8, infinity). This function takes the absolute value | ||
| * of all inputs to convert them into the domain of the approximation. | ||
| */ | ||
| template <typename T> | ||
| static inline typename std::enable_if<std::is_floating_point<T>::value, T>::type | ||
| calc_i0e(T _x) { | ||
| T x = std::abs(_x); | ||
|
|
||
| if (x <= T{8.0}) { | ||
| auto coeff_pair = chebyshev_coefficients_i0e_A<T>(); | ||
| auto A = std::get<0>(coeff_pair); | ||
| auto len = std::get<1>(coeff_pair); | ||
| T y = (x / T{2.0}) - T{2.0}; | ||
| return chbevl(y, A, len); | ||
| } | ||
|
|
||
| auto coeff_pair = chebyshev_coefficients_i0e_B<T>(); | ||
| auto B = std::get<0>(coeff_pair); | ||
| auto len = std::get<1>(coeff_pair); | ||
| return chbevl(T{32.0} / x - T{2.0}, B, len) / std::sqrt(x); | ||
| } | ||
|
|
||
| // Upcast bfloat16 input to float for numerical accuracy purposes | ||
| static inline c10::BFloat16 calc_i0e(c10::BFloat16 a) { return calc_i0e(static_cast<float>(a)); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where did all this code go?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Jiterated version of the code has the coefficients (which we get from chebyshev_coefficients_i0e_B and chebyshev_coefficients_i0e_A)
And BFloat16 upcasting here was redundant as it is handled at other places.
Existing tests in test_unary_ufuncs verify the correctness against scipy implementation for CPU and CUDA.
| template <typename scalar_t> | ||
| static inline C10_HOST_DEVICE scalar_t calc_i0e(scalar_t _x) { | ||
| static_assert(!std::is_same<scalar_t, Half>() && !std::is_same<scalar_t, BFloat16>(), "don't instantiate with low precision type"); | ||
| scalar_t x = ::abs(_x); | ||
| if (x <= scalar_t{8.0}) { | ||
| auto coeff_pair = chebyshev_coefficients_i0e_A<scalar_t>(); | ||
| auto A = std::get<0>(coeff_pair); | ||
| auto len = std::get<1>(coeff_pair); | ||
| scalar_t y = (x / scalar_t{2.0}) - scalar_t{2.0}; | ||
| return (chbevl(y, A, len)); | ||
| } | ||
|
|
||
| auto coeff_pair = chebyshev_coefficients_i0e_B<scalar_t>(); | ||
| auto B = std::get<0>(coeff_pair); | ||
| auto len = std::get<1>(coeff_pair); | ||
| return (chbevl(scalar_t{32.0} / x - scalar_t{2.0}, B, len) / ::sqrt(x)); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same, I don't see what's the equivalent of this code in the new PR.
bc72491 to
6f06a62
Compare
aten/src/ATen/native/Math.h
Outdated
| * function takes the absolute value of all inputs to convert them into the | ||
| * domain of the approximation. | ||
| */ | ||
| jiterator_code_stringify( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the future I would write jiterator_code( on the same line as jiterator_code_stringify to avoid the extra level of indentation
aten/src/ATen/jiterator_macros.h
Outdated
| #if defined(__CUDACC__) | ||
| // CPU and CUDA case | ||
| #define stringify_code(...) #__VA_ARGS__ | ||
| #define jiterator_code_stringify(code, str_name) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
naming suggestion: jiterator_also_stringify_as
that might make the fact that the code is preserved clearer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure.
|
Looks pretty good to me -- any reason this is still in draft? cc @anjali411 |
Forgot to mark it as ready 😅 |
No worries -- just tweak the name and ping me when this is ready to merge |
|
@mruberry have addressed the review. Should be ready once the CI is green. Thanks :)! |
|
@mruberry has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
Ping @mruberry |
|
Ping @mruberry |
|
@pytorchbot merge this please |
|
Merge failed due to PR 73908 does not match merge rules |
Summary: Introduce `jiterator_code_stringify` to reduce duplication of kernel code used with jiterator. Pull Request resolved: #73908 Reviewed By: ngimel Differential Revision: D34858716 Pulled By: mruberry fbshipit-source-id: f87a34e4966b31620bbc5c7d93f0387fc1980ded
|
Hey @kshitij12345. |
Introduce
jiterator_code_stringifyto reduce duplication of kernel code used with jiterator.