|
1 | 1 | #pragma once |
2 | 2 |
|
3 | 3 | #include <ATen/core/Tensor.h> |
| 4 | +#include <c10/macros/Macros.h> |
4 | 5 | #include <c10/util/Half.h> |
5 | 6 | #include <c10/util/Exception.h> |
6 | 7 | #include <ATen/core/DeprecatedTypeProperties.h> |
|
11 | 12 | return __VA_ARGS__(); \ |
12 | 13 | } |
13 | 14 |
|
| 15 | +#define AT_QINT_PRIVATE_CASE_TYPE(enum_type, type, underlying_type, ...) \ |
| 16 | + case enum_type: { \ |
| 17 | + using scalar_t C10_UNUSED = type; \ |
| 18 | + using underlying_t C10_UNUSED = underlying_type; \ |
| 19 | + return __VA_ARGS__(); \ |
| 20 | + } |
| 21 | + |
14 | 22 | namespace detail { |
15 | 23 |
|
16 | 24 | template <at::ScalarType N> |
@@ -211,14 +219,14 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} |
211 | 219 | #define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \ |
212 | 220 | [&] { \ |
213 | 221 | switch (TYPE) { \ |
214 | | - AT_PRIVATE_CASE_TYPE( \ |
215 | | - at::ScalarType::QInt8, qint8, __VA_ARGS__) \ |
216 | | - AT_PRIVATE_CASE_TYPE( \ |
217 | | - at::ScalarType::QUInt8, quint8, __VA_ARGS__) \ |
218 | | - AT_PRIVATE_CASE_TYPE( \ |
219 | | - at::ScalarType::QInt32, qint32, __VA_ARGS__) \ |
| 222 | + AT_QINT_PRIVATE_CASE_TYPE( \ |
| 223 | + at::ScalarType::QInt8, qint8, int8_t, __VA_ARGS__) \ |
| 224 | + AT_QINT_PRIVATE_CASE_TYPE( \ |
| 225 | + at::ScalarType::QUInt8, quint8, uint8_t, __VA_ARGS__) \ |
| 226 | + AT_QINT_PRIVATE_CASE_TYPE( \ |
| 227 | + at::ScalarType::QInt32, qint32, int, __VA_ARGS__) \ |
220 | 228 | default: \ |
221 | | - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ |
| 229 | + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ |
222 | 230 | } \ |
223 | 231 | }() |
224 | 232 |
|
|
0 commit comments