|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include "caffe2/utils/flat_hash_map/flat_hash_map.h" |
| 4 | +#include "caffe2/utils/Metaprogramming.h" |
| 5 | +#include "caffe2/core/dispatch/OpSchema.h" |
| 6 | + |
| 7 | +#include <type_traits> |
| 8 | +#include <array> |
| 9 | +#include <unordered_map> |
| 10 | +#include <iostream> |
| 11 | +#include <mutex> |
| 12 | + |
| 13 | +namespace c10 { |
| 14 | + |
| 15 | +namespace details { |
| 16 | + |
| 17 | +/// Kernel implementations in a thread-safe hash table. |
| 18 | +template<class Key> |
| 19 | +class ThreadsafeOperatorTable_ final { |
| 20 | +public: |
| 21 | + // TODO The current implementation below does not have the correct correctness characteristics |
| 22 | + // which we need. It's worth spelling out exactly what we need: |
| 23 | + // |
| 24 | + // - We need LOCK FREE read access to the table (as per the performance benchmark |
| 25 | + // at https://fb.quip.com/hvz3AGnx8MQ8 |
| 26 | + // |
| 27 | + // - We need to support writes which are possibly concurrent with reads, occurring when |
| 28 | + // a dynamic library is loaded or unloaded. |
| 29 | + // |
| 30 | + // - We probably can require that dynamic library loads/unloads be synchronized (so |
| 31 | + // there are never two concurrent loads.) |
| 32 | + |
| 33 | + template<class Key_> |
| 34 | + void emplace(Key_&& key, void* value) { |
| 35 | + using std::to_string; |
| 36 | + // TODO Locking |
| 37 | + //std::unique_lock<std::shared_timed_mutex> lock(mutex_); |
| 38 | + |
| 39 | + auto result = map_.emplace(std::forward<Key>(key), value); |
| 40 | + if (!result.second) { |
| 41 | + std::ostringstream msg; |
| 42 | + msg << "Tried to register conflicting kernels to the dispatcher: " << key; |
| 43 | + throw std::logic_error(msg.str()); |
| 44 | + } |
| 45 | + } |
| 46 | + |
| 47 | + void erase(const Key& key) { |
| 48 | + // TODO Locking |
| 49 | + //std::unique_lock<std::shared_timed_mutex> lock(mutex_); |
| 50 | + |
| 51 | + size_t num_removed = map_.erase(key); |
| 52 | + assert(num_removed <= 1); //This is not a multi-map |
| 53 | + if (num_removed == 0) { |
| 54 | + throw std::logic_error("Tried to deregister a kernel that isn't registered."); |
| 55 | + } |
| 56 | + } |
| 57 | + |
| 58 | + void* lookup(const Key& key) const { |
| 59 | + // TODO (lock needed but slow perf. Find better way) |
| 60 | + // std::shared_lock<std::shared_timed_mutex> lock(mutex_); |
| 61 | + auto found = map_.find(key); |
| 62 | + if (found == map_.end()) { |
| 63 | + return nullptr; |
| 64 | + } else { |
| 65 | + return found->second; |
| 66 | + } |
| 67 | + } |
| 68 | + |
| 69 | +private: |
| 70 | + ska::flat_hash_map<Key, void*> map_; |
| 71 | + // TODO Figure out how to get fast locking in C++11 (use boost::shared_timed_mutex? folly::SharedMutex? LR pattern?) |
| 72 | + //mutable std::shared_timed_mutex mutex_; |
| 73 | +}; |
| 74 | +} // namespace details |
| 75 | + |
| 76 | +/** |
| 77 | + * Per-operator dispatch table. |
| 78 | + * |
| 79 | + * Given an operator specified by 'OpSchemaDef', this class records a dispatch table for |
| 80 | + * various kernels provided for this operator. For example, if we consider the operator |
| 81 | + * add(Tensor, Tensor), the dispatch table for this operator may contain implementations |
| 82 | + * for various dynamic tensor types, such as (CPUFloatTensor, CPUFloatTensor), |
| 83 | + * (CUDAFloatTensor, CUDAFloatTensor), etc. |
| 84 | + * |
| 85 | + * @tparam OpSchemaDef The operator signature this dispatch table encodes. |
| 86 | + */ |
| 87 | +// TODO: Support dispatch for meta-operators (which apply to all dynamic types) |
| 88 | +template<class OpSchemaDef> |
| 89 | +class DispatchTable final { |
| 90 | +private: |
| 91 | + using Schema = OpSchema<OpSchemaDef>; |
| 92 | + |
| 93 | +public: |
| 94 | + DispatchTable(): kernels_() {} |
| 95 | + |
| 96 | + /** |
| 97 | + * Register a kernel in the table at some dispatch key. |
| 98 | + * @param func Concrete kernel function implementation to register |
| 99 | + * @param dispatch_key Dispatch key to define when this kernel is selected |
| 100 | + */ |
| 101 | + void registerKernel(typename Schema::signature::func_type* func, typename Schema::dispatch::dispatch_key_type dispatch_key) { |
| 102 | + kernels_.emplace(std::move(dispatch_key), reinterpret_cast<void*>(func)); |
| 103 | + } |
| 104 | + |
| 105 | + /** |
| 106 | + * Deregister the kernel for some dispatch key. |
| 107 | + * |
| 108 | + * @param dispatch_key Dispatch key to unregister. |
| 109 | + */ |
| 110 | + // TODO: This isn't going to work so well when we get more complicated override patterns! |
| 111 | + // In this case, an operator will show up in multiple slots, and erasing them one-by-one |
| 112 | + // is probably not such a good idea. |
| 113 | + void deregisterKernel(const typename Schema::dispatch::dispatch_key_type& dispatch_key) { |
| 114 | + kernels_.erase(dispatch_key); |
| 115 | + } |
| 116 | + |
| 117 | + /** |
| 118 | + * Perform a dynamic dispatch on this table. |
| 119 | + * |
| 120 | + * @tparam Args Perfect forwarding template arguments to the dispatch |
| 121 | + * @param args Arguments to invoke the function with |
| 122 | + * @return Returned value of the operator |
| 123 | + */ |
| 124 | + template<class... Args> |
| 125 | + typename Schema::signature::return_type call(Args&&... args) const { |
| 126 | + // TODO Better error message, but need to take care that reference arguments match non-reference arguments and so on. |
| 127 | + // static_assert(std::is_same<typename Schema::return_type (Args...), typename Schema::func_type>::value, "Argument types don't match operator signature"); |
| 128 | + auto kernel_func = lookupKernelFunc_(args...); |
| 129 | + return kernel_func(std::forward<Args>(args)...); |
| 130 | + } |
| 131 | + |
| 132 | +private: |
| 133 | + template<class... Args> |
| 134 | + typename Schema::signature::func_type* lookupKernelFunc_(const Args&... args) const { |
| 135 | + auto dispatch_key = Schema::dispatch::dispatch_key(args...); |
| 136 | + void* found = kernels_.lookup(dispatch_key); |
| 137 | + if (found == nullptr) { |
| 138 | + // TODO Better error message - include op name and dispatch key (i.e. argument types) |
| 139 | + throw std::logic_error(std::string() + "Didn't find kernel to dispatch to for operator '" + Schema::metadata::name() + "'"); |
| 140 | + } |
| 141 | + return reinterpret_cast<typename Schema::signature::func_type*>(found); |
| 142 | + } |
| 143 | + |
| 144 | + details::ThreadsafeOperatorTable_<typename Schema::dispatch::dispatch_key_type> kernels_; |
| 145 | +}; |
| 146 | + |
| 147 | +} // namespace c10 |
| 148 | + |
| 149 | +/* |
| 150 | + * Use this to access the dispatch table singleton for a given op schema. |
| 151 | + * It has an implementation for each op schema def in a cpp file, because |
| 152 | + * we can't rely on the one-definition-rule. |
| 153 | + */ |
| 154 | +template<class OpSchemaDef> c10::DispatchTable<OpSchemaDef>& c10_dispatch_table(); |
0 commit comments