-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
We have seen multiple users hit the problem of different DDP instance having different number of input batches. [e.g.] As a result, DDP instances processing more input batches will hang, as the peer with least input batch will not join those additional allreduce operations.
Applications can address this by sth like:
for batch in get_batch():
x = torch.tensor(int(has_next())
op = all_reduce(x, async_op=True)
ddp(batch).sum().backward()
opt.step()
op.wait()
if x.item() > 0:
breakAs the allrecuce is async and can overlap with the forward+backward+optimizer computation, the extra overhead should be fine. The question is whether we should implement this as an helper API in DDP, or can we leave it out as the application-side solution is simple enough?
@pritamdamania87 mentioned that this solution might not be sufficient for applications that cannot support has_next() API. In that case, we would need to support sth like below:
for data in iterator:
loss = ddp_model(data)
loss.backward()
optim.step()
# This writes EOF to the store, and abort allreduce in other DDP instances.
ddp_model.mark_end_of_data()This approach is more versatile, but needs to expose ProcessGroup internal Store to DDP and also need to implement abort for all ProcessGroup backends.
cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @xush6528 @osalpekar