Skip to content

Commit c2ab94f

Browse files
committed
changes
1 parent 0885dd2 commit c2ab94f

File tree

4 files changed

+26
-6
lines changed

4 files changed

+26
-6
lines changed

aten/src/ATen/Declarations.cwrap

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,7 +1193,6 @@
11931193
types:
11941194
- floating_point
11951195
backends:
1196-
- CPU
11971196
- CUDA
11981197
variants: function
11991198
return: argument 0
@@ -1236,7 +1235,6 @@
12361235
types:
12371236
- floating_point
12381237
backends:
1239-
- CPU
12401238
- CUDA
12411239
variants: function
12421240
return: argument 0

aten/src/ATen/native/UnaryOps.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ IMPLEMENT_UNARY_OP_VEC(asin)
181181
IMPLEMENT_UNARY_OP_VEC(atan)
182182
IMPLEMENT_UNARY_OP_VEC(ceil)
183183
IMPLEMENT_UNARY_OP_VEC(cos)
184-
IMPLEMENT_UNARY_OP_TH(cosh)
184+
IMPLEMENT_UNARY_OP_VEC(cosh)
185185
IMPLEMENT_UNARY_OP_VEC(erf)
186186
IMPLEMENT_UNARY_OP_VEC(erfc)
187187
IMPLEMENT_UNARY_OP_VEC(exp)
@@ -197,7 +197,7 @@ IMPLEMENT_UNARY_OP_VEC(reciprocal)
197197
IMPLEMENT_UNARY_OP_VEC(round)
198198
IMPLEMENT_UNARY_OP_VEC(rsqrt)
199199
IMPLEMENT_UNARY_OP_VEC(sin)
200-
IMPLEMENT_UNARY_OP_TH(sinh)
200+
IMPLEMENT_UNARY_OP_VEC(sinh)
201201
IMPLEMENT_UNARY_OP_VEC(sqrt)
202202
IMPLEMENT_UNARY_OP_VEC(tan)
203203
IMPLEMENT_UNARY_OP_VEC(tanh)
@@ -209,6 +209,7 @@ DEFINE_DISPATCH(asin_stub);
209209
DEFINE_DISPATCH(atan_stub);
210210
DEFINE_DISPATCH(ceil_stub);
211211
DEFINE_DISPATCH(cos_stub);
212+
DEFINE_DISPATCH(cosh_stub); // struct cosh_stub cosh_stub
212213
DEFINE_DISPATCH(erf_stub);
213214
DEFINE_DISPATCH(erfc_stub);
214215
DEFINE_DISPATCH(exp_stub);
@@ -225,6 +226,7 @@ DEFINE_DISPATCH(round_stub);
225226
DEFINE_DISPATCH(rsqrt_stub);
226227
DEFINE_DISPATCH(sigmoid_stub);
227228
DEFINE_DISPATCH(sin_stub);
229+
DEFINE_DISPATCH(sinh_stub);
228230
DEFINE_DISPATCH(sqrt_stub);
229231
DEFINE_DISPATCH(tan_stub);
230232
DEFINE_DISPATCH(tanh_stub);

aten/src/ATen/native/UnaryOps.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ DECLARE_DISPATCH(unary_fn, asin_stub);
1919
DECLARE_DISPATCH(unary_fn, atan_stub);
2020
DECLARE_DISPATCH(unary_fn, ceil_stub);
2121
DECLARE_DISPATCH(unary_fn, cos_stub);
22-
// DECLARE_DISPATCH(unary_fn, cosh_stub);
22+
DECLARE_DISPATCH(unary_fn, cosh_stub);
2323
DECLARE_DISPATCH(unary_fn, erf_stub);
2424
DECLARE_DISPATCH(unary_fn, erfc_stub);
2525
DECLARE_DISPATCH(unary_fn, exp_stub);
@@ -36,7 +36,7 @@ DECLARE_DISPATCH(unary_fn, round_stub);
3636
DECLARE_DISPATCH(unary_fn, rsqrt_stub);
3737
DECLARE_DISPATCH(unary_fn, sigmoid_stub);
3838
DECLARE_DISPATCH(unary_fn, sin_stub);
39-
// DECLARE_DISPATCH(unary_fn, sinh_stub);
39+
DECLARE_DISPATCH(unary_fn, sinh_stub);
4040
DECLARE_DISPATCH(unary_fn, sqrt_stub);
4141
DECLARE_DISPATCH(unary_fn, tan_stub);
4242
DECLARE_DISPATCH(unary_fn, tanh_stub);

aten/src/ATen/native/cpu/UnaryOpsKernel.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,24 @@ static void rsqrt_kernel(TensorIterator& iter) {
169169
});
170170
}
171171

172+
static void sinh_kernel(TensorIterator& iter) {
173+
AT_DISPATCH_ALL_TYPES(iter.dtype(), "sinh_cpu", [&]() {
174+
unary_kernel_vec(
175+
iter,
176+
[=](scalar_t a) -> scalar_t { return std::sinh(a); },
177+
[=](Vec256<scalar_t> a) { return a.sinh(); });
178+
});
179+
}
180+
181+
static void cosh_kernel(TensorIterator& iter) {
182+
AT_DISPATCH_ALL_TYPES(iter.dtype(), "cosh_cpu", [&]() {
183+
unary_kernel_vec(
184+
iter,
185+
[=](scalar_t a) -> scalar_t { return std::cosh(a); },
186+
[=](Vec256<scalar_t> a) { return a.cosh(); });
187+
});
188+
}
189+
172190
// TODO: Disable cont. branch to test more risky code
173191

174192
#define IMPLEMENT_FLOAT_KERNEL(dispatchtypes, op) \
@@ -212,6 +230,8 @@ REGISTER_DISPATCH(frac_stub, &frac_kernel);
212230
REGISTER_DISPATCH(reciprocal_stub, &reciprocal_kernel);
213231
REGISTER_DISPATCH(neg_stub, &neg_kernel);
214232
REGISTER_DISPATCH(fill_stub, &fill_kernel);
233+
REGISTER_DISPATCH(sinh_stub, &sinh_kernel);
234+
REGISTER_DISPATCH(cosh_stub, &cosh_kernel);
215235

216236
// IMPLEMENT_FLOAT_KERNEL(ALL, abs)
217237
IMPLEMENT_FLOAT_KERNEL(FLOATING, acos)

0 commit comments

Comments
 (0)