Skip to content

Support uneven DDP inputs #33148

@mrshenli

Description

@mrshenli

We have seen multiple users hit the problem of different DDP instance having different number of input batches. [e.g.] As a result, DDP instances processing more input batches will hang, as the peer with least input batch will not join those additional allreduce operations.

Applications can address this by sth like:

for batch in get_batch():
    x = torch.tensor(int(has_next())
    op = all_reduce(x, async_op=True)
    ddp(batch).sum().backward()
    opt.step()
    op.wait()
    if x.item() > 0:
        break

As the allrecuce is async and can overlap with the forward+backward+optimizer computation, the extra overhead should be fine. The question is whether we should implement this as an helper API in DDP, or can we leave it out as the application-side solution is simple enough?

@pritamdamania87 mentioned that this solution might not be sufficient for applications that cannot support has_next() API. In that case, we would need to support sth like below:

for data in iterator:
  loss = ddp_model(data)
  loss.backward()
  optim.step()

# This writes EOF to the store, and abort allreduce in other DDP instances.
ddp_model.mark_end_of_data()

This approach is more versatile, but needs to expose ProcessGroup internal Store to DDP and also need to implement abort for all ProcessGroup backends.

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @xush6528 @osalpekar

Metadata

Metadata

Assignees

No one assigned

    Labels

    oncall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions