Skip to content

[RFC] Integrate autograd profiler with torch.distributed.rpc #39675

@rohan-varma

Description

@rohan-varma

🚀 Feature

The profiler is a useful tool to gain insight regarding the operations run inside a model, and is a commonly used tool to diagnose performance issues and optimize models. This RFC seeks to describe how the profiler will be integrated with torch.distributed.rpc, so RPC-based training jobs can benefit from the profiler. In particular, the main motivation is to be able to remotely collect the profile for operations run under RPC across different nodes.

Example use case

# Assume that RPC is initialized across worker_1 and worker_2
# On worker_1
with torch.autograd.profiler.profile(profile_memory=True) as p:
    rpc.rpc_sync(worker_2, torch.add, args=(torch.tensor(1), torch.tensor(1))
print(p.key_averages().table()) # should show that torch.add ran on remote node, in addition to the local profiling output
p.export_chrome_trace("/tmp/dist_trace.json") # should show torch.add in trace, and indicate that it ran on a remote node

The printed table will look like:

  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
Name                                                  Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     Number of Calls  Node ID
----------------------------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
rpc_async#udf_with_torch_ops(worker1 -> worker2)      0.00%            0.000us          0                4.549ms          4.549ms          1                1
ones                                                  8.31%            29.587us         20.72%           73.806us         36.903us         2                2
empty                                                 16.42%           58.487us         16.42%           58.487us         9.748us          6                2
fill_                                                 5.77%            20.557us         6.78%            24.136us         12.068us         2                2
is_complex                                            1.00%            3.579us          1.00%            3.579us          1.789us          2                2
add                                                   18.33%           65.315us         22.42%           79.883us         79.883us         1                2
mul                                                   13.40%           47.737us         15.98%           56.943us         56.943us         1                2
relu                                                  7.34%            26.154us         15.34%           54.637us         54.637us         1                2
threshold                                             5.77%            20.561us         8.00%            28.483us         28.483us         1                2
sigmoid                                               11.06%           39.392us         25.54%           90.965us         90.965us         1                2
sigmoid_out                                           10.46%           37.260us         12.59%           44.865us         44.865us         1                2
resize_                                               2.13%            7.605us          2.13%            7.605us          7.605us          1                2
----------------------------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------

The trace would look like:
dist_trace_correct_offset
(Note that the events on Node 2 are offset by an approximate time that is described below).

Design and Implementation

Similar to how autograd information is propagated in RPC, we will wrap internal RPC messages with profiling information when RPCs are invoked under the profiler context. This wrapping/unwrapping of messages will be done transparently for the user, and the user will not have to change their code/use of the profiler at all. The flow is as follows:

  1. RPC APIs detect that profiler is enabled in the current thread (through torch::autograd::profiler::profilerEnabled()).
  2. If the profiler is enabled, we profile the current RPC request using the record_function API provided by the profiler. We also attach callbacks to the future corresponding to the RPC, so that callbacks corresponding to the profiler can be run at the appropriate time.
  3. After building the internal message to send over the wire via RPC, we wrap this message into another message that contains the profiling metadata. For example, this metadata will include the ProfilerConfig that the profiler on the original node was invoked with. ProfilerConfig contains data such as whether the profiler was enabled, whether we should profile memory, record cuda, etc.
  4. On the receiving end, this message is deserialized. When parsing a message of this type, we enable the profiler on the thread that is satisfying the nested request. We then add callbacks that aggregate the torch::autograd::profiler::Events corresponding to the operations that were profiled. After building the response message, we wrap it into a message that contains the profiling response. For example, this profiling response will contain metadata about the profiling as well as the profiled events. We send this result back to the caller over the wire.
  5. When the caller has received the response, it will unwrap and read the events that are profiled on the remote node. We will then add these events to a separate list on the [ProfilerThreadLocalState] (https://github.com/pytorch/pytorch/blob/a5e023f28ab3b0d7f39d6ab9487e409a6db94eb7/torch/csrc/autograd/profiler.cpp).
  6. When the profiler is disabled on the remote node (generally this is done by exiting the context manager), we run a consolidation and parsing process that spans C++ and Python in order to aggregate all of the profiled events. We will parse remote events and display the worker ID that they ran on. When computing start and end CPU intervals we will use the __start_profile Event on the remote node. For traces, we will also offset the remote trace by an estimated delta, which is currently computed as the difference in __start_profile events on the local and remote node, which aims to approximate the amount of time from when the profiler was enabled locally and remotely (note that this does not account for possible clock skew issues).

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @rohan-varma @xush6528 @jjlilley @osalpekar

Metadata

Metadata

Assignees

Labels

module: rpcRelated to RPC, distributed autograd, RRef, and distributed optimizeroncall: profilerprofiler-related issues (cpu, gpu, kineto)triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions