Skip to content

Commit 97c8e17

Browse files
vsx: complex types use std::sqrt
1 parent 8e1bc4d commit 97c8e17

File tree

2 files changed

+2
-23
lines changed

2 files changed

+2
-23
lines changed

aten/src/ATen/cpu/vec256/vsx/vec256_complex_double_vsx.h

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -364,17 +364,7 @@ class Vec256<ComplexDbl> {
364364
}
365365

366366
Vec256<ComplexDbl> sqrt() const {
367-
// sqrt(a + bi)
368-
// = sqrt(2)/2 * [sqrt(sqrt(a**2 + b**2) + a) + sgn(b)*sqrt(sqrt(a**2 +
369-
// b**2) - a)i] = sqrt(2)/2 * [sqrt(abs() + a) + sgn(b)*sqrt(abs() - a)i]
370-
371-
auto sign = *this & vd_isign_mask;
372-
auto factor = sign | vd_sqrt2_2;
373-
auto a_a = el_mergee();
374-
// a_a.dump();
375-
a_a = a_a ^ vd_isign_mask; // a -a
376-
auto res_re_im = (abs_() + a_a).elwise_sqrt(); // sqrt(abs + a) sqrt(abs - a)
377-
return factor.elwise_mult(res_re_im);
367+
return map(std::sqrt);
378368
}
379369

380370
Vec256<ComplexDbl> reciprocal() const {

aten/src/ATen/cpu/vec256/vsx/vec256_complex_float_vsx.h

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -417,18 +417,7 @@ class Vec256<ComplexFlt> {
417417
}
418418

419419
Vec256<ComplexFlt> sqrt() const {
420-
// sqrt(a + bi)
421-
// = sqrt(2)/2 * [sqrt(sqrt(a**2 + b**2) + a) + sgn(b)*sqrt(sqrt(a**2 +
422-
// b**2) - a)i] = sqrt(2)/2 * [sqrt(abs() + a) + sgn(b)*sqrt(abs() - a)i]
423-
424-
auto sign = *this & isign_mask;
425-
auto factor = sign | sqrt2_2;
426-
auto a_a = el_mergee();
427-
// a_a.dump();
428-
a_a = a_a ^ isign_mask; // a -a
429-
auto res_re_im =
430-
(abs_() + a_a).elwise_sqrt(); // sqrt(abs + a) sqrt(abs - a)
431-
return factor.elwise_mult(res_re_im);
420+
return map(std::sqrt);
432421
}
433422

434423
Vec256<ComplexFlt> reciprocal() const {

0 commit comments

Comments
 (0)