Skip to content

Commit 44128e0

Browse files
smessmerfacebook-github-bot
authored andcommitted
Speed up op lookup and registration (#21806)
Summary: Pull Request resolved: #21806 Dispatcher::findSchema(op_name) now uses a lookup table instead of iterating through the list of operators to find it. This speeds up op lookup (as in finding the operator handle from the name, not as in finding a kernel when you already have the operator handle) and it also speeds up op registration since that needs to look if an op with the same name already eists. Differential Revision: D15834256 fbshipit-source-id: c3639d7b567e4ed5e3627c3ebfd01b7d08b55ac1
1 parent d1c8030 commit 44128e0

14 files changed

+260
-230
lines changed

aten/src/ATen/core/dispatch/Dispatcher.cpp

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ OpRegistrationListener::~OpRegistrationListener() {}
3030

3131
Dispatcher::Dispatcher()
3232
: operators_()
33+
, operatorLookupTable_()
3334
, listeners_(guts::make_unique<detail::RegistrationListenerList>())
3435
, mutex_() {}
3536

@@ -40,20 +41,18 @@ C10_EXPORT Dispatcher& Dispatcher::singleton() {
4041
return _singleton;
4142
}
4243

43-
c10::optional<OperatorHandle> Dispatcher::findSchema(const char* operator_name, const char* overload_name) {
44-
const auto found = std::find_if(operators_.begin(), operators_.end(), [&] (const OperatorDef& opDef) {
45-
return opDef.op.schema().name() == operator_name && opDef.op.schema().overload_name() == overload_name;
44+
c10::optional<OperatorHandle> Dispatcher::findSchema(const OperatorName& overload_name) {
45+
return operatorLookupTable_.read([&] (const ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) -> c10::optional<OperatorHandle> {
46+
auto found = operatorLookupTable.find(overload_name);
47+
if (found == operatorLookupTable.end()) {
48+
return c10::nullopt;
49+
}
50+
return found->second;
4651
});
47-
48-
if (found == operators_.end()) {
49-
return c10::nullopt;
50-
}
51-
52-
return OperatorHandle(found);
5352
}
5453

5554
OperatorHandle Dispatcher::findOrRegisterSchema_(FunctionSchema&& schema, OperatorOptions&& options) {
56-
const auto found = findSchema(schema.name().c_str(), schema.overload_name().c_str());
55+
const auto found = findSchema(schema.operator_name());
5756
if (found != c10::nullopt) {
5857
if (found->schema() != schema) {
5958
std::ostringstream str;
@@ -66,14 +65,22 @@ OperatorHandle Dispatcher::findOrRegisterSchema_(FunctionSchema&& schema, Operat
6665
return *found;
6766
}
6867

68+
OperatorName op_name = schema.operator_name();
6969
operators_.emplace_back(std::move(schema), std::move(options));
70-
return OperatorHandle(--operators_.end());
70+
OperatorHandle handle(--operators_.end());
71+
operatorLookupTable_.write([&] (ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) {
72+
operatorLookupTable.emplace(op_name, handle);
73+
});
74+
75+
return handle;
7176
}
7277

7378
SchemaRegistrationHandleRAII Dispatcher::registerSchema(FunctionSchema schema, OperatorOptions options) {
7479
// we need a lock to avoid concurrent writes
7580
std::lock_guard<std::mutex> lock(mutex_);
7681

82+
OperatorName op_name = schema.operator_name();
83+
7784
auto op = findOrRegisterSchema_(std::move(schema), std::move(options));
7885

7986
++op.operatorIterator_->refcount;
@@ -82,15 +89,17 @@ SchemaRegistrationHandleRAII Dispatcher::registerSchema(FunctionSchema schema, O
8289
listeners_->callOnOperatorRegistered(op);
8390
}
8491

85-
return SchemaRegistrationHandleRAII {op, RegistrationHandleRAII([this, op] {
86-
deregisterSchema_(op);
92+
return SchemaRegistrationHandleRAII {op, RegistrationHandleRAII([this, op, op_name] {
93+
deregisterSchema_(op, op_name);
8794
})};
8895
}
8996

90-
void Dispatcher::deregisterSchema_(const OperatorHandle& op) {
97+
void Dispatcher::deregisterSchema_(const OperatorHandle& op, const OperatorName& op_name) {
9198
// we need a lock to avoid concurrent writes
9299
std::lock_guard<std::mutex> lock(mutex_);
93100

101+
TORCH_INTERNAL_ASSERT(op.schema().operator_name() == op_name);
102+
94103
// reduce refcount and actually deregister if no references left
95104
TORCH_INTERNAL_ASSERT(op.operatorIterator_->refcount > 0);
96105
--op.operatorIterator_->refcount;
@@ -101,6 +110,9 @@ void Dispatcher::deregisterSchema_(const OperatorHandle& op) {
101110
listeners_->callOnOperatorDeregistered(op);
102111

103112
operators_.erase(op.operatorIterator_);
113+
operatorLookupTable_.write([&] (ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) {
114+
operatorLookupTable.erase(op_name);
115+
});
104116
}
105117
}
106118

aten/src/ATen/core/dispatch/Dispatcher.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class CAFFE2_API Dispatcher final {
107107
* and returns it if it is registered.
108108
* Returns nullopt otherwise.
109109
*/
110-
c10::optional<OperatorHandle> findSchema(const char* operator_name, const char* overload_name);
110+
c10::optional<OperatorHandle> findSchema(const OperatorName& operator_name);
111111

112112
/**
113113
* Register a kernel to the dispatch table for an operator.
@@ -146,9 +146,10 @@ class CAFFE2_API Dispatcher final {
146146

147147
OperatorHandle findOrRegisterSchema_(FunctionSchema&& schema, OperatorOptions&& options);
148148

149-
void deregisterSchema_(const OperatorHandle& op);
149+
void deregisterSchema_(const OperatorHandle& op, const OperatorName& op_name);
150150

151151
std::list<OperatorDef> operators_;
152+
LeftRight<ska::flat_hash_map<OperatorName, OperatorHandle>> operatorLookupTable_;
152153
std::unique_ptr<detail::RegistrationListenerList> listeners_;
153154
std::mutex mutex_;
154155
};

aten/src/ATen/core/function_schema.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@ struct FunctionSchema {
160160
void checkArg(const IValue& value, const Argument& argument, optional<size_t> pos) const;
161161

162162
public:
163+
const OperatorName& operator_name() const {
164+
return name_;
165+
}
163166
const std::string& name() const {
164167
return name_.name;
165168
}

aten/src/ATen/core/function_schema_inl.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,4 +180,21 @@ inline FunctionSchema FunctionSchema::cloneWithRemappedTypes(
180180
is_varret());
181181
}
182182

183+
inline bool operator==(const OperatorName& lhs, const OperatorName& rhs) {
184+
return lhs.name == rhs.name && lhs.overload_name == rhs.overload_name;
185+
}
186+
187+
inline bool operator!=(const OperatorName& lhs, const OperatorName& rhs) {
188+
return !operator==(lhs, rhs);
189+
}
190+
183191
} // namespace c10
192+
193+
namespace std {
194+
template <>
195+
struct hash<::c10::OperatorName> {
196+
size_t operator()(const ::c10::OperatorName& x) const {
197+
return std::hash<std::string>()(x.name) ^ (~ std::hash<std::string>()(x.overload_name));
198+
}
199+
};
200+
}

0 commit comments

Comments
 (0)