-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🚀 Feature
The profiler is a useful tool to gain insight regarding the operations run inside a model, and is a commonly used tool to diagnose performance issues and optimize models. This RFC seeks to describe how the profiler will be integrated with torch.distributed.rpc, so RPC-based training jobs can benefit from the profiler. In particular, the main motivation is to be able to remotely collect the profile for operations run under RPC across different nodes.
Example use case
# Assume that RPC is initialized across worker_1 and worker_2
# On worker_1
with torch.autograd.profiler.profile(profile_memory=True) as p:
rpc.rpc_sync(worker_2, torch.add, args=(torch.tensor(1), torch.tensor(1))
print(p.key_averages().table()) # should show that torch.add ran on remote node, in addition to the local profiling output
p.export_chrome_trace("/tmp/dist_trace.json") # should show torch.add in trace, and indicate that it ran on a remote node
The printed table will look like:
--------------- --------------- --------------- --------------- --------------- ---------------
Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg Number of Calls Node ID
---------------------------------------------------- --------------- --------------- --------------- --------------- --------------- --------------- ---------------
rpc_async#udf_with_torch_ops(worker1 -> worker2) 0.00% 0.000us 0 4.549ms 4.549ms 1 1
ones 8.31% 29.587us 20.72% 73.806us 36.903us 2 2
empty 16.42% 58.487us 16.42% 58.487us 9.748us 6 2
fill_ 5.77% 20.557us 6.78% 24.136us 12.068us 2 2
is_complex 1.00% 3.579us 1.00% 3.579us 1.789us 2 2
add 18.33% 65.315us 22.42% 79.883us 79.883us 1 2
mul 13.40% 47.737us 15.98% 56.943us 56.943us 1 2
relu 7.34% 26.154us 15.34% 54.637us 54.637us 1 2
threshold 5.77% 20.561us 8.00% 28.483us 28.483us 1 2
sigmoid 11.06% 39.392us 25.54% 90.965us 90.965us 1 2
sigmoid_out 10.46% 37.260us 12.59% 44.865us 44.865us 1 2
resize_ 2.13% 7.605us 2.13% 7.605us 7.605us 1 2
---------------------------------------------------- --------------- --------------- --------------- --------------- --------------- --------------- ---------------
The trace would look like:

(Note that the events on Node 2 are offset by an approximate time that is described below).
Design and Implementation
Similar to how autograd information is propagated in RPC, we will wrap internal RPC messages with profiling information when RPCs are invoked under the profiler context. This wrapping/unwrapping of messages will be done transparently for the user, and the user will not have to change their code/use of the profiler at all. The flow is as follows:
- RPC APIs detect that profiler is enabled in the current thread (through
torch::autograd::profiler::profilerEnabled()). - If the profiler is enabled, we profile the current RPC request using the
record_functionAPI provided by the profiler. We also attach callbacks to the future corresponding to the RPC, so that callbacks corresponding to the profiler can be run at the appropriate time. - After building the internal message to send over the wire via RPC, we wrap this message into another message that contains the profiling metadata. For example, this metadata will include the
ProfilerConfigthat the profiler on the original node was invoked with.ProfilerConfigcontains data such as whether the profiler was enabled, whether we should profile memory, record cuda, etc. - On the receiving end, this message is deserialized. When parsing a message of this type, we enable the profiler on the thread that is satisfying the nested request. We then add callbacks that aggregate the
torch::autograd::profiler::Events corresponding to the operations that were profiled. After building the response message, we wrap it into a message that contains the profiling response. For example, this profiling response will contain metadata about the profiling as well as the profiled events. We send this result back to the caller over the wire. - When the caller has received the response, it will unwrap and read the events that are profiled on the remote node. We will then add these events to a separate list on the [
ProfilerThreadLocalState] (https://github.com/pytorch/pytorch/blob/a5e023f28ab3b0d7f39d6ab9487e409a6db94eb7/torch/csrc/autograd/profiler.cpp). - When the profiler is disabled on the remote node (generally this is done by exiting the context manager), we run a consolidation and parsing process that spans C++ and Python in order to aggregate all of the profiled events. We will parse remote events and display the worker ID that they ran on. When computing start and end CPU intervals we will use the
__start_profileEvent on the remote node. For traces, we will also offset the remote trace by an estimated delta, which is currently computed as the difference in__start_profileevents on the local and remote node, which aims to approximate the amount of time from when the profiler was enabled locally and remotely (note that this does not account for possible clock skew issues).
cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @rohan-varma @xush6528 @jjlilley @osalpekar