@@ -30,6 +30,7 @@ OpRegistrationListener::~OpRegistrationListener() {}
3030
3131Dispatcher::Dispatcher ()
3232: operators_()
33+ , operatorLookupTable_()
3334, listeners_(guts::make_unique<detail::RegistrationListenerList>())
3435, mutex_() {}
3536
@@ -40,20 +41,18 @@ C10_EXPORT Dispatcher& Dispatcher::singleton() {
4041 return _singleton;
4142}
4243
43- c10::optional<OperatorHandle> Dispatcher::findSchema (const char * operator_name, const char * overload_name) {
44- const auto found = std::find_if (operators_.begin (), operators_.end (), [&] (const OperatorDef& opDef) {
45- return opDef.op .schema ().name () == operator_name && opDef.op .schema ().overload_name () == overload_name;
44+ c10::optional<OperatorHandle> Dispatcher::findSchema (const OperatorName& overload_name) {
45+ return operatorLookupTable_.read ([&] (const ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) -> c10::optional<OperatorHandle> {
46+ auto found = operatorLookupTable.find (overload_name);
47+ if (found == operatorLookupTable.end ()) {
48+ return c10::nullopt ;
49+ }
50+ return found->second ;
4651 });
47-
48- if (found == operators_.end ()) {
49- return c10::nullopt ;
50- }
51-
52- return OperatorHandle (found);
5352}
5453
5554OperatorHandle Dispatcher::findOrRegisterSchema_ (FunctionSchema&& schema, OperatorOptions&& options) {
56- const auto found = findSchema (schema.name (). c_str (), schema. overload_name (). c_str ());
55+ const auto found = findSchema (schema.operator_name ());
5756 if (found != c10::nullopt ) {
5857 if (found->schema () != schema) {
5958 std::ostringstream str;
@@ -66,14 +65,22 @@ OperatorHandle Dispatcher::findOrRegisterSchema_(FunctionSchema&& schema, Operat
6665 return *found;
6766 }
6867
68+ OperatorName op_name = schema.operator_name ();
6969 operators_.emplace_back (std::move (schema), std::move (options));
70- return OperatorHandle (--operators_.end ());
70+ OperatorHandle handle (--operators_.end ());
71+ operatorLookupTable_.write ([&] (ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) {
72+ operatorLookupTable.emplace (op_name, handle);
73+ });
74+
75+ return handle;
7176}
7277
7378SchemaRegistrationHandleRAII Dispatcher::registerSchema (FunctionSchema schema, OperatorOptions options) {
7479 // we need a lock to avoid concurrent writes
7580 std::lock_guard<std::mutex> lock (mutex_);
7681
82+ OperatorName op_name = schema.operator_name ();
83+
7784 auto op = findOrRegisterSchema_ (std::move (schema), std::move (options));
7885
7986 ++op.operatorIterator_ ->refcount ;
@@ -82,15 +89,17 @@ SchemaRegistrationHandleRAII Dispatcher::registerSchema(FunctionSchema schema, O
8289 listeners_->callOnOperatorRegistered (op);
8390 }
8491
85- return SchemaRegistrationHandleRAII {op, RegistrationHandleRAII ([this , op] {
86- deregisterSchema_ (op);
92+ return SchemaRegistrationHandleRAII {op, RegistrationHandleRAII ([this , op, op_name ] {
93+ deregisterSchema_ (op, op_name );
8794 })};
8895}
8996
90- void Dispatcher::deregisterSchema_ (const OperatorHandle& op) {
97+ void Dispatcher::deregisterSchema_ (const OperatorHandle& op, const OperatorName& op_name ) {
9198 // we need a lock to avoid concurrent writes
9299 std::lock_guard<std::mutex> lock (mutex_);
93100
101+ TORCH_INTERNAL_ASSERT (op.schema ().operator_name () == op_name);
102+
94103 // reduce refcount and actually deregister if no references left
95104 TORCH_INTERNAL_ASSERT (op.operatorIterator_ ->refcount > 0 );
96105 --op.operatorIterator_ ->refcount ;
@@ -101,6 +110,9 @@ void Dispatcher::deregisterSchema_(const OperatorHandle& op) {
101110 listeners_->callOnOperatorDeregistered (op);
102111
103112 operators_.erase (op.operatorIterator_ );
113+ operatorLookupTable_.write ([&] (ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) {
114+ operatorLookupTable.erase (op_name);
115+ });
104116 }
105117}
106118
0 commit comments