Skip to content

Commit aa4ea6e

Browse files
eqypytorchmergebot
authored andcommitted
[cuDNN][cuDNN V8 API] Fix incorrect use of emplace in the benchmark cache (#97838)
`emplace` does not overwrite the existing mapped value in a map if it already exists, which can lead to repeated execution of a plan that e.g., tries to allocate an OOM-inducing workspace size and retriggers either a heuristic run (or worse, a benchmark run). CC @ptrblck @ngimel @Fuzzkatt @syed-ahmed Pull Request resolved: #97838 Approved by: https://github.com/ngimel
1 parent 35be579 commit aa4ea6e

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

aten/src/ATen/native/cudnn/Conv_v8.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,9 @@ cudnn_frontend::ExecutionPlan* find(const KeyType& key) {
153153
return &(it->second);
154154
}
155155

156-
void emplace(const KeyType& key, T& results) {
156+
void update(const KeyType& key, T& results) {
157157
std::lock_guard<std::mutex> guard(mutex);
158+
engine_cache.erase(key);
158159
engine_cache.emplace(key, std::move(results));
159160
}
160161

@@ -548,7 +549,7 @@ void try_plans(cudnn_frontend::executionPlans_t& plans, const CacheKey& key, con
548549
for (auto & plan : plans) {
549550
try {
550551
run_conv_plan(handle, x, y, w, plan);
551-
benchmark_cache.emplace(key, plan);
552+
benchmark_cache.update(key, plan);
552553
return;
553554
} catch (cudnn_frontend::cudnnException &e) {} catch (CuDNNError &e) {}
554555
catch (c10::OutOfMemoryError &e) {
@@ -562,7 +563,7 @@ void try_plans_fused(cudnn_frontend::executionPlans_t& plans, const CacheKeyFuse
562563
for (auto & plan : plans) {
563564
try {
564565
run_conv_plan_fused(handle, x, y, w, z, b, plan);
565-
benchmark_cache_fused.emplace(key, plan);
566+
benchmark_cache_fused.update(key, plan);
566567
return;
567568
} catch (cudnn_frontend::cudnnException &e) {} catch (CuDNNError &e) {}
568569
catch (c10::OutOfMemoryError &e) {
@@ -583,7 +584,7 @@ bool try_configs(cudnn_frontend::EngineConfigList& configs, const std::string& o
583584
continue;
584585
}
585586
run_conv_plan(handle, x, y, w, plan);
586-
benchmark_cache.emplace(key, plan);
587+
benchmark_cache.update(key, plan);
587588
return true;
588589
} catch (cudnn_frontend::cudnnException &e) {} catch(CuDNNError &e) {}
589590
catch (c10::OutOfMemoryError &e) {
@@ -604,7 +605,7 @@ bool try_configs_fused(cudnn_frontend::EngineConfigList& configs, const std::str
604605
continue;
605606
}
606607
run_conv_plan_fused(handle, x, y, w, z, b, plan);
607-
benchmark_cache_fused.emplace(key, plan);
608+
benchmark_cache_fused.update(key, plan);
608609
return true;
609610
} catch (cudnn_frontend::cudnnException &e) {} catch(CuDNNError &e) {}
610611
catch (c10::OutOfMemoryError &e) {

0 commit comments

Comments
 (0)