Skip to content

Commit fb3a025

Browse files
committed
Add MKL and CUFFT helpers
1 parent ec9009d commit fb3a025

File tree

5 files changed

+162
-0
lines changed

5 files changed

+162
-0
lines changed

aten/src/ATen/mkl/Descriptors.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#pragma once
2+
3+
#include "Exceptions.h"
4+
#include <mkl_dfti.h>
5+
#include <ATen/Tensor.h>
6+
7+
namespace at { namespace native {
8+
9+
struct DftiDescriptorDeleter {
10+
void operator()(DFTI_DESCRIPTOR* desc) {
11+
if (desc != nullptr) {
12+
MKL_DFTI_CHECK(DftiFreeDescriptor(&desc));
13+
}
14+
}
15+
};
16+
17+
class DftiDescriptor {
18+
public:
19+
void init(DFTI_CONFIG_VALUE precision, DFTI_CONFIG_VALUE signal_type, MKL_LONG signal_ndim, MKL_LONG* sizes) {
20+
if (desc_ != nullptr) {
21+
throw std::runtime_error("DFTI DESCRIPTOR can only be initialized once");
22+
}
23+
DFTI_DESCRIPTOR *raw_desc;
24+
if (signal_ndim == 1) {
25+
MKL_DFTI_CHECK(DftiCreateDescriptor(&raw_desc, precision, signal_type, 1, sizes[0]));
26+
} else {
27+
MKL_DFTI_CHECK(DftiCreateDescriptor(&raw_desc, precision, signal_type, signal_ndim, sizes));
28+
}
29+
desc_.reset(raw_desc);
30+
}
31+
32+
DFTI_DESCRIPTOR *get() const {
33+
if (desc_ == nullptr) {
34+
throw std::runtime_error("DFTI DESCRIPTOR has not been initialized");
35+
}
36+
return desc_.get();
37+
}
38+
39+
private:
40+
std::unique_ptr<DFTI_DESCRIPTOR, DftiDescriptorDeleter> desc_;
41+
};
42+
43+
44+
}} // at::native

aten/src/ATen/mkl/Exceptions.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#pragma once
2+
3+
#include <string>
4+
#include <stdexcept>
5+
#include <sstream>
6+
#include <mkl_dfti.h>
7+
8+
namespace at { namespace native {
9+
10+
static inline void MKL_DFTI_CHECK(MKL_INT status)
11+
{
12+
if (status && !DftiErrorClass(status, DFTI_NO_ERROR)) {
13+
std::ostringstream ss;
14+
ss << "MKL FFT error: " << DftiErrorMessage(status);
15+
throw std::runtime_error(ss.str());
16+
}
17+
}
18+
19+
}} // namespace at::native

aten/src/ATen/mkl/Limits.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#pragma once
2+
3+
#include <mkl_types.h>
4+
5+
namespace at { namespace native {
6+
7+
// Since size of MKL_LONG varies on different platforms (linux 64 bit, windows
8+
// 32 bit), we need to programmatically calculate the max.
9+
static int64_t MKL_LONG_MAX = ((1LL << (sizeof(MKL_LONG) * 8 - 2)) - 1) * 2 + 1;
10+
11+
}} // namespace

aten/src/ATen/mkl/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
All files living in this directory are written with the assumption that MKL is available,
2+
which means that these code are not guarded by `#if AT_MKL_ENABLED()`. Therefore, whenever
3+
you need to use definitions from here, please guard the `#include<ATen/mkl/*.h>` and
4+
definition usages with `#if AT_MKL_ENABLED()` macro, e.g. [SpectralOps.cpp](native/mkl/SpectralOps.cpp).
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#pragma once
2+
3+
#include "ATen/ATen.h"
4+
#include "ATen/Config.h"
5+
6+
#include <string>
7+
#include <stdexcept>
8+
#include <sstream>
9+
#include <cufft.h>
10+
#include <cufftXt.h>
11+
12+
13+
namespace at { namespace native {
14+
15+
static inline std::string _cudaGetErrorEnum(cufftResult error)
16+
{
17+
switch (error)
18+
{
19+
case CUFFT_SUCCESS:
20+
return "CUFFT_SUCCESS";
21+
case CUFFT_INVALID_PLAN:
22+
return "CUFFT_INVALID_PLAN";
23+
case CUFFT_ALLOC_FAILED:
24+
return "CUFFT_ALLOC_FAILED";
25+
case CUFFT_INVALID_TYPE:
26+
return "CUFFT_INVALID_TYPE";
27+
case CUFFT_INVALID_VALUE:
28+
return "CUFFT_INVALID_VALUE";
29+
case CUFFT_INTERNAL_ERROR:
30+
return "CUFFT_INTERNAL_ERROR";
31+
case CUFFT_EXEC_FAILED:
32+
return "CUFFT_EXEC_FAILED";
33+
case CUFFT_SETUP_FAILED:
34+
return "CUFFT_SETUP_FAILED";
35+
case CUFFT_INVALID_SIZE:
36+
return "CUFFT_INVALID_SIZE";
37+
case CUFFT_UNALIGNED_DATA:
38+
return "CUFFT_UNALIGNED_DATA";
39+
case CUFFT_INCOMPLETE_PARAMETER_LIST:
40+
return "CUFFT_INCOMPLETE_PARAMETER_LIST";
41+
case CUFFT_INVALID_DEVICE:
42+
return "CUFFT_INVALID_DEVICE";
43+
case CUFFT_PARSE_ERROR:
44+
return "CUFFT_PARSE_ERROR";
45+
case CUFFT_NO_WORKSPACE:
46+
return "CUFFT_NO_WORKSPACE";
47+
case CUFFT_NOT_IMPLEMENTED:
48+
return "CUFFT_NOT_IMPLEMENTED";
49+
case CUFFT_LICENSE_ERROR:
50+
return "CUFFT_LICENSE_ERROR";
51+
case CUFFT_NOT_SUPPORTED:
52+
return "CUFFT_NOT_SUPPORTED";
53+
default:
54+
std::ostringstream ss;
55+
ss << "unknown error " << error;
56+
return ss.str();
57+
}
58+
}
59+
60+
static inline void CUFFT_CHECK(cufftResult error)
61+
{
62+
if (error != CUFFT_SUCCESS) {
63+
std::ostringstream ss;
64+
ss << "cuFFT error: " << _cudaGetErrorEnum(error);
65+
throw std::runtime_error(ss.str());
66+
}
67+
}
68+
69+
class CufftHandle {
70+
public:
71+
explicit CufftHandle() {
72+
CUFFT_CHECK(cufftCreate(&raw_plan));
73+
}
74+
75+
const cufftHandle &get() const { return raw_plan; }
76+
77+
~CufftHandle() {
78+
CUFFT_CHECK(cufftDestroy(raw_plan));
79+
}
80+
private:
81+
cufftHandle raw_plan;
82+
};
83+
84+
}} // at::native

0 commit comments

Comments
 (0)