Skip to content

Commit 1ed1a2f

Browse files
Basil Hosmerfacebook-github-bot
authored andcommitted
[wip] fast typeMeta/ScalarType conversion approach 2 (#44965)
Summary: Pull Request resolved: #44965 Test Plan: Imported from OSS Reviewed By: ezyang Differential Revision: D23789657 Pulled By: bhosmer fbshipit-source-id: 5afdd52d24bd097891ff4a7313033f7bd400165e
1 parent 489af4d commit 1ed1a2f

File tree

14 files changed

+251
-167
lines changed

14 files changed

+251
-167
lines changed

aten/src/ATen/native/DispatchStub.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
#include <c10/core/Backend.h>
44
#include <c10/core/ScalarType.h>
55
#include <c10/util/Exception.h>
6+
67
#include <type_traits>
8+
#include <atomic>
79

810
// Implements instruction set specific function dispatch.
911
//

aten/src/ATen/templates/TensorBody.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <c10/core/QScheme.h>
77
#include <c10/core/Scalar.h>
88
#include <c10/core/ScalarType.h>
9+
#include <c10/core/ScalarTypeToTypeMeta.h>
910
#include <c10/core/Storage.h>
1011
#include <ATen/core/TensorAccessor.h>
1112
#include <c10/core/TensorImpl.h>

aten/src/TH/THStorageFunctions.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <TH/THStorageFunctions.h>
99

1010
#include <c10/core/ScalarType.h>
11+
#include <c10/core/ScalarTypeToTypeMeta.h>
1112

1213
// Note [Weak references for intrusive refcounting]
1314
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

c10/core/DefaultDtype.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33

44
namespace c10 {
55
static auto default_dtype = caffe2::TypeMeta::Make<float>();
6-
static auto default_dtype_as_scalartype = typeMetaToScalarType(default_dtype);
6+
static auto default_dtype_as_scalartype = default_dtype.toScalarType();
77
static auto default_complex_dtype = caffe2::TypeMeta::Make<c10::complex<float>>();
88

99
void set_default_dtype(caffe2::TypeMeta dtype) {
1010
default_dtype = std::move(dtype);
11-
default_dtype_as_scalartype = typeMetaToScalarType(default_dtype);
11+
default_dtype_as_scalartype = default_dtype.toScalarType();
1212
if(default_dtype_as_scalartype == ScalarType::Double) {
1313
default_complex_dtype = std::move(caffe2::TypeMeta::Make<c10::complex<double>>());
1414
} else {

c10/core/ScalarType.h

Lines changed: 6 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
#include <c10/util/ArrayRef.h>
44
#include <c10/util/complex.h>
55
#include <c10/util/Half.h>
6+
#include <c10/util/qint32.h>
7+
#include <c10/util/qint8.h>
8+
#include <c10/util/quint8.h>
69
#include <c10/util/BFloat16.h>
710
#include <c10/util/Optional.h>
8-
#include <c10/util/typeid.h>
911

1012
#include <complex>
1113
#include <cstdint>
@@ -67,6 +69,8 @@ enum class ScalarType : int8_t {
6769
NumOptions
6870
};
6971

72+
constexpr uint16_t NumScalarTypes = static_cast<uint16_t>(ScalarType::NumOptions);
73+
7074
namespace impl {
7175

7276
// These are used to map ScalarTypes to C++ types.
@@ -93,7 +97,7 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType)
9397

9498
#undef SPECIALIZE_ScalarTypeToCPPType
9599

96-
}
100+
} // namespace impl
97101

98102
template <typename T>
99103
struct CppTypeToScalarType;
@@ -160,64 +164,6 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
160164
_(c10::complex<float>, ComplexFloat) \
161165
_(c10::complex<double>, ComplexDouble)
162166

163-
static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) {
164-
#define DEFINE_CASE(ctype, name) \
165-
case ScalarType::name: \
166-
return caffe2::TypeMeta::Make<ctype>();
167-
168-
switch (scalar_type) {
169-
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
170-
case ScalarType::Undefined:
171-
return caffe2::TypeMeta();
172-
default:
173-
AT_ERROR(
174-
"Unrecognized Scalartype ",
175-
scalar_type,
176-
" (please report this error)");
177-
}
178-
#undef DEFINE_CASE
179-
}
180-
181-
static inline c10::optional<ScalarType> tryTypeMetaToScalarType(
182-
caffe2::TypeMeta dtype) {
183-
#define DEFINE_IF(ctype, name) \
184-
if (dtype == caffe2::TypeMeta::Make<ctype>()) { \
185-
return {ScalarType::name}; \
186-
}
187-
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_IF)
188-
#undef DEFINE_IF
189-
if (dtype == caffe2::TypeMeta()) {
190-
return {ScalarType::Undefined};
191-
}
192-
return c10::nullopt;
193-
}
194-
195-
static inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) {
196-
if (auto scalar_type = tryTypeMetaToScalarType(dtype)) {
197-
return *scalar_type;
198-
}
199-
AT_ERROR(
200-
"Unsupported TypeMeta in ATen: ", dtype, " (please report this error)");
201-
}
202-
203-
inline optional<at::ScalarType> optTypeMetaToScalarType(optional<caffe2::TypeMeta> type_meta) {
204-
if (!type_meta.has_value()) {
205-
return c10::nullopt;
206-
}
207-
return typeMetaToScalarType(*type_meta);
208-
}
209-
210-
static inline bool operator==(ScalarType t, caffe2::TypeMeta m) {
211-
if (auto mt = tryTypeMetaToScalarType(m)) {
212-
return (*mt) == t;
213-
}
214-
return false;
215-
}
216-
217-
static inline bool operator==(caffe2::TypeMeta m, ScalarType t) {
218-
return t == m;
219-
}
220-
221167
#define DEFINE_CONSTANT(_, name) \
222168
constexpr ScalarType k##name = ScalarType::name;
223169

c10/core/ScalarTypeToTypeMeta.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#pragma once
2+
3+
#include <c10/core/ScalarType.h>
4+
#include <c10/util/typeid.h>
5+
6+
// these just expose TypeMeta/ScalarType bridge functions in c10
7+
// TODO move to typeid.h (or codemod away) when TypeMeta et al
8+
// are moved from caffe2 to c10 (see note at top of typeid.h)
9+
10+
namespace c10 {
11+
12+
/**
13+
* convert ScalarType enum values to TypeMeta handles
14+
*/
15+
static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) {
16+
return caffe2::TypeMeta::fromScalarType(scalar_type);
17+
}
18+
19+
/**
20+
* convert TypeMeta handles to ScalarType enum values
21+
*/
22+
static inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) {
23+
return dtype.toScalarType();
24+
}
25+
26+
/**
27+
* typeMetaToScalarType(), lifted to optional
28+
*/
29+
static inline optional<at::ScalarType> optTypeMetaToScalarType(optional<caffe2::TypeMeta> type_meta) {
30+
if (!type_meta.has_value()) {
31+
return c10::nullopt;
32+
}
33+
return type_meta->toScalarType();
34+
}
35+
36+
/**
37+
* convenience: equality across TypeMeta/ScalarType conversion
38+
*/
39+
static inline bool operator==(ScalarType t, caffe2::TypeMeta m) {
40+
return m.isScalarType(t);
41+
}
42+
43+
static inline bool operator==(caffe2::TypeMeta m, ScalarType t) {
44+
return t == m;
45+
}
46+
47+
} // namespace c10

c10/core/TensorImpl.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2::
6262
data_type_(data_type),
6363
device_opt_(device_opt) {
6464
if (!key_set.empty()) {
65-
AT_ASSERT(data_type.id() == caffe2::TypeIdentifier::uninitialized() ||
66-
device_opt_.has_value());
65+
TORCH_INTERNAL_ASSERT(data_type == ScalarType::Undefined || device_opt_.has_value());
6766
// UndefinedTensorImpl is a singleton, so we skip logging it
6867
C10_LOG_API_USAGE_ONCE("tensor.create");
6968
}

c10/core/TensorImpl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1777,13 +1777,13 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
17771777
// strides SmallVector (pre-allocated 4)
17781778
// storage offset
17791779
// numel
1780-
// data type pointer
1780+
// data type
17811781
// (optional) device
17821782
// tensor type id
17831783
// miscellaneous bitfield
17841784
//
17851785
static_assert(sizeof(void*) != sizeof(int64_t) || // if 64-bit...
1786-
sizeof(TensorImpl) == sizeof(int64_t) * 31,
1786+
sizeof(TensorImpl) == sizeof(int64_t) * 30,
17871787
"You changed the size of TensorImpl on 64-bit arch."
17881788
"See Note [TensorImpl size constraints] on how to proceed.");
17891789
} // namespace c10

c10/core/TensorOptions.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <c10/core/Backend.h>
55
#include <c10/core/Layout.h>
66
#include <c10/core/ScalarType.h>
7+
#include <c10/core/ScalarTypeToTypeMeta.h>
78
#include <c10/core/Device.h>
89
#include <c10/core/MemoryFormat.h>
910
#include <c10/core/DispatchKeySet.h>

c10/core/UndefinedTensorImpl.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ struct C10_API UndefinedTensorImpl final : public TensorImpl {
2828
private:
2929
UndefinedTensorImpl();
3030
static UndefinedTensorImpl _singleton;
31-
public:
32-
friend struct UndefinedType;
3331
};
3432

3533
} // namespace c10

0 commit comments

Comments
 (0)