Skip to content

Commit f73bb16

Browse files
fmassasoumith
authored andcommitted
Optimize pow for different exponents and add tests
1 parent 76a6529 commit f73bb16

File tree

1 file changed

+29
-8
lines changed

1 file changed

+29
-8
lines changed

lib/TH/generic/THTensorMath.c

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2856,13 +2856,6 @@ TENSOR_IMPLEMENT_LOGICAL(ne,!=)
28562856
TH_TENSOR_APPLY2(real, t, real, r_, *r__data = CFUNC(*t_data);); \
28572857
} \
28582858

2859-
#define LAB_IMPLEMENT_BASIC_FUNCTION_VALUE(NAME, CFUNC) \
2860-
void THTensor_(NAME)(THTensor *r_, THTensor *t, real value) \
2861-
{ \
2862-
THTensor_(resizeAs)(r_, t); \
2863-
TH_TENSOR_APPLY2(real, t, real, r_, *r__data = CFUNC(*t_data, value);); \
2864-
} \
2865-
28662859
#if defined(TH_REAL_IS_LONG)
28672860
LAB_IMPLEMENT_BASIC_FUNCTION(abs,labs)
28682861
LAB_IMPLEMENT_BASIC_FUNCTION(neg,-)
@@ -2912,7 +2905,6 @@ LAB_IMPLEMENT_BASIC_FUNCTION(sinh,TH_MATH_NAME(sinh))
29122905
LAB_IMPLEMENT_BASIC_FUNCTION(tan,TH_MATH_NAME(tan))
29132906
LAB_IMPLEMENT_BASIC_FUNCTION(atan,TH_MATH_NAME(atan))
29142907
LAB_IMPLEMENT_BASIC_FUNCTION(tanh,TH_MATH_NAME(tanh))
2915-
LAB_IMPLEMENT_BASIC_FUNCTION_VALUE(pow,TH_MATH_NAME(pow))
29162908
LAB_IMPLEMENT_BASIC_FUNCTION(sqrt,TH_MATH_NAME(sqrt))
29172909
LAB_IMPLEMENT_BASIC_FUNCTION(rsqrt,TH_MATH_NAME(TH_rsqrt))
29182910
LAB_IMPLEMENT_BASIC_FUNCTION(ceil,TH_MATH_NAME(ceil))
@@ -2925,6 +2917,35 @@ LAB_IMPLEMENT_BASIC_FUNCTION(neg,-)
29252917
LAB_IMPLEMENT_BASIC_FUNCTION(cinv, TH_MATH_NAME(1.0) / )
29262918

29272919

2920+
void THTensor_(pow)(THTensor *r_, THTensor *t, real value)
2921+
{
2922+
THTensor_(resizeAs)(r_, t);
2923+
if(value == 1){
2924+
THTensor_(copy)(r_, t);
2925+
}
2926+
else if(value == 2){
2927+
THTensor_(cmul)(r_, t, t);
2928+
}
2929+
else if(value == 3){
2930+
TH_TENSOR_APPLY2(real, t, real, r_, *r__data = *t_data * *t_data * *t_data;);
2931+
}
2932+
else if(value == 0.5){
2933+
THTensor_(sqrt)(r_, t);
2934+
}
2935+
else if(value == -0.5){
2936+
THTensor_(rsqrt)(r_, t);
2937+
}
2938+
else if(value == -1){
2939+
THTensor_(cinv)(r_, t);
2940+
}
2941+
else if(value == -2){
2942+
TH_TENSOR_APPLY2(real, t, real, r_, *r__data = TH_MATH_NAME(1.0) / (*t_data * *t_data););
2943+
}
2944+
else{
2945+
TH_TENSOR_APPLY2(real, t, real, r_, *r__data = TH_MATH_NAME(pow)(*t_data, value););
2946+
}
2947+
}
2948+
29282949
void THTensor_(atan2)(THTensor *r_, THTensor *tx, THTensor *ty)
29292950
{
29302951
THTensor_(resizeAs)(r_, tx);

0 commit comments

Comments
 (0)