Skip to content

feat(grpo): replace deprecated use_transformers_paged with transformers continuous batching#5765

Open
sergiopaniego wants to merge 12 commits into
mainfrom
grpo-continuous-batching
Open

feat(grpo): replace deprecated use_transformers_paged with transformers continuous batching#5765
sergiopaniego wants to merge 12 commits into
mainfrom
grpo-continuous-batching

Conversation

@sergiopaniego
Copy link
Copy Markdown
Member

@sergiopaniego sergiopaniego commented May 13, 2026

What does this PR do?

Replaces the deprecated use_transformers_paged with proper transformers continuous batching support + a new training example script. The old branch set logprobs = None, silently bypassing importance-sampling correction. This PR captures logprobs from output.logprobs and exposes ContinuousBatchingConfig for KV cache tuning.

Existing use_transformers_paged=True configs continue to work and forward to the new flag with a FutureWarning.

Requires transformers>=5.8.0 (two training-mode bugs in generate_batch() fixed in huggingface/transformers#45943). The 5.8.0 could actually need an update if these changes land in a follow-up version.

Usage

GRPOConfig(
    ...
    use_transformers_continuous_batching=True,
    transformers_continuous_batching_config={
        "use_cuda_graph": False,
        "max_memory_percent": 0.4,  # leave ~60% free for training
    },
    ...
)

Use CB over the default generate() when N≥32 with variable completion lengths (e.g. math reasoning). Use vLLM when maximum throughput or multi-GPU tensor parallelism is the priority.

Performance

Benchmark (A100 80GB, Llama-3.2-1B-Instruct). Full script: https://gist.github.com/sergiopaniego/740a9708289e8f64cacd0d087d17d162

Scenario Gens Max tokens Default CB Speedup VRAM delta
GSM8K 8 2048 18.24 ± 12.81 s 18.16 ± 8.12 s 1.00x +17.86 GB
GSM8K 32 2048 37.06 ± 8.63 s 29.91 ± 10.09 s 1.24x +7.39 GB
GSM8K 64 2048 41.12 ± 12.69 s 32.84 ± 7.00 s 1.25x -16.66 GB

CB pulls ahead at N≥32 with variable completion lengths. At N=64 the VRAM delta inverts: default generate() eagerly allocates KV cache for all 64×2048 sequences (~40 GB) while CB pre-allocates a fixed fraction of free VRAM (~25 GB).

cb_benchmark_results

RLOOTrainer also supported

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

AI writing disclosure

We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.

  • No AI usage: the PR was written entirely by a human.
  • AI-assisted: some parts were suggested or improved by AI, but the PR was written and reviewed by a human.
  • AI-generated: the PR was mostly or fully generated by an AI tool.

Who can review?

@qgallouedec @AmineDiro @kashif @remi-or


Note

Medium Risk
Updates GRPO/RLOO generation code paths and config surface area, including a new transformers>=5.8.0 dependency gate, which could affect training behavior and GPU-only execution paths. Backward-compat via use_transformers_paged forwarding reduces rollout risk but still changes runtime defaults (e.g., KV cache memory cap).

Overview
Adds a new use_transformers_continuous_batching option (and transformers_continuous_batching_config) to GRPOConfig/RLOOConfig, with use_transformers_paged now deprecated, warning, and automatically forwarding to the new flag.

Updates GRPOTrainer and RLOOTrainer to generate via model.generate_batch(..., continuous_batching_config=...), enforce transformers>=5.8.0, apply a training-aware default max_memory_percent=0.5, and explicitly reject multimodal processors for this path.

Adds documentation and an examples/scripts/grpo_continuous_batching.py training script, and updates tests to cover the new flag and include contract checks for the upstream continuous batching API.

Reviewed by Cursor Bugbot for commit a966d8e. Bugbot is set up for automated code reviews on this repo. Configure here.

@sergiopaniego sergiopaniego changed the title feat(grpo): replace deprecated use_transformers_paged with transformers continuous batching feat(grpo): replace deprecated use_transformers_paged with transformers continuous batching May 13, 2026
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment thread trl/trainer/grpo_trainer.py Outdated
"Using `use_transformers_continuous_batching` requires transformers>=5.8.0. "
"Please upgrade with `pip install --upgrade transformers`."
)
from transformers.generation import ContinuousBatchingConfig
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inlined import maybe we can move it ?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's intentional since it was introduced in transformers 5.4.0

)
from transformers.generation import ContinuousBatchingConfig

cb_kwargs = dict(args.transformers_continuous_batching_config or {})
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there are some good defaults that we can set ? @sergiopaniego did you test different config to get a sense of throughput, latency, memory ?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be interested in this as well!
Usually if no config is given, continuous batching adapts to the situation and sets smarts default values for the parameters itself, but if you see something different please let me know!
BTW, compile is turned off by default, because it adds a lot of warmup overhead. But if you re-use the continuous batching manager for several generations (you have to set persistent_manager=True so you don't need to re-setup the CB manager, and it is False by default) it might be worth it. Lmk if you have any question!

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_memory_percent=0.5 is the only parameter TRL overrides (transformers defaults to 0.9). The lower value reserves more VRAM headroom for the backward pass, which matters on larger models. The rest (allow_block_sharing, block_size) inherit transformers defaults directly. Running some benchmarks with the latest changes:

cb_config_sweep

Comment on lines 1385 to 1386
unwrap_model_for_generation(
self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this unwrap_model_for_generation yields a correct model in all cases (fsdp, deepspeed, tp etc ) for the continuous batching ? Maybe a small test with fsdp can tell us if it's robust.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! I've tested it and once it does yield the correct model, it's surfaced another bug caused by using with torch.inference_mode(): so I've updated it

Copy link
Copy Markdown
Member

@AmineDiro AmineDiro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work 👏🏼

Just wondering if the unwrapped model passed to continous batching still works with fsdp etc ..

Comment thread trl/trainer/grpo_trainer.py Outdated
Comment thread trl/trainer/grpo_config.py
Comment thread trl/trainer/grpo_trainer.py Outdated
Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes using default mode and found 2 potential issues.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit b2b41ed. Configure here.

Comment thread trl/trainer/grpo_trainer.py
Comment thread trl/trainer/rloo_trainer.py
@sergiopaniego
Copy link
Copy Markdown
Member Author

sergiopaniego commented May 14, 2026

The transformers version needs to be updated from 5.8.0 once huggingface/transformers#45943 (already merged) lands in a new version

@qgallouedec
Copy link
Copy Markdown
Member

Thanks for the PR. High-level for now:

+1 on retiring use_transformers_paged.

On CB itself: the part that's a bit of a shame is that in sync GRPO the batch can't span training steps — CB's intra-batch refill helps (hence the 1.24× in your bench), but the bigger win, keeping the batch continuously full across the entire training run, is left on the table. That's the win async GRPO would unlock; I'm prototyping it in #5781.

Not a blocker: sync+CB still has a clear niche, and the speedup is real. But I'd love your thoughts on framing: for a user purely chasing performance, what's the sync+CB sweet spot vs. async+CB?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants