Skip to content

[RFC] Distributed optimizer with TorchScript support #46883

@wanchaol

Description

@wanchaol

Motivation

PyTorch provides a broad set of optimizers for training algorithms, and these have been used repeatedly as part of the python API. However, users often want to use multithreaded training instead of multiprocess training as it provides better resource utilization and efficiency in the context of large scale distributed training (e.g. Distributed Model Parallel) or any RPC-based training application). Users couldn’t do this with with distributed optimizer before because we need to get rid of the python Global Interpreter Lock (GIL) limitation to achieve this.

New DistributedOptimizer with TorchScript support

To make Distributed Optimizer work with TorchScript, we will refactor the existing optimizers to have a functional API and then let Distributed Optimizer use the functional API to have the TorchScript support.

Functional Optimizer

We have introduced the functional optimizer concept in torch.optim, and allow the computation and state management be separate. This makes it easier to let optimizers be TorchScript compatible, and unlocks the opportunity for distributed optimizer to use them in order to be GIL-free.

DistributedOptimizer

In Distributed Optimizer, we maintain a separate set of functional optimizers that consists of state + computation, where the computation part use the shared functional API we introduced above.

It’s OK for distributed optimizer to stay in python. What we will do in distributed optimizer is that we would like to maintain the API that we expose to user, but we would like to use the functional optimizer. The trick here is that we maintain a map for optim_class that likes below:

{
    torch.optim.Adagrad: torch.distributed.optim.FunctionalAdagrad,
    torch.optim.SGD: torch.distributed.optim.FunctionalSGD,
    ...
}

In DistributedOptimizer initialization, we will just swap the OSS optimizer and use the functional optimizer we exposed to initialize the _LocalOptimizer (and compile them). A rough change like below:

class DistributedOptimizer:
    def __init__(self, optimizer_class, params_rref, *args, **kwargs):

*        functional_optimizer_class = optim_table.get(optimizer_class, None)
        if functional_optimizer_class is None:
            raise Warning("Optimizer " + str(optimizer_class) + " not supported")


        # throw warning/logs switching from oss optimizer to functional optimizer*
        per_worker_params_rref = defaultdict(list)
        for param in params_rref:
            per_worker_params_rref[param.owner()].append(param)

        remote_optim_futs = []
        for worker, param_rrefs in per_worker_params_rref.items():
            remote_optim_rref_fut = rpc.rpc_async(
                worker,
                _new_local_optimizer,
                args=(*functional_optimizer_class*, param_rrefs) + args,
                kwargs=kwargs,
            )
            remote_optim_futs.append(remote_optim_rref_fut)

        self.remote_optimizers = _wait_for_all(remote_optim_futs)

Note we will need to refactor all optimizers in torch.optim to follow the functional API, and then register the functional optimizers one by one to make all of them available in Distributed Optimizer.

Usage

The new distributed optimizer has exact same interface (APIs) as before, we exposed the same API as python, but do all the heavy lifting under the hood, upon DistributedOptimizer construction, it tries to find the corresponding functional optimizer match, and then construct the local optimizer, automatically turn optimizers in each worker into TorchScript to make it GIL free. Example usage is exactly the same with Python API:

import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
from torch import optim
from torch.distributed.optim import DistributedOptimizer

with dist_autograd.context() as context_id:
  # Forward pass.
  rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
  rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
  loss = rref1.to_here() + rref2.to_here()

  # Backward pass.
  dist_autograd.backward(context_id, [loss.sum()])

  # Optimizer, pass in optim.Adagrad, DistributedOptimizer will
  # automatically convert/compile it to TorchScript (GIL-free)
  dist_optim = DistributedOptimizer(
     optim.Adagrad,
     [rref1, rref2],
     lr=0.05,
  )
  dist_optim.step(context_id)

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @rohan-varma @jjlilley @osalpekar @jiayisuse @gmagogsfm @xush6528 @agolynski

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: rpcRelated to RPC, distributed autograd, RRef, and distributed optimizeroncall: jitAdd this issue/PR to JIT oncall triage queue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions