Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions tests/experimental/test_async_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,6 @@ def update_model_version(self, version):
def stop(self):
pass

def pause(self):
pass

def resume(self):
pass

def send_weights(self, iterator):
pass


class TestAsyncGRPOTrainer(TrlTestCase):
def test_init_minimal(self):
Expand Down
91 changes: 66 additions & 25 deletions trl/experimental/async_grpo/async_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import textwrap
import time
from collections import defaultdict
from collections.abc import Callable, Iterator
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Protocol

Expand All @@ -35,6 +35,7 @@

from .async_grpo_config import AsyncGRPOConfig
from .async_rollout_worker import AsyncRolloutWorker
from .weight_transfer import WeightTransferClient


logger = get_logger(__name__)
Expand All @@ -57,9 +58,6 @@ class RolloutWorkerProtocol(Protocol):

def start(self) -> None: ...
def stop(self) -> None: ...
def pause(self) -> None: ...
def resume(self) -> None: ...
def send_weights(self, iterator: Iterator[tuple[str, torch.Tensor]]) -> None: ...
def update_model_version(self, version: int) -> None: ...


Expand All @@ -77,6 +75,37 @@ def on_step_end(self, _args, state, _control, **_kwargs):
self.fn()


class _InitialWeightSyncCallback(TrainerCallback):
"""Idempotent: NCCL group setup + cold weight sync to vLLM on train begin."""

def __init__(self, trainer: "AsyncGRPOTrainer"):
self._trainer = trainer
self._fired = False

def on_train_begin(self, _args, _state, _control, **_kwargs):
if self._fired:
return
self._fired = True
if self._trainer.accelerator.is_main_process and self._trainer.weight_transfer is not None:
self._trainer.weight_transfer.init_weight_transfer()
self._trainer._sync_weight()


class _StartRolloutWorkerCallback(TrainerCallback):
"""Idempotent: starts the rollout worker. Must be registered AFTER `_InitialWeightSyncCallback`."""

def __init__(self, trainer: "AsyncGRPOTrainer"):
self._trainer = trainer
self._fired = False

def on_train_begin(self, _args, _state, _control, **_kwargs):
if self._fired:
return
self._fired = True
if self._trainer.accelerator.is_main_process and self._trainer.rollout_worker is not None:
self._trainer.rollout_worker.start()


class RolloutQueueDataset(torch.utils.data.IterableDataset):
def __init__(self, rollout_queue, model_version_fn, max_staleness=3, timeout=120.0):
self.queue = rollout_queue
Expand Down Expand Up @@ -352,7 +381,10 @@ def __init__(

if rollout_worker is not None:
# Use the injected worker (e.g. a stub in tests). The queue is owned by the worker.
# Weight transfer is also expected to be wired by the test fixture (or left as None
# if the stub doesn't sync to a real vLLM).
self.rollout_worker = rollout_worker
self.weight_transfer = None
else:
# Collect weight metadata once — names/dtypes/shapes are fixed for the lifetime of training.
# DTensor.shape returns the global shape without triggering any all-gather.
Expand All @@ -363,6 +395,17 @@ def __init__(
weight_names.append(name)
weight_dtype_names.append(str(param.dtype).split(".")[-1])
weight_shapes.append(list(param.shape))
self.weight_transfer = WeightTransferClient(
vllm_server_url=self.args.vllm_server_base_url,
server_timeout=self.args.vllm_server_timeout,
weight_update_info={
"names": weight_names,
"dtype_names": weight_dtype_names,
"shapes": weight_shapes,
"packed": True,
"is_checkpoint_format": True,
},
)
self.rollout_worker = AsyncRolloutWorker(
model_name=model_name,
dataset=train_dataset,
Expand All @@ -377,21 +420,21 @@ def __init__(
max_tokens=self.args.max_completion_length,
temperature=self.args.temperature,
request_timeout=self.args.request_timeout,
server_timeout=self.args.vllm_server_timeout,
chat_template_kwargs=self.args.chat_template_kwargs,
max_tool_calling_iterations=self.args.max_tool_calling_iterations,
log_completions=self.args.log_completions,
num_completions_to_print=self.args.num_completions_to_print,
weight_names=weight_names,
weight_dtype_names=weight_dtype_names,
weight_shapes=weight_shapes,
)
# TODO(@aminediro): decide if this is returned by the worker or common API that is passed to the worker later.
self.rollout_queue = self.rollout_worker.rollout_buffer
else:
self.rollout_queue = None
self.rollout_worker = None
self.weight_transfer = None

# Add callbacks
# Add callbacks. Registration order matters: weight sync first, then worker start.
self.add_callback(_InitialWeightSyncCallback(self))
self.add_callback(_StartRolloutWorkerCallback(self))
self.add_callback(StepIntervalCallback(self._sync_weight, self.args.weight_sync_steps))

def get_train_dataloader(self) -> DataLoader:
Expand Down Expand Up @@ -581,17 +624,17 @@ def _streaming_iter(self):
def _sync_weight(self):
t0 = time.time()
logger.info("Weight sync: pausing vLLM...")
if self.accelerator.is_main_process and self.rollout_worker:
self.rollout_worker.pause()
if self.accelerator.is_main_process and self.weight_transfer:
self.weight_transfer.pause()
t_pause = time.time()
logger.info(f"Weight sync: pause took {t_pause - t0:.1f}s, waiting for all ranks...")

self.accelerator.wait_for_everyone()
t_barrier = time.time()

logger.info(f"Weight sync: transferring weights... (barrier took {t_barrier - t_pause:.1f}s)")
if self.accelerator.is_main_process and self.rollout_worker:
self.rollout_worker.send_weights(self._streaming_iter())
if self.accelerator.is_main_process and self.weight_transfer:
self.weight_transfer.send_weights(self._streaming_iter())
else:
# Non-rank-0 processes must still participate in full_tensor() collectives for FSDP2.
for _ in self._streaming_iter():
Expand All @@ -601,24 +644,22 @@ def _sync_weight(self):
self.accelerator.wait_for_everyone()

logger.info(f"Weight sync: resuming vLLM... (transfer took {t_transfer - t_barrier:.1f}s)")
if self.accelerator.is_main_process and self.rollout_worker:
self.rollout_worker.resume()
if self.accelerator.is_main_process:
if self.weight_transfer:
self.weight_transfer.resume()
self.model_version += 1
self.rollout_worker.update_model_version(self.model_version)
if self.rollout_worker:
self.rollout_worker.update_model_version(self.model_version)
weight_sync_time_s = time.time() - t0
self._metrics["train"]["weight_sync_time_s"].append(weight_sync_time_s)
logger.info(f"Weight sync: done. Total {weight_sync_time_s:.1f}s")

def _inner_training_loop(self, *args, **kwargs):
# Start the rollout worker here (not in __init__) so that checkpoint loading in Trainer.train()
# has already restored the model weights. The sequence is: start worker thread → wait for NCCL
# init → sync weights to vLLM → begin generation. This ensures vLLM always uses the current
# policy before producing any samples (matters for resumed runs, harmless for fresh ones).
self._sync_weight()
if self.accelerator.is_main_process and self.rollout_worker:
self.rollout_worker.start()
try:
return super()._inner_training_loop(*args, **kwargs)
finally:
if self.accelerator.is_main_process and self.rollout_worker:
self.rollout_worker.stop()
if self.accelerator.is_main_process:
if self.rollout_worker:
self.rollout_worker.stop()
if self.weight_transfer:
self.weight_transfer.destroy()
Loading
Loading