Skip to content

Conversation

@kevssim
Copy link
Contributor

@kevssim kevssim commented Dec 1, 2025

What does this PR do?

Integrate PrefixGrouper into verl's FSDP worker to accelerate GRPO training by reducing redundant prefix computations.

In GRPO training, each prompt is copied G times (rollout.n), leading to redundant self-attention computation on shared prefixes. PrefixGrouper decomposes this into prefix self-attention + suffix concat-attention, significantly reducing computation and memory usage.

Key changes:

  • Add use_prefix_grouper config option in ActorConfig
  • Implement PG forward path in DataParallelPPOActor._forward_micro_batch
  • Add utility functions in verl/trainer/ppo/prefix_grouper_utils.py
  • Add example scripts and documentation in examples/prefix_grouper_examples/

Test

Benchmark Results (Qwen3-4B, 4×H800, rollout.n=4):

Context Length Metric PG No PG Speedup
4K old_log_prob 1.31s 1.70s 1.30x
update_actor 4.80s 6.07s 1.26x
step 17.08s 19.40s 1.14x
8K old_log_prob 1.69s 2.63s 1.56x
update_actor 5.98s 10.18s 1.70x
step 19.48s 24.71s 1.27x
timing_comparison_combined

As context length increases, the speedup becomes more pronounced.

API and Usage Example

# Enable PrefixGrouper in training config
actor_rollout_ref.actor.use_prefix_grouper=True
trainer.balance_batch=False  # Required: PG is incompatible with balance_batch
actor_rollout_ref.model.use_remove_padding=False  # Required: PG is incompatible with remove_padding
# Run example script
bash examples/prefix_grouper_examples/run_qwen3_pg.sh

Design & Code Changes

High-level Design:

PrefixGrouper optimizes GRPO training by avoiding redundant computation on shared prefixes. When rollout.n > 1, multiple responses share the same prompt, but standard attention computes the prefix n times. PrefixGrouper decomposes this into:

  1. Prefix self-attention: Compute once per unique prompt
  2. Suffix concat-attention: Each response attends to the shared prefix output

Design & Code Changes

High-level Design:

PrefixGrouper optimizes GRPO training by avoiding redundant computation on shared prefixes. When rollout.n > 1, multiple responses share the same prompt, but standard attention computes the prefix n times. PrefixGrouper decomposes this into:

  1. Prefix self-attention: Compute once per unique prompt
  2. Suffix concat-attention: Each response attends to the shared prefix output

Code Changes:

File Change
verl/workers/config/actor.py Add use_prefix_grouper: bool = False config option
verl/trainer/config/actor/actor.yaml Add use_prefix_grouper: false default config
verl/workers/actor/dp_actor.py (1) Add self.use_prefix_grouper and self.use_dynamic_bsz attributes in __init__; (2) Add PG forward path in _forward_micro_batch with lazy import and incompatibility checks; (3) Select extra keys (prompts, response_mask, uid) for PG in compute_log_prob; (4) Select extra keys (prompts, uid) for PG in update_policy
verl/trainer/ppo/prefix_grouper_utils.py New file with: build_position_ids_for_prefix_grouper() for position encoding, build_pg_from_micro_batch() to construct PrefixGrouper from micro batch, pg_forward() to execute PG-optimized forward pass
verl/workers/fsdp_workers.py Sync use_prefix_grouper config from actor to ref policy in init_model to ensure both use the same forward path
verl/trainer/ppo/ray_trainer.py Add ValueError check for use_prefix_grouper + balance_batch incompatibility at initialization
examples/prefix_grouper_examples/ New directory with: README.md documentation, run_qwen3_prefix_grouper.sh example script, qwen3/modeling_qwen3.py modified model supporting PrefixGrouper

Limitations

  • FSDP worker only: Megatron worker is not supported yet
  • Incompatible configurations:
    • trainer.balance_batch=True (reorders data, breaks uid grouping)
    • use_dynamic_bsz=True
    • use_remove_padding=True (Flash Attention V2 variable length)
    • use_fused_kernels=True
    • use_ulysses_sp=True (Ulysses sequence parallelism)
  • Model modification required: The model must accept prefix_grouper argument in its forward method

Checklist Before Submitting

  • Read the Contribute Guide.
  • Apply pre-commit checks: pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always
  • Add / Update the documentation. (Added examples/prefix_grouper_examples/README.md)
  • Add unit or end-to-end test(s) to the CI workflow to cover all the code. If not feasible, explain why: PrefixGrouper requires modified model files and specific hardware setup, tested manually with benchmark results above.
  • Once your PR is ready for CI, send a message in the ci-request channel.

@CLAassistant
Copy link

CLAassistant commented Dec 1, 2025

CLA assistant check
All committers have signed the CLA.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully integrates PrefixGrouper to accelerate GRPO training, which is a significant performance improvement. The changes are well-structured, with new utilities for PrefixGrouper, configuration options, and integration into the FSDP worker. The inclusion of example scripts and documentation is also very helpful. The code correctly handles incompatibilities with other features like balance_batch. My review identifies one critical edge case that could lead to a runtime error.

Comment on lines 44 to 52
group_sizes = []
cur = 1
for i in range(1, bs):
if uids[i] == uids[i-1]:
cur += 1
else:
group_sizes.append(cur)
cur = 1
group_sizes.append(cur)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current logic for calculating group_sizes does not handle the case where the batch size bs is 0. If bs is 0, group_sizes will incorrectly be [1], which will lead to an IndexError downstream when trying to select from an empty prompts tensor. This can happen if an empty micro-batch is processed, causing a crash. The logic should be guarded to handle an empty batch gracefully.

Suggested change
group_sizes = []
cur = 1
for i in range(1, bs):
if uids[i] == uids[i-1]:
cur += 1
else:
group_sizes.append(cur)
cur = 1
group_sizes.append(cur)
group_sizes = []
if bs > 0:
cur = 1
for i in range(1, bs):
if uids[i] == uids[i - 1]:
cur += 1
else:
group_sizes.append(cur)
cur = 1
group_sizes.append(cur)

@kevssim kevssim marked this pull request as ready for review December 2, 2025 01:45
self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl)

# PrefixGrouper requires data to be sorted by uid, which is incompatible with balance_batch
if self.config.actor_rollout_ref.actor.get("use_prefix_grouper", False) and self.config.trainer.get(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sequence balance is important to balance num of total tokens in micro batches across all dp groups, I'm a bit concerned that this incompatibility will diminish the benefits of prefix_grouper.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for raising this! You're right that the incompatibility could hurt prefix_grouper's benefits. I've been thinking about a few ways to solve this:

Option 1: Group-level balancing

Instead of balancing individual samples, we balance entire uid groups. So if uid=1 has 4 samples, they all stay together and get assigned to the same rank.

The good:

  • Prefix sharing still works since same-uid samples stay on the same rank
  • When you have many groups (e.g., batch_size=64, n=4 → 16 groups), the balancing is still pretty fine-grained

The bad:

  • Needs batch_size % (world_size * n) == 0 to work. For example, with world_size=8 and n=4, you need batch_size to be a multiple of 32
  • Balancing is coarser than sample-level, so you might see slightly worse load distribution

Option 2: Balance first, then sort within each rank

Let balance_batch do its thing, then sort by uid within each rank before the forward pass.

The good:

  • Easy to implement, no batch_size constraints

The bad:

  • Same-uid samples get scattered across ranks. If uid=1 has 8 samples and world_size=4, each rank only gets 2 of them
  • This basically kills the prefix sharing benefit - you're computing the same prefix 4 times instead of once

Option 3: Keep them mutually exclusive (current state)

Just document that you can't use both together and let users pick.

The good:

  • Simple

The bad:

  • Users have to choose between two useful features

My take

I think Option 1 is worth implementing. The constraint isn't trivial, but it's manageable in practice with proper batch_size configuration. And unlike Option 2, it actually preserves the prefix sharing benefits.

Happy to implement this if it sounds reasonable, or we can explore other ideas if you have thoughts on a better approach.

return attn_output, attn_weights


class Qwen3Attention(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there's a uniform way to monkey patch transformer for all models? It's too verbose for user to patch each model independently.

For example, we have a uniform monkey patch for deepspeed ulyssess:
https://github.com/volcengine/verl/blob/main/verl/models/transformers/monkey_patch.py

)


def pg_forward(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does prefix_grouper forward compatible with ulyssess sequence parallel?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, prefix_grouper is not compatible with Ulysses sequence parallel.

PrefixGrouper Relies on computing the full prefix once and reusing it for multiple suffixes. When the sequence is split across devices, we can't efficiently share the prefix computation across grouped samples.

The choice between them depends on your workload:

  • Ulysses SP: Better for very long sequences where sequence length is the bottleneck
  • PrefixGrouper: Better when you have larger n (e.g., 8+) and high prefix/suffix ratio (long prompts, short responses)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants