Skip to content

Commit cc18f74

Browse files
author
royboy
committed
Replace Type dispatch with ATenDispatch
gh-metadata: pytorch pytorch 21320 gh/li-roy/25/head
1 parent 44128e0 commit cc18f74

22 files changed

+1125
-989
lines changed

aten/src/ATen/ATen.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@
2424
#include <c10/core/Storage.h>
2525
#include <c10/core/TensorOptions.h>
2626
#include <c10/util/Exception.h>
27+
#include <ATen/core/ATenDispatch.h>
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#include <ATen/core/ATenDispatch.h>
2+
3+
namespace at {
4+
5+
ATenDispatch & globalATenDispatch() {
6+
static ATenDispatch singleton;
7+
return singleton;
8+
}
9+
10+
} // namespace at

aten/src/ATen/core/ATenDispatch.h

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
#pragma once
2+
3+
#include <c10/core/Backend.h>
4+
#include <unordered_map>
5+
#include <c10/util/C++17.h>
6+
#include <memory>
7+
#include <mutex>
8+
9+
// This dispatch class serves as a replacement for our previous dispatch
10+
// mechanism, in which all functions were members of a Type class. A derived
11+
// class existed for each backend (and Variable), and the vtable was used to
12+
// dispatch to the correct implementation. This class is to be replaced by
13+
// the c10 dispatcher when it supports all argument and return types.
14+
// This implementation opts to store implementations in a table of void*.
15+
16+
namespace at {
17+
18+
// ATenOpTable stores the implementations for each backend, in addition to
19+
// an implementation for variables.
20+
class CAFFE2_API ATenOpTable {
21+
public:
22+
ATenOpTable(std::string schema)
23+
: schema_(std::move(schema)) {}
24+
25+
template<class FuncType>
26+
FuncType* getOp(Backend backend, bool is_variable) const {
27+
if (is_variable) {
28+
return reinterpret_cast<FuncType*>(getVariableOp());
29+
}
30+
return reinterpret_cast<FuncType*>(getBaseOp(backend));
31+
}
32+
private:
33+
void registerOp(Backend backend, void* fn) {
34+
TORCH_CHECK(function_table_[static_cast<int64_t>(backend)] == nullptr,
35+
"Attempting to register variable function for schema ", schema_,
36+
" and backend ", toString(backend),
37+
" but there is already a function registered");
38+
function_table_[static_cast<int64_t>(backend)] = fn;
39+
}
40+
41+
void registerVariableOp(void* fn) {
42+
TORCH_CHECK(variable_function_ == nullptr,
43+
"Attempting to register variable function for schema ", schema_,
44+
" but there is already a function registered");
45+
variable_function_ = fn;
46+
}
47+
48+
void* getBaseOp(Backend backend) const {
49+
if (function_table_[static_cast<int64_t>(backend)] == nullptr) {
50+
TORCH_CHECK(function_table_[static_cast<int64_t>(Backend::Undefined)] != nullptr,
51+
"No function is registered for schema ", schema_, " on backend ", toString(backend));
52+
return function_table_[static_cast<int64_t>(Backend::Undefined)];
53+
}
54+
return function_table_[static_cast<int64_t>(backend)];
55+
}
56+
57+
void* getVariableOp() const {
58+
TORCH_CHECK(variable_function_ != nullptr,
59+
"No variable function registered for ", schema_);
60+
return variable_function_;
61+
}
62+
63+
friend class ATenDispatch;
64+
65+
std::string schema_;
66+
void* function_table_[static_cast<int64_t>(Backend::NumOptions)] = {nullptr};
67+
void* variable_function_ = nullptr;
68+
};
69+
70+
class CAFFE2_API ATenDispatch {
71+
public:
72+
template<class FuncType>
73+
ATenDispatch& registerOp(Backend backend, const char* schema, FuncType* fn) {
74+
std::lock_guard<std::mutex> lock(mutex_);
75+
if (op_tables_.find(schema) == op_tables_.end()) {
76+
op_tables_.insert(std::make_pair(schema, ATenOpTable(schema)));
77+
}
78+
op_tables_.at(schema).registerOp(backend, reinterpret_cast<void*>(fn));
79+
return *this;
80+
}
81+
82+
template <class FuncType>
83+
ATenDispatch& registerVariableOp(const char* schema, FuncType* fn) {
84+
std::lock_guard<std::mutex> lock(mutex_);
85+
if (op_tables_.find(schema) == op_tables_.end()) {
86+
op_tables_.insert(std::make_pair(schema, ATenOpTable(schema)));
87+
}
88+
op_tables_.at(schema).registerVariableOp(reinterpret_cast<void*>(fn));
89+
return *this;
90+
}
91+
92+
const ATenOpTable* getOpTable(const char* schema) const {
93+
auto iter = op_tables_.find(schema);
94+
TORCH_CHECK(iter != op_tables_.end(),
95+
"No functions are registered for schema ", schema);
96+
return &iter->second;
97+
}
98+
99+
private:
100+
std::unordered_map<std::string, ATenOpTable> op_tables_;
101+
std::mutex mutex_;
102+
};
103+
104+
CAFFE2_API ATenDispatch& globalATenDispatch();
105+
106+
} // namespace at

aten/src/ATen/core/LegacyTypeDispatch.h

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,7 @@ class CAFFE2_API LegacyTypeDispatch {
5959
}
6060
Type * getNonVariableTypeOpt(Backend p, ScalarType s) {
6161
if (p != Backend::Undefined) {
62-
initForDeviceType(backendToDeviceType(p));
63-
initForScalarType(s);
62+
initForBackend(p);
6463
}
6564
auto type = getNonVariableTypeRaw(p, s);
6665

@@ -103,10 +102,11 @@ class CAFFE2_API LegacyTypeDispatch {
103102
type_registry[static_cast<int>(b)] = std::move(t);
104103
detail::getVariableHooks().registerVariableTypeFor(this, b);
105104
}
106-
private:
107-
void initForDeviceType(DeviceType p) {
105+
void initForBackend(Backend b) {
106+
auto p = backendToDeviceType(b);
108107
static std::once_flag cpu_once;
109108
static std::once_flag cuda_once;
109+
static std::once_flag complex_once;
110110
if (p == DeviceType::CPU) {
111111
std::call_once(cpu_once, [] {
112112
getLegacyDeviceTypeInit().initCPU();
@@ -120,17 +120,13 @@ class CAFFE2_API LegacyTypeDispatch {
120120
getLegacyDeviceTypeInit().initHIP();
121121
});
122122
}
123-
}
124-
void initForScalarType(ScalarType s) {
125-
static std::once_flag once;
126-
// Only complex may need initialization
127-
if (isComplexType(s)) {
128-
std::call_once(once, [] {
123+
if (b == Backend::ComplexCPU || b == Backend::ComplexCUDA) {
124+
std::call_once(complex_once, [] {
129125
getLegacyDeviceTypeInit().initComplex();
130126
});
131127
}
132128
}
133-
129+
private:
134130
// NB: type_registry has nullptr for all CUDA backends until
135131
// CUDA initialization has occurred
136132
TypeUniquePtr type_registry

0 commit comments

Comments
 (0)