Skip to content

Commit 42b2f56

Browse files
Yuri Putivskyfacebook-github-bot
authored andcommitted
Fixing race condition at Module::forward method (#21398)
Summary: Pull Request resolved: #21398 Module::forward method calls find_method() function potentially in multiple threads. Internally it calls find_offset() method and reads dict_ object. If the correspondent name is not in a dictionary thread call insert() method and modifies dict_ object. At the same time when first thread modifies dict_ object another thread can enter forward()->find_method()->find_offset() path and access dict_ object for reading while it have been modified -> crash. Moved mutex protection up to protect both calls find_offset() and insert(). Consider to use C++ 17 shared_mutex locking object instead of recursive_mutex object. Reviewed By: bddppq Differential Revision: D15638942 fbshipit-source-id: ca6a453448302a0b3666c87724755fa4e9ce242f
1 parent 95eb933 commit 42b2f56

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

torch/csrc/jit/script/module.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,12 @@ struct TORCH_API Module {
304304
return offset ? modules_[*offset] : nullptr;
305305
}
306306
Method* find_method(const std::string& name) const {
307+
// find_offset() method reads "dict_" object.
308+
// Lock because another thread can modify "dict_" object at the same time
309+
// calling insert() method.
310+
// Ideally recursive_mutex should be replaced with shared_mutex (C++ 17)
311+
// for the performance reasons.
312+
std::unique_lock<std::recursive_mutex> keeper(create_method_guard_);
307313
auto offset = find_offset(name, EntityType::METHOD);
308314
if (offset) {
309315
return methods_[*offset].get();
@@ -314,7 +320,6 @@ struct TORCH_API Module {
314320
// but we have to update the internal Method cache.
315321
// This can be removed when class_compilation_unit() is the source of
316322
// truth for methods.
317-
std::lock_guard<std::recursive_mutex> guard(create_method_guard_);
318323
Module* mutable_this = const_cast<Module*>(this);
319324
std::unique_ptr<Method> m(new Method(mutable_this, fn));
320325
return mutable_this

0 commit comments

Comments
 (0)