Skip to content

Commit cca2476

Browse files
authored
First version of dispatcher (#8713)
1 parent 2b926aa commit cca2476

19 files changed

+1412
-3
lines changed

caffe2/core/dispatch/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
set(LIB_SOURCES
22
DeviceId.cpp
3+
Dispatcher.cpp
34
DispatchKey.cpp
5+
DispatchTable.cpp
6+
KernelRegistration.cpp
47
LayoutId.cpp
58
OpSchema.cpp
9+
OpSchemaRegistration.cpp
610
TensorTypeId.cpp
711
TensorTypeIdRegistration.cpp
812
)

caffe2/core/dispatch/DeviceId.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <functional>
44
#include <iostream>
5+
#include "caffe2/utils/C++17.h"
56

67
namespace c10 {
78

@@ -19,6 +20,7 @@ inline std::ostream& operator<<(std::ostream& stream, DeviceTypeId device_type_i
1920
case DeviceTypeId::CUDA: return stream << "DeviceTypeId(CUDA)";
2021
case DeviceTypeId::UNDEFINED: return stream << "DeviceTypeId(UNDEFINED)";
2122
}
23+
throw std::logic_error("Unknown DeviceTypeId: " + guts::to_string(static_cast<int>(device_type_id)));
2224
}
2325

2426
}

caffe2/core/dispatch/DispatchKey.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include <vector>
88
#include <functional>
9+
#include <sstream>
910
#include "caffe2/utils/Array.h"
1011

1112
namespace c10 {
@@ -21,6 +22,9 @@ struct TensorParameterDispatchKey final {
2122
inline constexpr bool operator==(const TensorParameterDispatchKey& lhs, const TensorParameterDispatchKey& rhs) {
2223
return lhs.deviceTypeId == rhs.deviceTypeId && lhs.layoutId == rhs.layoutId && lhs.dataType == rhs.dataType;
2324
}
25+
inline std::ostream& operator<<(std::ostream& stream, const TensorParameterDispatchKey& key) {
26+
return stream << "TensorKey(" << key.deviceTypeId << ", " << key.layoutId.value() << ", " << key.dataType << ")";
27+
}
2428
} // namespace details
2529
} // namespace c10
2630

@@ -58,6 +62,18 @@ inline constexpr bool operator==(const DispatchKey<num_dispatch_args> &lhs, cons
5862
// TODO: Use AVX instructions to perform this equality test more quickly
5963
return lhs.argTypes == rhs.argTypes;
6064
}
65+
template<size_t num_dispatch_args>
66+
inline std::ostream& operator<<(std::ostream& stream, const DispatchKey<num_dispatch_args>& key) {
67+
stream << "DispatchKey(";
68+
if (num_dispatch_args > 0) {
69+
stream << "DispatchKey(" << key.argTypes[0];
70+
for (size_t i = 1; i < num_dispatch_args; ++i) {
71+
stream << ", " << key.argTypes[i];
72+
}
73+
stream << ")";
74+
}
75+
return stream << ")";
76+
}
6177

6278
} // namespace c10
6379

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#include "caffe2/core/dispatch/DispatchTable.h"
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
#pragma once
2+
3+
#include "caffe2/utils/flat_hash_map/flat_hash_map.h"
4+
#include "caffe2/utils/Metaprogramming.h"
5+
#include "caffe2/core/dispatch/OpSchema.h"
6+
7+
#include <type_traits>
8+
#include <array>
9+
#include <unordered_map>
10+
#include <iostream>
11+
#include <mutex>
12+
13+
namespace c10 {
14+
15+
namespace details {
16+
17+
/// Kernel implementations in a thread-safe hash table.
18+
template<class Key>
19+
class ThreadsafeOperatorTable_ final {
20+
public:
21+
// TODO The current implementation below does not have the correct correctness characteristics
22+
// which we need. It's worth spelling out exactly what we need:
23+
//
24+
// - We need LOCK FREE read access to the table (as per the performance benchmark
25+
// at https://fb.quip.com/hvz3AGnx8MQ8
26+
//
27+
// - We need to support writes which are possibly concurrent with reads, occurring when
28+
// a dynamic library is loaded or unloaded.
29+
//
30+
// - We probably can require that dynamic library loads/unloads be synchronized (so
31+
// there are never two concurrent loads.)
32+
33+
template<class Key_>
34+
void emplace(Key_&& key, void* value) {
35+
using std::to_string;
36+
// TODO Locking
37+
//std::unique_lock<std::shared_timed_mutex> lock(mutex_);
38+
39+
auto result = map_.emplace(std::forward<Key>(key), value);
40+
if (!result.second) {
41+
std::ostringstream msg;
42+
msg << "Tried to register conflicting kernels to the dispatcher: " << key;
43+
throw std::logic_error(msg.str());
44+
}
45+
}
46+
47+
void erase(const Key& key) {
48+
// TODO Locking
49+
//std::unique_lock<std::shared_timed_mutex> lock(mutex_);
50+
51+
size_t num_removed = map_.erase(key);
52+
assert(num_removed <= 1); //This is not a multi-map
53+
if (num_removed == 0) {
54+
throw std::logic_error("Tried to deregister a kernel that isn't registered.");
55+
}
56+
}
57+
58+
void* lookup(const Key& key) const {
59+
// TODO (lock needed but slow perf. Find better way)
60+
// std::shared_lock<std::shared_timed_mutex> lock(mutex_);
61+
auto found = map_.find(key);
62+
if (found == map_.end()) {
63+
return nullptr;
64+
} else {
65+
return found->second;
66+
}
67+
}
68+
69+
private:
70+
ska::flat_hash_map<Key, void*> map_;
71+
// TODO Figure out how to get fast locking in C++11 (use boost::shared_timed_mutex? folly::SharedMutex? LR pattern?)
72+
//mutable std::shared_timed_mutex mutex_;
73+
};
74+
} // namespace details
75+
76+
/**
77+
* Per-operator dispatch table.
78+
*
79+
* Given an operator specified by 'OpSchemaDef', this class records a dispatch table for
80+
* various kernels provided for this operator. For example, if we consider the operator
81+
* add(Tensor, Tensor), the dispatch table for this operator may contain implementations
82+
* for various dynamic tensor types, such as (CPUFloatTensor, CPUFloatTensor),
83+
* (CUDAFloatTensor, CUDAFloatTensor), etc.
84+
*
85+
* @tparam OpSchemaDef The operator signature this dispatch table encodes.
86+
*/
87+
// TODO: Support dispatch for meta-operators (which apply to all dynamic types)
88+
template<class OpSchemaDef>
89+
class DispatchTable final {
90+
private:
91+
using Schema = OpSchema<OpSchemaDef>;
92+
93+
public:
94+
DispatchTable(): kernels_() {}
95+
96+
/**
97+
* Register a kernel in the table at some dispatch key.
98+
* @param func Concrete kernel function implementation to register
99+
* @param dispatch_key Dispatch key to define when this kernel is selected
100+
*/
101+
void registerKernel(typename Schema::signature::func_type* func, typename Schema::dispatch::dispatch_key_type dispatch_key) {
102+
kernels_.emplace(std::move(dispatch_key), reinterpret_cast<void*>(func));
103+
}
104+
105+
/**
106+
* Deregister the kernel for some dispatch key.
107+
*
108+
* @param dispatch_key Dispatch key to unregister.
109+
*/
110+
// TODO: This isn't going to work so well when we get more complicated override patterns!
111+
// In this case, an operator will show up in multiple slots, and erasing them one-by-one
112+
// is probably not such a good idea.
113+
void deregisterKernel(const typename Schema::dispatch::dispatch_key_type& dispatch_key) {
114+
kernels_.erase(dispatch_key);
115+
}
116+
117+
/**
118+
* Perform a dynamic dispatch on this table.
119+
*
120+
* @tparam Args Perfect forwarding template arguments to the dispatch
121+
* @param args Arguments to invoke the function with
122+
* @return Returned value of the operator
123+
*/
124+
template<class... Args>
125+
typename Schema::signature::return_type call(Args&&... args) const {
126+
// TODO Better error message, but need to take care that reference arguments match non-reference arguments and so on.
127+
// static_assert(std::is_same<typename Schema::return_type (Args...), typename Schema::func_type>::value, "Argument types don't match operator signature");
128+
auto kernel_func = lookupKernelFunc_(args...);
129+
return kernel_func(std::forward<Args>(args)...);
130+
}
131+
132+
private:
133+
template<class... Args>
134+
typename Schema::signature::func_type* lookupKernelFunc_(const Args&... args) const {
135+
auto dispatch_key = Schema::dispatch::dispatch_key(args...);
136+
void* found = kernels_.lookup(dispatch_key);
137+
if (found == nullptr) {
138+
// TODO Better error message - include op name and dispatch key (i.e. argument types)
139+
throw std::logic_error(std::string() + "Didn't find kernel to dispatch to for operator '" + Schema::metadata::name() + "'");
140+
}
141+
return reinterpret_cast<typename Schema::signature::func_type*>(found);
142+
}
143+
144+
details::ThreadsafeOperatorTable_<typename Schema::dispatch::dispatch_key_type> kernels_;
145+
};
146+
147+
} // namespace c10
148+
149+
/*
150+
* Use this to access the dispatch table singleton for a given op schema.
151+
* It has an implementation for each op schema def in a cpp file, because
152+
* we can't rely on the one-definition-rule.
153+
*/
154+
template<class OpSchemaDef> c10::DispatchTable<OpSchemaDef>& c10_dispatch_table();
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#include "caffe2/core/dispatch/Dispatcher.h"

caffe2/core/dispatch/Dispatcher.h

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#pragma once
2+
3+
#include "caffe2/core/dispatch/DispatchTable.h"
4+
5+
namespace c10 {
6+
7+
/**
8+
* Top-level dispatch interface for dispatching via the dynamic dispatcher.
9+
*/
10+
template<class OpSchemaDef>
11+
class Dispatcher final {
12+
public:
13+
// Implementation note: this class abstracts over the fact that we have per-operator
14+
// dispatch tables. This could be easily adjusted to have a single global hash
15+
// table.
16+
17+
/**
18+
* Register an operator to the dispatch table for some operator schema.
19+
*
20+
* @tparam OpSchemaDef Operator schema to register this operator to (mandatory)
21+
* @tparam Args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::registerOp (inferred)
22+
* @param args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::registerOp
23+
* @return void
24+
*/
25+
template<class... Args>
26+
static void registerKernel(Args&&... args) {
27+
auto& dispatch_table_for_this_op = c10_dispatch_table<OpSchemaDef>();
28+
return dispatch_table_for_this_op.registerKernel(std::forward<Args>(args)...);
29+
}
30+
31+
/**
32+
* Remove an operator from the dispatch table for some operator schema.
33+
*
34+
* @tparam OpSchemaDef Operator schema to deregister from (mandatory)
35+
* @tparam Args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::deregisterOp (inferred)
36+
* @param args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::deregisterOp
37+
* @return void
38+
*/
39+
template<class... Args>
40+
static void deregisterKernel(Args&&... args) {
41+
auto& dispatch_table_for_this_op = c10_dispatch_table<OpSchemaDef>();
42+
return dispatch_table_for_this_op.deregisterKernel(std::forward<Args>(args)...);
43+
}
44+
45+
/**
46+
* Perform a dynamic dispatch to some operator
47+
*
48+
* @tparam OpSchemaDef Operator schema to dispatch with (mandatory)
49+
* @tparam Args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::call (inferred)
50+
* @param args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::call
51+
* @return Return type of this operator
52+
*/
53+
template<class... Args>
54+
static typename OpSchema<OpSchemaDef>::signature::return_type call(Args&&... args) {
55+
auto& dispatch_table_for_this_op = c10_dispatch_table<OpSchemaDef>();
56+
return dispatch_table_for_this_op.call(std::forward<Args>(args)...);
57+
}
58+
};
59+
60+
} // namespace c10
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#include "caffe2/core/dispatch/KernelRegistration.h"

0 commit comments

Comments
 (0)