33#include < c10/util/ArrayRef.h>
44#include < c10/util/complex.h>
55#include < c10/util/Half.h>
6+ #include < c10/util/qint32.h>
7+ #include < c10/util/qint8.h>
8+ #include < c10/util/quint8.h>
69#include < c10/util/BFloat16.h>
710#include < c10/util/Optional.h>
8- #include < c10/util/typeid.h>
911
1012#include < complex>
1113#include < cstdint>
@@ -67,6 +69,8 @@ enum class ScalarType : int8_t {
6769 NumOptions
6870};
6971
72+ constexpr uint16_t NumScalarTypes = static_cast <uint16_t >(ScalarType::NumOptions);
73+
7074namespace impl {
7175
7276// These are used to map ScalarTypes to C++ types.
@@ -93,7 +97,7 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType)
9397
9498#undef SPECIALIZE_ScalarTypeToCPPType
9599
96- }
100+ } // namespace impl
97101
98102template <typename T>
99103struct CppTypeToScalarType ;
@@ -160,64 +164,6 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
160164 _ (c10::complex <float >, ComplexFloat) \
161165 _ (c10::complex <double >, ComplexDouble)
162166
163- static inline caffe2::TypeMeta scalarTypeToTypeMeta (ScalarType scalar_type) {
164- #define DEFINE_CASE (ctype, name ) \
165- case ScalarType::name: \
166- return caffe2::TypeMeta::Make<ctype>();
167-
168- switch (scalar_type) {
169- AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS (DEFINE_CASE)
170- case ScalarType::Undefined:
171- return caffe2::TypeMeta ();
172- default :
173- AT_ERROR (
174- " Unrecognized Scalartype " ,
175- scalar_type,
176- " (please report this error)" );
177- }
178- #undef DEFINE_CASE
179- }
180-
181- static inline c10::optional<ScalarType> tryTypeMetaToScalarType (
182- caffe2::TypeMeta dtype) {
183- #define DEFINE_IF (ctype, name ) \
184- if (dtype == caffe2::TypeMeta::Make<ctype>()) { \
185- return {ScalarType::name}; \
186- }
187- AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS (DEFINE_IF)
188- #undef DEFINE_IF
189- if (dtype == caffe2::TypeMeta ()) {
190- return {ScalarType::Undefined};
191- }
192- return c10::nullopt ;
193- }
194-
195- static inline ScalarType typeMetaToScalarType (caffe2::TypeMeta dtype) {
196- if (auto scalar_type = tryTypeMetaToScalarType (dtype)) {
197- return *scalar_type;
198- }
199- AT_ERROR (
200- " Unsupported TypeMeta in ATen: " , dtype, " (please report this error)" );
201- }
202-
203- inline optional<at::ScalarType> optTypeMetaToScalarType (optional<caffe2::TypeMeta> type_meta) {
204- if (!type_meta.has_value ()) {
205- return c10::nullopt ;
206- }
207- return typeMetaToScalarType (*type_meta);
208- }
209-
210- static inline bool operator ==(ScalarType t, caffe2::TypeMeta m) {
211- if (auto mt = tryTypeMetaToScalarType (m)) {
212- return (*mt) == t;
213- }
214- return false ;
215- }
216-
217- static inline bool operator ==(caffe2::TypeMeta m, ScalarType t) {
218- return t == m;
219- }
220-
221167#define DEFINE_CONSTANT (_, name ) \
222168 constexpr ScalarType k##name = ScalarType::name;
223169
0 commit comments