feat(grpo): replace deprecated use_transformers_paged with transformers continuous batching#5765
feat(grpo): replace deprecated use_transformers_paged with transformers continuous batching#5765sergiopaniego wants to merge 12 commits into
use_transformers_paged with transformers continuous batching#5765Conversation
…rs continuous batching
use_transformers_paged with transformers continuous batching
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| "Using `use_transformers_continuous_batching` requires transformers>=5.8.0. " | ||
| "Please upgrade with `pip install --upgrade transformers`." | ||
| ) | ||
| from transformers.generation import ContinuousBatchingConfig |
There was a problem hiding this comment.
inlined import maybe we can move it ?
There was a problem hiding this comment.
It's intentional since it was introduced in transformers 5.4.0
| ) | ||
| from transformers.generation import ContinuousBatchingConfig | ||
|
|
||
| cb_kwargs = dict(args.transformers_continuous_batching_config or {}) |
There was a problem hiding this comment.
I wonder if there are some good defaults that we can set ? @sergiopaniego did you test different config to get a sense of throughput, latency, memory ?
There was a problem hiding this comment.
I would be interested in this as well!
Usually if no config is given, continuous batching adapts to the situation and sets smarts default values for the parameters itself, but if you see something different please let me know!
BTW, compile is turned off by default, because it adds a lot of warmup overhead. But if you re-use the continuous batching manager for several generations (you have to set persistent_manager=True so you don't need to re-setup the CB manager, and it is False by default) it might be worth it. Lmk if you have any question!
There was a problem hiding this comment.
max_memory_percent=0.5 is the only parameter TRL overrides (transformers defaults to 0.9). The lower value reserves more VRAM headroom for the backward pass, which matters on larger models. The rest (allow_block_sharing, block_size) inherit transformers defaults directly. Running some benchmarks with the latest changes:
… into grpo-continuous-batching
| unwrap_model_for_generation( | ||
| self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation |
There was a problem hiding this comment.
I wonder if this unwrap_model_for_generation yields a correct model in all cases (fsdp, deepspeed, tp etc ) for the continuous batching ? Maybe a small test with fsdp can tell us if it's robust.
There was a problem hiding this comment.
thanks! I've tested it and once it does yield the correct model, it's surfaced another bug caused by using with torch.inference_mode(): so I've updated it
AmineDiro
left a comment
There was a problem hiding this comment.
Great work 👏🏼
Just wondering if the unwrapped model passed to continous batching still works with fsdp etc ..
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes using default mode and found 2 potential issues.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit b2b41ed. Configure here.
|
The |
|
Thanks for the PR. High-level for now: +1 on retiring On CB itself: the part that's a bit of a shame is that in sync GRPO the batch can't span training steps — CB's intra-batch refill helps (hence the 1.24× in your bench), but the bigger win, keeping the batch continuously full across the entire training run, is left on the table. That's the win async GRPO would unlock; I'm prototyping it in #5781. Not a blocker: sync+CB still has a clear niche, and the speedup is real. But I'd love your thoughts on framing: for a user purely chasing performance, what's the sync+CB sweet spot vs. async+CB? |

What does this PR do?
Replaces the deprecated
use_transformers_pagedwith proper transformers continuous batching support + a new training example script. The old branch setlogprobs = None, silently bypassing importance-sampling correction. This PR captures logprobs fromoutput.logprobsand exposesContinuousBatchingConfigfor KV cache tuning.Existing
use_transformers_paged=Trueconfigs continue to work and forward to the new flag with aFutureWarning.Requires
transformers>=5.8.0(two training-mode bugs ingenerate_batch()fixed in huggingface/transformers#45943). The 5.8.0 could actually need an update if these changes land in a follow-up version.Usage
Use CB over the default
generate()when N≥32 with variable completion lengths (e.g. math reasoning). Use vLLM when maximum throughput or multi-GPU tensor parallelism is the priority.Performance
Benchmark (A100 80GB, Llama-3.2-1B-Instruct). Full script: https://gist.github.com/sergiopaniego/740a9708289e8f64cacd0d087d17d162
CB pulls ahead at N≥32 with variable completion lengths. At N=64 the VRAM delta inverts: default
generate()eagerly allocates KV cache for all 64×2048 sequences (~40 GB) while CB pre-allocates a fixed fraction of free VRAM (~25 GB).RLOOTraineralso supportedBefore submitting
AI writing disclosure
We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.
Who can review?
@qgallouedec @AmineDiro @kashif @remi-or
Note
Medium Risk
Updates GRPO/RLOO generation code paths and config surface area, including a new transformers>=5.8.0 dependency gate, which could affect training behavior and GPU-only execution paths. Backward-compat via
use_transformers_pagedforwarding reduces rollout risk but still changes runtime defaults (e.g., KV cache memory cap).Overview
Adds a new
use_transformers_continuous_batchingoption (andtransformers_continuous_batching_config) toGRPOConfig/RLOOConfig, withuse_transformers_pagednow deprecated, warning, and automatically forwarding to the new flag.Updates
GRPOTrainerandRLOOTrainerto generate viamodel.generate_batch(..., continuous_batching_config=...), enforcetransformers>=5.8.0, apply a training-aware defaultmax_memory_percent=0.5, and explicitly reject multimodal processors for this path.Adds documentation and an
examples/scripts/grpo_continuous_batching.pytraining script, and updates tests to cover the new flag and include contract checks for the upstream continuous batching API.Reviewed by Cursor Bugbot for commit a966d8e. Bugbot is set up for automated code reviews on this repo. Configure here.