Skip to content

Commit 1168a1a

Browse files
committed
[RPC profiling] Extend RPC profiling to support async function execution over RPC.
Pull Request resolved: #44664 Closes #39971. This PR adds support for functions decorated with `@rpc.functions.async_execution` to be profiled over RPC as builtins, jit functions, and blocking python UDFs currently can be. The reasoning for this is to provide complete feature support in terms of RPC profiling and the various types of functions users can run. To enable this, the PR below this enables calling `disableProfiler()` safely from another thread. We use that functionality to defer disabling the profiler on the server until the future corresponding to the RPC request completes (rather than only the blocking `processRPC` call as was done previously). Since when the future completes we've kicked off the async function and the future corresponding to it has completed, we are able to capture any RPCs the function would have called and the actual work done on the other node. For example, if the following async function is ran on a server over RPC: ``` def slow_add(x, y): time.sleep(1) return torch.add(x, y) @rpc.functions.async_execution def slow_async_add(to, x, y): return rpc.rpc_async(to, slow_add, args=(x, y)) ``` we expect to see the original RPC profiled, the nested RPC profiled, and the actual torch.add() work. All of these events should be recorded with the correct node id. Here is an example profiling output: ``` ------------------------------------------------------------------------------------------------------------------------- --------------- --------------- --------------- -------- ------- --------------- --------------- --------------- Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg Number of Calls Node ID ------------------------------------------------------------------------------------------------------------------------- --------------- --------------- --------------- -------- ------- --------------- --------------- --------------- rpc_async#slow_async_add(worker1 -> worker2) 0.00% 0.000us 0 1.012s 1.012s 1 1 aten::empty 7.02% 11.519us 7.02% 11.519us 11.519us 1 1 rpc_async#slow_async_add(worker1 -> worker2)#remote_op: rpc_async#slow_add(worker2 -> worker3) 0.00% 0.000us 0 1.006s 1.006s 1 2 rpc_async#slow_async_add(worker1 -> worker2)#remote_op: aten::empty 7.21% 11.843us 7.21% 11.843us 11.843us 1 2 rpc_async#slow_async_add(worker1 -> worker2)#remote_op: rpc_async#slow_add(worker2 -> worker3)#remote_op: aten::add 71.94% 118.107us 85.77% 140.802us 140.802us 1 3 rpc_async#slow_async_add(worker1 -> worker2)#remote_op: rpc_async#slow_add(worker2 -> worker3)#remote_op: aten::empty 13.82% 22.695us 13.82% 22.695us 22.695us 1 3 ------------------------------------------------------------------------------------------------------------------------- --------------- --------------- --------------- -------- ------- --------------- --------------- --------------- Self CPU time total: 164.164us ``` This PR also moves a bunch of the profiling logic to `rpc/utils.cpp` to declutter `request_callback` code. ghstack-source-id: 112032360 Differential Revision: [D23638387](https://our.internmc.facebook.com/intern/diff/D23638387/)
1 parent b5d5689 commit 1168a1a

File tree

7 files changed

+254
-117
lines changed

7 files changed

+254
-117
lines changed

torch/csrc/autograd/profiler.h

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,14 @@ struct TORCH_API RecordProfile {
370370
void processEvents(const std::vector<Event*>& events);
371371
};
372372

373+
// A struct to control settings of disableProfiler options, to be used in
374+
// conjunction with TlSProfilerGuard.
375+
376+
struct TORCH_API ProfilerDisableOptions {
377+
bool cleanupTLSState = true;
378+
bool consolidate = true;
379+
};
380+
373381
// A guard that enables the profiler, taking in an optional callback to process
374382
// the results
375383
// Usage:
@@ -379,23 +387,35 @@ struct TORCH_API RecordProfile {
379387
// });
380388
// Code to profile
381389
// }
390+
382391
struct TORCH_API TLSProfilerGuard {
383392
explicit TLSProfilerGuard(
384393
const ProfilerConfig& cfg,
385394
c10::optional<std::function<void(const thread_event_lists&)>>
386-
resultCallback = c10::nullopt)
387-
: cb_(std::move(resultCallback)) {
395+
resultCallback = c10::nullopt,
396+
c10::optional<ProfilerDisableOptions> profilerDisableOptions =
397+
c10::nullopt)
398+
: cb_(std::move(resultCallback)),
399+
profilerDisableOptions_(std::move(profilerDisableOptions)) {
388400
enableProfiler(cfg);
389401
}
390402
~TLSProfilerGuard() {
391-
thread_event_lists event_lists = disableProfiler();
403+
thread_event_lists event_lists;
404+
if (profilerDisableOptions_) {
405+
event_lists = disableProfiler(
406+
profilerDisableOptions_->cleanupTLSState,
407+
profilerDisableOptions_->consolidate);
408+
} else {
409+
event_lists = disableProfiler();
410+
}
392411
if (cb_) {
393412
(*cb_)(event_lists);
394413
}
395414
}
396415

397416
private:
398417
c10::optional<std::function<void(const thread_event_lists&)>> cb_;
418+
c10::optional<ProfilerDisableOptions> profilerDisableOptions_;
399419
};
400420

401421
} // namespace profiler

torch/csrc/distributed/rpc/request_callback_no_python.cpp

Lines changed: 45 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -486,93 +486,12 @@ void RequestCallbackNoPython::processRpc(
486486
const auto profilingKeyId = rpcWithProfilingReq.getProfilingId();
487487
auto wrappedRpcResponseFuture = std::make_shared<FutureMessage>();
488488
// Enable the profiler with the config from the sender.
489-
std::vector<torch::autograd::profiler::Event> profiledEvents;
489+
torch::autograd::profiler::ProfilerDisableOptions requestThreadOptions;
490+
requestThreadOptions.cleanupTLSState = true;
491+
requestThreadOptions.consolidate = false;
490492
{
491493
torch::autograd::profiler::TLSProfilerGuard g(
492-
profilingConfig,
493-
[&profiledEvents, profilingConfig](
494-
const std::vector<std::vector<
495-
torch::autograd::profiler::Event>>& event_lists) {
496-
// Gather all events into a vector
497-
for (auto& l : event_lists) {
498-
for (auto& e : l) {
499-
profiledEvents.push_back(e);
500-
}
501-
}
502-
// find __start_profile event and __cuda_start_event.
503-
bool cuda_profiling_enabled = profilingConfig.state ==
504-
torch::autograd::profiler::ProfilerState::CUDA;
505-
bool found_cpu_start = false;
506-
const torch::autograd::profiler::Event* profilerStart = nullptr;
507-
// Each device has its own cudaProfilerStart, so we must take
508-
// care to use the correct one depending on the device the
509-
// operation ran on.
510-
std::unordered_map<int, const torch::autograd::profiler::Event*>
511-
cudaProfilerStarts;
512-
for (auto& e : profiledEvents) {
513-
if (!found_cpu_start &&
514-
0 == strcmp(e.name(), "__start_profile")) {
515-
profilerStart = &e;
516-
found_cpu_start = true;
517-
}
518-
if (cuda_profiling_enabled &&
519-
0 == strcmp(e.name(), "__cuda_start_event")) {
520-
e.setCudaUs(e.cpu_us());
521-
auto device = e.device();
522-
TORCH_CHECK(
523-
device != -1,
524-
"CUDA profiling was enabled but could not find CUDA device.");
525-
TORCH_CHECK(
526-
cudaProfilerStarts.find(device) ==
527-
cudaProfilerStarts.end(),
528-
c10::str(
529-
"Duplicate __cuda_start_event found for ", device));
530-
cudaProfilerStarts[device] = &e;
531-
}
532-
// TODO: determine no. of CUDA devices and break here if we have
533-
// a cudaProfilerStart for all of them, in the case of cuda
534-
// profiling.
535-
if (found_cpu_start && !cuda_profiling_enabled) {
536-
break;
537-
}
538-
}
539-
// We should always find __start_profile.
540-
TORCH_CHECK(
541-
profilerStart != nullptr,
542-
"Expected to find __start_profile event.");
543-
// Should have >= 1 CUDA start event.
544-
// TODO: we can enhance this assert by ensuring we have found a
545-
// start for every available CUDA device.
546-
TORCH_CHECK(
547-
!cuda_profiling_enabled || cudaProfilerStarts.size() > 0,
548-
"Profiler was enabled with CUDA recording, but did not find __cuda_start_event.");
549-
550-
if (cuda_profiling_enabled) {
551-
// Compute and set global time for when this CUDA kernel was
552-
// launched/ended, since deserialized event will not have a
553-
// corresponding CUDA event.
554-
for (auto& e : profiledEvents) {
555-
if (e.has_cuda()) {
556-
auto cuda_device = e.device();
557-
TORCH_CHECK(
558-
cuda_device != -1,
559-
"CUDA profiling was enabled but could not find CUDA device.");
560-
auto it = cudaProfilerStarts.find(cuda_device);
561-
TORCH_CHECK(
562-
it != cudaProfilerStarts.end(),
563-
c10::str(
564-
"Failed to find __cuda_start_event for device ",
565-
cuda_device));
566-
auto cudaProfilerStartEvent = it->second;
567-
double cuda_elapsed_us =
568-
cudaProfilerStartEvent->cuda_elapsed_us(e);
569-
int64_t cuda_us =
570-
cuda_elapsed_us + cudaProfilerStartEvent->cpu_us();
571-
e.setCudaUs(cuda_us);
572-
}
573-
}
574-
}
575-
});
494+
profilingConfig, c10::nullopt, requestThreadOptions);
576495
TORCH_INTERNAL_ASSERT(
577496
torch::autograd::profiler::profilerEnabled(),
578497
"Expected profiler to be enabled!");
@@ -583,25 +502,48 @@ void RequestCallbackNoPython::processRpc(
583502
wrappedMsgType,
584503
messageId,
585504
wrappedRpcResponseFuture);
586-
}
587-
wrappedRpcResponseFuture->addCallback([wrappedRpcResponseFuture,
505+
506+
auto tid = std::this_thread::get_id();
507+
wrappedRpcResponseFuture->addCallback(
508+
at::wrapPropagateTLSState<void>([wrappedRpcResponseFuture,
588509
responseFuture,
589-
profiledEvents =
590-
std::move(profiledEvents),
591-
profilingKeyId] {
592-
if (wrappedRpcResponseFuture->hasError()) {
593-
// Propagate error
594-
responseFuture->setError(wrappedRpcResponseFuture->error()->what());
595-
} else {
596-
auto rpcWithProfilingResp = std::make_unique<RpcWithProfilingResp>(
597-
MessageType::RUN_WITH_PROFILING_RESP,
598-
std::move(*wrappedRpcResponseFuture).moveValue(),
599-
profiledEvents,
600-
profilingKeyId);
601-
responseFuture->markCompleted(
602-
std::move(*rpcWithProfilingResp).toMessage());
603-
}
604-
});
510+
profilingKeyId,
511+
profilingConfig,
512+
tid] {
513+
std::vector<torch::autograd::profiler::Event> profiledEvents;
514+
// Defer consolidation of profiler events until async work has
515+
// completed (such as async UDF)
516+
517+
TORCH_INTERNAL_ASSERT(
518+
torch::autograd::profiler::profilerEnabled(),
519+
"Expected profiler to be enabled!");
520+
521+
// Only clean up TLS states of profiler if we are disabling on
522+
// the main thread.
523+
bool shouldCleanUpTLSStates = (std::this_thread::get_id() == tid);
524+
auto event_lists = torch::autograd::profiler::disableProfiler(
525+
shouldCleanUpTLSStates, true);
526+
if (wrappedRpcResponseFuture->hasError()) {
527+
// Propagate error
528+
// No need to propagate remote events in the case of an error.
529+
responseFuture->setError(
530+
wrappedRpcResponseFuture->error()->what());
531+
} else {
532+
populateRemoteProfiledEvents(
533+
profiledEvents, profilingConfig, event_lists);
534+
auto rpcWithProfilingResp =
535+
std::make_unique<RpcWithProfilingResp>(
536+
MessageType::RUN_WITH_PROFILING_RESP,
537+
std::move(*wrappedRpcResponseFuture).moveValue(),
538+
profiledEvents,
539+
profilingKeyId);
540+
responseFuture->markCompleted(
541+
std::move(*rpcWithProfilingResp).toMessage());
542+
}
543+
}));
544+
// Exiting the scope will disable the profiler on this thread with the
545+
// options specified above.
546+
}
605547
return;
606548
}
607549
default: {

torch/csrc/distributed/rpc/utils.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,85 @@ std::vector<at::IValue> readWrappedPayload(
713713
payload.resize(payload.size() - additionalPayloadSize);
714714
return tupleElements;
715715
}
716+
717+
void populateRemoteProfiledEvents(
718+
std::vector<torch::autograd::profiler::Event>& profiledEvents,
719+
const torch::autograd::profiler::ProfilerConfig& profilingConfig,
720+
const std::vector<std::vector<torch::autograd::profiler::Event>>&
721+
event_lists) {
722+
// Gather all events into a vector
723+
for (auto& l : event_lists) {
724+
for (auto& e : l) {
725+
profiledEvents.push_back(e);
726+
}
727+
}
728+
// find __start_profile event and __cuda_start_event.
729+
bool cuda_profiling_enabled =
730+
profilingConfig.state == torch::autograd::profiler::ProfilerState::CUDA;
731+
bool found_cpu_start = false;
732+
const torch::autograd::profiler::Event* profilerStart = nullptr;
733+
// Each device has its own cudaProfilerStart, so we must take
734+
// care to use the correct one depending on the device the
735+
// operation ran on.
736+
std::unordered_map<int, const torch::autograd::profiler::Event*>
737+
cudaProfilerStarts;
738+
for (auto& e : profiledEvents) {
739+
if (!found_cpu_start && 0 == strcmp(e.name(), "__start_profile")) {
740+
profilerStart = &e;
741+
found_cpu_start = true;
742+
}
743+
if (cuda_profiling_enabled && 0 == strcmp(e.name(), "__cuda_start_event")) {
744+
e.setCudaUs(e.cpu_us());
745+
auto device = e.device();
746+
TORCH_CHECK(
747+
device != -1,
748+
"CUDA profiling was enabled but could not find CUDA device.");
749+
TORCH_CHECK(
750+
cudaProfilerStarts.find(device) == cudaProfilerStarts.end(),
751+
c10::str("Duplicate __cuda_start_event found for ", device));
752+
cudaProfilerStarts[device] = &e;
753+
}
754+
// TODO: determine no. of CUDA devices and break here if we have
755+
// a cudaProfilerStart for all of them, in the case of cuda
756+
// profiling.
757+
if (found_cpu_start && !cuda_profiling_enabled) {
758+
break;
759+
}
760+
}
761+
// We should always find __start_profile.
762+
TORCH_CHECK(
763+
profilerStart != nullptr, "Expected to find __start_profile event.");
764+
// Should have >= 1 CUDA start event.
765+
// TODO: we can enhance this assert by ensuring we have found a
766+
// start for every available CUDA device.
767+
TORCH_CHECK(
768+
!cuda_profiling_enabled || cudaProfilerStarts.size() > 0,
769+
"Profiler was enabled with CUDA recording, but did not find __cuda_start_event.");
770+
771+
if (cuda_profiling_enabled) {
772+
// Compute and set global time for when this CUDA kernel was
773+
// launched/ended, since deserialized event will not have a
774+
// corresponding CUDA event.
775+
for (auto& e : profiledEvents) {
776+
if (e.has_cuda()) {
777+
auto cuda_device = e.device();
778+
TORCH_CHECK(
779+
cuda_device != -1,
780+
"CUDA profiling was enabled but could not find CUDA device.");
781+
auto it = cudaProfilerStarts.find(cuda_device);
782+
TORCH_CHECK(
783+
it != cudaProfilerStarts.end(),
784+
c10::str(
785+
"Failed to find __cuda_start_event for device ", cuda_device));
786+
auto cudaProfilerStartEvent = it->second;
787+
double cuda_elapsed_us = cudaProfilerStartEvent->cuda_elapsed_us(e);
788+
int64_t cuda_us = cuda_elapsed_us + cudaProfilerStartEvent->cpu_us();
789+
e.setCudaUs(cuda_us);
790+
}
791+
}
792+
}
793+
}
794+
716795
} // namespace rpc
717796
} // namespace distributed
718797
} // namespace torch

torch/csrc/distributed/rpc/utils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <c10/core/Device.h>
4+
#include <torch/csrc/autograd/profiler.h>
45
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
56
#include <torch/csrc/jit/serialization/pickle.h>
67
#include <torch/csrc/utils/byte_order.h>
@@ -125,6 +126,12 @@ TORCH_API std::vector<at::IValue> readWrappedPayload(
125126
std::vector<char>& payload,
126127
const rpc::Message& message);
127128

129+
TORCH_API void populateRemoteProfiledEvents(
130+
std::vector<torch::autograd::profiler::Event>& profiledEvents,
131+
const torch::autograd::profiler::ProfilerConfig& profilerConfig,
132+
const std::vector<std::vector<torch::autograd::profiler::Event>>&
133+
event_lists);
134+
128135
} // namespace rpc
129136
} // namespace distributed
130137
} // namespace torch

torch/testing/_internal/dist_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,26 @@
1818
INIT_METHOD_TEMPLATE = "file://{file_name}"
1919

2020

21+
22+
def single_threaded_process_group_agent(f):
23+
"""
24+
Forces ProcessGroupAgent to use only a single thread in the ThreadPool for
25+
sending and processing requests.
26+
"""
27+
@wraps(f)
28+
def wrapper(self, *args, **kwargs):
29+
backend_type = self.rpc_backend
30+
if backend_type == rpc.backend_registry.BackendType["PROCESS_GROUP"]:
31+
self.rpc_backend_options = rpc.backend_registry.construct_rpc_backend_options(
32+
self.rpc_backend,
33+
init_method=self.init_method,
34+
num_send_recv_threads=1,
35+
)
36+
return_value = f(self, *args, **kwargs)
37+
return return_value
38+
return wrapper
39+
40+
2141
def dist_init(old_test_method=None, setup_rpc=True, clean_shutdown=True,
2242
faulty_messages=None, messages_to_delay=None):
2343
"""

torch/testing/_internal/distributed/rpc/process_group_agent_test_fixture.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,19 @@ def rpc_backend(self):
1313

1414
@property
1515
def rpc_backend_options(self):
16-
return rpc.backend_registry.construct_rpc_backend_options(
17-
self.rpc_backend,
18-
init_method=self.init_method,
19-
# Some tests need additional threads (ex: test_trainer_ps)
20-
num_send_recv_threads=8,
21-
)
16+
try:
17+
return self._rpc_backend_options
18+
except AttributeError:
19+
return rpc.backend_registry.construct_rpc_backend_options(
20+
self.rpc_backend,
21+
init_method=self.init_method,
22+
# Some tests need additional threads (ex: test_trainer_ps)
23+
num_send_recv_threads=8,
24+
)
25+
26+
@rpc_backend_options.setter
27+
def rpc_backend_options(self, new_rpc_backend_options):
28+
self._rpc_backend_options = new_rpc_backend_options
2229

2330
def get_shutdown_error_regex(self):
2431
error_regexes = [

0 commit comments

Comments
 (0)