Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions aten/src/ATen/CPUApplyUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,16 +253,15 @@ apply_op(int64_t numel, int64_t offset, const Op& op, Args... iters) {
}
}


inline void apply_kernel(){};

// TODO: Deal elegantly with 0-dim tensors. iters.strides_ of 0-dim
// strided_tensor_iter will be of size 0 for dim 0 and iters.strides_[iters.dim_
// - 1] will index at -1. C++14 integer_sequence could be of use here.
template <typename Op, typename... Args>
inline void
apply_kernel(int64_t numel, int64_t offset, const Op& op, Args... iters) {
// For 0-dim tensors
if (numel == 1 && max_dim(iters...) == 0) {
op(1, iters.data_..., iters.strides_[iters.dim_ - 1]...);
return;
}
if (offset > 0)
forward(offset, iters...);
int64_t size = std::min(numel, max_iterate_size(iters...));
Expand All @@ -284,6 +283,10 @@ inline void
CPU_tensor_parallel_kernel_apply2(Tensor tensor1, Tensor tensor2, const Op op) {
if (!_apply_preamble({tensor1, tensor2}))
return;
if (tensor1.numel() == 1) {
op(1, tensor1.data<scalar1>(), tensor2.data<scalar2>(), 0, 0);
return;
}
if (tensor1.ndimension() < 8 && tensor2.ndimension() < 8) {
parallel_for(
0,
Expand Down
16 changes: 1 addition & 15 deletions aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -1090,24 +1090,10 @@
- THTensor* self
]]
[[
name: sigmoid_
name: _th_sigmoid
types:
- floating_point
backends:
- CPU
- CUDA
cname: sigmoid
return: self
arguments:
- THTensor* self
- THTensor* self
]]
[[
name: sigmoid
types:
- floating_point
backends:
- CPU
- CUDA
cname: sigmoid
variants:
Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/cpu/vec256/vec256_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ struct Vec256 {
Vec256<T> floor() const {
return map(std::floor);
}
Vec256<T> neg() const {
return map([](T x) { return -x; });
}
Vec256<T> round() const {
return map(std::round);
}
Expand All @@ -146,6 +149,9 @@ struct Vec256 {
Vec256<T> sqrt() const {
return map(std::sqrt);
}
Vec256<T> reciprocal() const {
return map([](T x) { return (T)(1) / x; });
}
Vec256<T> rsqrt() const {
return map([](T x) { return 1 / std::sqrt(x); });
}
Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/cpu/vec256/vec256_double.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ template <> class Vec256<double> {
Vec256<double> floor() const {
return _mm256_floor_pd(values);
}
Vec256<double> neg() const {
return _mm256_xor_pd(_mm256_set1_pd(-0.), values);
}
Vec256<double> round() const {
return _mm256_round_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
}
Expand All @@ -136,6 +139,9 @@ template <> class Vec256<double> {
Vec256<double> sqrt() const {
return _mm256_sqrt_pd(values);
}
Vec256<double> reciprocal() const {
return _mm256_div_pd(_mm256_set1_pd(1), values);
}
Vec256<double> rsqrt() const {
return _mm256_div_pd(_mm256_set1_pd(1), _mm256_sqrt_pd(values));
}
Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/cpu/vec256/vec256_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ template <> class Vec256<float> {
Vec256<float> floor() const {
return _mm256_floor_ps(values);
}
Vec256<float> neg() const {
return _mm256_xor_ps(_mm256_set1_ps(-0.f), values);
}
Vec256<float> round() const {
return _mm256_round_ps(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
}
Expand All @@ -141,6 +144,9 @@ template <> class Vec256<float> {
Vec256<float> sqrt() const {
return _mm256_sqrt_ps(values);
}
Vec256<float> reciprocal() const {
return _mm256_div_ps(_mm256_set1_ps(1), values);
}
Vec256<float> rsqrt() const {
return _mm256_div_ps(_mm256_set1_ps(1), _mm256_sqrt_ps(values));
}
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/cpu/vec256/vec256_int.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ struct Vec256<int64_t> : public Vec256i {
__at_align32__ int64_t tmp_values[size];
a.store(tmp_values);
if (mask & 0x01)
tmp_values[0] = _mm256_extract_epi16(b.values, 0);
tmp_values[0] = _mm256_extract_epi64(b.values, 0);
if (mask & 0x02)
tmp_values[1] = _mm256_extract_epi16(b.values, 1);
tmp_values[1] = _mm256_extract_epi64(b.values, 1);
if (mask & 0x04)
tmp_values[2] = _mm256_extract_epi16(b.values, 2);
tmp_values[2] = _mm256_extract_epi64(b.values, 2);
if (mask & 0x08)
tmp_values[3] = _mm256_extract_epi16(b.values, 3);
tmp_values[3] = _mm256_extract_epi64(b.values, 3);
return loadu(tmp_values);
}
static Vec256<int64_t>
Expand Down
150 changes: 98 additions & 52 deletions aten/src/ATen/cpu/vml.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,21 @@
// This header implements various unary operations using a MKL VML style
// interface.

// It implements various functions with a simple interface
// For example it enables the user to call vsin(float* out, const float* in,
// size) This functions takes a pointer to a contious output array of floats and
// a constant input array. It will then apply sin to each value in in the input
// array and write the result into the output array. out and in may point to the
// same memory, i.e. this fully supports in-place operations. These functions
// also implement their own parallelization, so take precautions when calling
// these from threaded functions.

// When MKL is available it will call into MKL's VML library similar to NumPy
// If MKL is not available it will use SLEEF.

// This file might be compiled under AVX or AVX2 when called from e.g.
// UnaryOpsKernel.cpp

#include <algorithm>
#include <cstddef>
#include <cstdint>
Expand All @@ -16,7 +31,19 @@

#if AT_MKL_ENABLED() && !defined(__APPLE__)
#include <mkl.h>
#include <mkl_vml.h>
#endif

// [Note SSE-AVX transitions]
// There is a bug in Glibc2.23
// https://bugs.launchpad.net/ubuntu/+source/glibc/+bug/1663280. Calling zeroall
// when using AVX/AVX2 code resolves this.
#if defined(__AVX__) && defined(__GLIBC__) && __GLIBC_MINOR__ == 23
#define DL_RUNTIME_BUG(op, type) \
volatile type x = (type)(1); \
x = std::op(x); \
_mm256_zeroall();
#else
#define DL_RUNTIME_BUG(op, type)
#endif

namespace at {
Expand All @@ -40,9 +67,16 @@ inline void vrsqrt(scalar_t* out, scalar_t* in, int64_t size) {

// NB: We ignore numerical errors by convention and leave them to the user

#define IMPLEMENT_VML(op) \
// We unfortunately need to duplicate code here to deal with the SSE-AVX
// transition bug (see [Note SSE-AVX transitions]). As soon as we can expect
// users to use a version of glibc newer than 2.23 we will be able to ditch
// this. This duplication is also necessary since not all functions (e.g. rsqrt)
// might be part of cmath.

#define IMPLEMENT_VML_BUG(op) \
template <typename scalar_t> \
inline void v##op(scalar_t* out, scalar_t* in, int64_t size) { \
inline void v##op(scalar_t* out, const scalar_t* in, int64_t size) { \
DL_RUNTIME_BUG(op, scalar_t) \
parallel_for(0, size, 2048, [out, in](int64_t begin, int64_t end) { \
map([](const Vec256<scalar_t>& x) { return x.op(); }, \
out + begin, \
Expand All @@ -51,70 +85,82 @@ inline void vrsqrt(scalar_t* out, scalar_t* in, int64_t size) {
}); \
}

#define IMPLEMENT_FLOAT_MKL_VML(op, mklop) \
template <typename scalar_t> \
inline void v##op(scalar_t* out, scalar_t* in, int64_t size); \
#define IMPLEMENT_VML(op) \
template <typename scalar_t> \
inline void v##op(scalar_t* out, const scalar_t* in, int64_t size) { \
parallel_for(0, size, 2048, [out, in](int64_t begin, int64_t end) { \
map([](const Vec256<scalar_t>& x) { return x.op(); }, \
out + begin, \
in + begin, \
end - begin); \
}); \
}

IMPLEMENT_VML_BUG(abs)
IMPLEMENT_VML_BUG(acos)
IMPLEMENT_VML_BUG(asin)
IMPLEMENT_VML_BUG(atan)
IMPLEMENT_VML_BUG(ceil)
IMPLEMENT_VML_BUG(cos)
// IMPLEMENT_VML_BUG(cosh)
IMPLEMENT_VML_BUG(erf)
IMPLEMENT_VML_BUG(exp)
IMPLEMENT_VML_BUG(expm1)
IMPLEMENT_VML_BUG(floor)
IMPLEMENT_VML(reciprocal)
IMPLEMENT_VML_BUG(log)
IMPLEMENT_VML_BUG(log10)
IMPLEMENT_VML_BUG(log1p)
IMPLEMENT_VML_BUG(log2)
IMPLEMENT_VML(neg)
IMPLEMENT_VML_BUG(sin)
// IMPLEMENT_VML_BUG(sinh)
IMPLEMENT_VML_BUG(sqrt)
IMPLEMENT_VML_BUG(round)
IMPLEMENT_VML(rsqrt)
IMPLEMENT_VML_BUG(tan)
IMPLEMENT_VML_BUG(tanh)
IMPLEMENT_VML_BUG(trunc)

#if AT_MKL_ENABLED() && !defined(__APPLE__)

#define IMPLEMENT_VML_MKL(op, mklop) \
template <> \
inline void v##op(float* out, float* in, int64_t size) { \
inline void v##op(float* out, const float* in, int64_t size) { \
vms##mklop(size, in, out, VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
} \
template <> \
inline void v##op(double* out, double* in, int64_t size) { \
inline void v##op(double* out, const double* in, int64_t size) { \
vmd##mklop(size, in, out, VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
}

// NB: abs, cosh and sinh were temporarily disabled due to issues with Apple clang

#if AT_MKL_ENABLED() && !defined(__APPLE__)
IMPLEMENT_FLOAT_MKL_VML(acos, Acos)
IMPLEMENT_FLOAT_MKL_VML(asin, Asin)
IMPLEMENT_FLOAT_MKL_VML(atan, Atan)
IMPLEMENT_FLOAT_MKL_VML(cos, Cos)
// IMPLEMENT_FLOAT_MKL_VML(cosh, Cosh)
IMPLEMENT_FLOAT_MKL_VML(erf, Erf)
IMPLEMENT_FLOAT_MKL_VML(exp, Exp)
IMPLEMENT_FLOAT_MKL_VML(expm1, Expm1)
IMPLEMENT_FLOAT_MKL_VML(log, Ln)
IMPLEMENT_FLOAT_MKL_VML(log10, Log10)
IMPLEMENT_FLOAT_MKL_VML(log1p, Log1p)
IMPLEMENT_FLOAT_MKL_VML(sin, Sin)
// IMPLEMENT_FLOAT_MKL_VML(sinh, Sinh)
IMPLEMENT_FLOAT_MKL_VML(sqrt, Sqrt)
IMPLEMENT_FLOAT_MKL_VML(tan, Tan)
IMPLEMENT_FLOAT_MKL_VML(tanh, Tanh)
IMPLEMENT_FLOAT_MKL_VML(trunc, Trunc)
IMPLEMENT_VML_MKL(abs, Abs)
IMPLEMENT_VML_MKL(acos, Acos)
IMPLEMENT_VML_MKL(asin, Asin)
IMPLEMENT_VML_MKL(atan, Atan)
IMPLEMENT_VML_MKL(cos, Cos)
// IMPLEMENT_VML_MKL(cosh, Cosh)
IMPLEMENT_VML_MKL(erf, Erf)
IMPLEMENT_VML_MKL(exp, Exp)
IMPLEMENT_VML_MKL(expm1, Expm1)
IMPLEMENT_VML_MKL(log, Ln)
IMPLEMENT_VML_MKL(log10, Log10)
IMPLEMENT_VML_MKL(log1p, Log1p)
IMPLEMENT_VML_MKL(sin, Sin)
// IMPLEMENT_VML_MKL(sinh, Sinh)
IMPLEMENT_VML_MKL(sqrt, Sqrt)
IMPLEMENT_VML_MKL(tan, Tan)
IMPLEMENT_VML_MKL(tanh, Tanh)
IMPLEMENT_VML_MKL(trunc, Trunc)

#if INTEL_MKL_VERSION >= 20180406
IMPLEMENT_FLOAT_MKL_VML(log2, Log2)
#else
IMPLEMENT_VML(log2)
IMPLEMENT_VML_MKL(log2, Log2)
#endif

#else
IMPLEMENT_VML(acos)
IMPLEMENT_VML(asin)
IMPLEMENT_VML(atan)
IMPLEMENT_VML(cos)
// IMPLEMENT_VML(cosh)
IMPLEMENT_VML(erf)
IMPLEMENT_VML(exp)
IMPLEMENT_VML(expm1)
IMPLEMENT_VML(log)
IMPLEMENT_VML(log10)
IMPLEMENT_VML(log1p)
IMPLEMENT_VML(log2)
IMPLEMENT_VML(sin)
// IMPLEMENT_VML(sinh)
IMPLEMENT_VML(sqrt)
IMPLEMENT_VML(tan)
IMPLEMENT_VML(tanh)
#endif

IMPLEMENT_VML(ceil)
IMPLEMENT_VML(floor)
IMPLEMENT_VML(round)
IMPLEMENT_VML(trunc)

} // namespace
} // namespace vml
} // namespace at
3 changes: 2 additions & 1 deletion aten/src/ATen/native/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Tensor& fill_(Tensor& self, const Tensor& value) {
// NB: If you use this macro, you may also need to add a CUDA forwarding
// stub in CUDAUnaryOps

#define IMPLEMENT_UNARY_OP_VEC(op) \
#define IMPLEMENT_UNARY_OP_VEC(op) \
Tensor op(const Tensor& self) { \
Tensor result = self.type().tensor(); \
return at::op##_out(result, self); \
Expand Down Expand Up @@ -87,6 +87,7 @@ IMPLEMENT_UNARY_OP_VEC(log1p)
IMPLEMENT_UNARY_OP_VEC(log2)
IMPLEMENT_UNARY_OP_VEC(round)
IMPLEMENT_UNARY_OP_VEC(rsqrt)
IMPLEMENT_UNARY_OP_VEC(sigmoid)
IMPLEMENT_UNARY_OP_VEC(sin)
IMPLEMENT_UNARY_OP_TH(sinh)
IMPLEMENT_UNARY_OP_VEC(sqrt)
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/native/cpu/CapabilityDispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ struct DispatchStub {
#ifndef __powerpc__
if (cpuinfo_initialize()) {
int avx2 = static_cast<int>(CPUCapability::AVX2);
if (!std::getenv("ATEN_DISABLE_AVX2") && cpuinfo_has_x86_avx2() && table[avx2]) {
if (!std::getenv("ATEN_DISABLE_AVX2") && cpuinfo_has_x86_avx2() &&
cpuinfo_has_x86_fma3() && table[avx2]) {
return table[avx2];
}
int avx = static_cast<int>(CPUCapability::AVX);
Expand Down
Loading