-
Notifications
You must be signed in to change notification settings - Fork 2.8k
[fsdp] feat: integrate PrefixGrouper for GRPO training acceleration #4368
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request successfully integrates PrefixGrouper to accelerate GRPO training, which is a significant performance improvement. The changes are well-structured, with new utilities for PrefixGrouper, configuration options, and integration into the FSDP worker. The inclusion of example scripts and documentation is also very helpful. The code correctly handles incompatibilities with other features like balance_batch. My review identifies one critical edge case that could lead to a runtime error.
| group_sizes = [] | ||
| cur = 1 | ||
| for i in range(1, bs): | ||
| if uids[i] == uids[i-1]: | ||
| cur += 1 | ||
| else: | ||
| group_sizes.append(cur) | ||
| cur = 1 | ||
| group_sizes.append(cur) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current logic for calculating group_sizes does not handle the case where the batch size bs is 0. If bs is 0, group_sizes will incorrectly be [1], which will lead to an IndexError downstream when trying to select from an empty prompts tensor. This can happen if an empty micro-batch is processed, causing a crash. The logic should be guarded to handle an empty batch gracefully.
| group_sizes = [] | |
| cur = 1 | |
| for i in range(1, bs): | |
| if uids[i] == uids[i-1]: | |
| cur += 1 | |
| else: | |
| group_sizes.append(cur) | |
| cur = 1 | |
| group_sizes.append(cur) | |
| group_sizes = [] | |
| if bs > 0: | |
| cur = 1 | |
| for i in range(1, bs): | |
| if uids[i] == uids[i - 1]: | |
| cur += 1 | |
| else: | |
| group_sizes.append(cur) | |
| cur = 1 | |
| group_sizes.append(cur) |
| self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl) | ||
|
|
||
| # PrefixGrouper requires data to be sorted by uid, which is incompatible with balance_batch | ||
| if self.config.actor_rollout_ref.actor.get("use_prefix_grouper", False) and self.config.trainer.get( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sequence balance is important to balance num of total tokens in micro batches across all dp groups, I'm a bit concerned that this incompatibility will diminish the benefits of prefix_grouper.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for raising this! You're right that the incompatibility could hurt prefix_grouper's benefits. I've been thinking about a few ways to solve this:
Option 1: Group-level balancing
Instead of balancing individual samples, we balance entire uid groups. So if uid=1 has 4 samples, they all stay together and get assigned to the same rank.
The good:
- Prefix sharing still works since same-uid samples stay on the same rank
- When you have many groups (e.g., batch_size=64, n=4 → 16 groups), the balancing is still pretty fine-grained
The bad:
- Needs
batch_size % (world_size * n) == 0to work. For example, with world_size=8 and n=4, you need batch_size to be a multiple of 32 - Balancing is coarser than sample-level, so you might see slightly worse load distribution
Option 2: Balance first, then sort within each rank
Let balance_batch do its thing, then sort by uid within each rank before the forward pass.
The good:
- Easy to implement, no batch_size constraints
The bad:
- Same-uid samples get scattered across ranks. If uid=1 has 8 samples and world_size=4, each rank only gets 2 of them
- This basically kills the prefix sharing benefit - you're computing the same prefix 4 times instead of once
Option 3: Keep them mutually exclusive (current state)
Just document that you can't use both together and let users pick.
The good:
- Simple
The bad:
- Users have to choose between two useful features
My take
I think Option 1 is worth implementing. The constraint isn't trivial, but it's manageable in practice with proper batch_size configuration. And unlike Option 2, it actually preserves the prefix sharing benefits.
Happy to implement this if it sounds reasonable, or we can explore other ideas if you have thoughts on a better approach.
| return attn_output, attn_weights | ||
|
|
||
|
|
||
| class Qwen3Attention(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there's a uniform way to monkey patch transformer for all models? It's too verbose for user to patch each model independently.
For example, we have a uniform monkey patch for deepspeed ulyssess:
https://github.com/volcengine/verl/blob/main/verl/models/transformers/monkey_patch.py
| ) | ||
|
|
||
|
|
||
| def pg_forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does prefix_grouper forward compatible with ulyssess sequence parallel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, prefix_grouper is not compatible with Ulysses sequence parallel.
PrefixGrouper Relies on computing the full prefix once and reusing it for multiple suffixes. When the sequence is split across devices, we can't efficiently share the prefix computation across grouped samples.
The choice between them depends on your workload:
- Ulysses SP: Better for very long sequences where sequence length is the bottleneck
- PrefixGrouper: Better when you have larger n (e.g., 8+) and high prefix/suffix ratio (long prompts, short responses)
What does this PR do?
Integrate PrefixGrouper into verl's FSDP worker to accelerate GRPO training by reducing redundant prefix computations.
In GRPO training, each prompt is copied
Gtimes (rollout.n), leading to redundant self-attention computation on shared prefixes. PrefixGrouper decomposes this into prefix self-attention + suffix concat-attention, significantly reducing computation and memory usage.Key changes:
use_prefix_grouperconfig option inActorConfigDataParallelPPOActor._forward_micro_batchverl/trainer/ppo/prefix_grouper_utils.pyexamples/prefix_grouper_examples/Test
Benchmark Results (Qwen3-4B, 4×H800,
rollout.n=4):old_log_probupdate_actorstepold_log_probupdate_actorstepAs context length increases, the speedup becomes more pronounced.
API and Usage Example
# Run example script bash examples/prefix_grouper_examples/run_qwen3_pg.shDesign & Code Changes
High-level Design:
PrefixGrouper optimizes GRPO training by avoiding redundant computation on shared prefixes. When
rollout.n > 1, multiple responses share the same prompt, but standard attention computes the prefixntimes. PrefixGrouper decomposes this into:Design & Code Changes
High-level Design:
PrefixGrouper optimizes GRPO training by avoiding redundant computation on shared prefixes. When
rollout.n > 1, multiple responses share the same prompt, but standard attention computes the prefixntimes. PrefixGrouper decomposes this into:Code Changes:
verl/workers/config/actor.pyuse_prefix_grouper: bool = Falseconfig optionverl/trainer/config/actor/actor.yamluse_prefix_grouper: falsedefault configverl/workers/actor/dp_actor.pyself.use_prefix_grouperandself.use_dynamic_bszattributes in__init__; (2) Add PG forward path in_forward_micro_batchwith lazy import and incompatibility checks; (3) Select extra keys (prompts,response_mask,uid) for PG incompute_log_prob; (4) Select extra keys (prompts,uid) for PG inupdate_policyverl/trainer/ppo/prefix_grouper_utils.pybuild_position_ids_for_prefix_grouper()for position encoding,build_pg_from_micro_batch()to construct PrefixGrouper from micro batch,pg_forward()to execute PG-optimized forward passverl/workers/fsdp_workers.pyuse_prefix_grouperconfig from actor to ref policy ininit_modelto ensure both use the same forward pathverl/trainer/ppo/ray_trainer.pyValueErrorcheck foruse_prefix_grouper + balance_batchincompatibility at initializationexamples/prefix_grouper_examples/README.mddocumentation,run_qwen3_prefix_grouper.shexample script,qwen3/modeling_qwen3.pymodified model supporting PrefixGrouperLimitations
trainer.balance_batch=True(reorders data, breaks uid grouping)use_dynamic_bsz=Trueuse_remove_padding=True(Flash Attention V2 variable length)use_fused_kernels=Trueuse_ulysses_sp=True(Ulysses sequence parallelism)prefix_grouperargument in itsforwardmethodChecklist Before Submitting
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel.