Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,11 @@ def _validate(self):
data_source_lst = []
reward_extra_infos_dict: dict[str, list] = defaultdict(list)

# Check if we need to decode texts (only needed for logging/dumping)
val_data_dir = self.config.trainer.get("validation_data_dir", None)
generations_to_log = self.config.trainer.get("log_val_generations", 0)
need_decode = val_data_dir is not None or generations_to_log > 0

# Lists to collect samples for the table
sample_inputs = []
sample_outputs = []
Expand Down Expand Up @@ -560,8 +565,9 @@ def _validate(self):
# Store original inputs
input_ids = test_batch.batch["input_ids"]
# TODO: Can we keep special tokens except for padding tokens?
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
sample_inputs.extend(input_texts)
if need_decode:
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
sample_inputs.extend(input_texts)
Comment on lines +569 to +570
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For better performance and consistency with other parts of the codebase (e.g., _log_rollout_data), it's recommended to use tokenizer.batch_decode instead of a list comprehension with tokenizer.decode. batch_decode is generally more efficient for decoding batches of sequences. You can also directly extend sample_inputs to make the code more concise.

Suggested change
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
sample_inputs.extend(input_texts)
sample_inputs.extend(self.tokenizer.batch_decode(input_ids, skip_special_tokens=True))

sample_uids.extend(test_batch.non_tensor_batch["uid"])

ground_truths = [
Expand Down Expand Up @@ -599,8 +605,9 @@ def _validate(self):

# Store generated outputs
output_ids = test_output_gen_batch.batch["responses"]
output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
sample_outputs.extend(output_texts)
if need_decode:
output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
sample_outputs.extend(output_texts)
Comment on lines +609 to +610
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the input decoding, using tokenizer.batch_decode here is more performant and consistent with other parts of the code. The intermediate output_texts variable can also be removed for conciseness.

Suggested change
output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
sample_outputs.extend(output_texts)
sample_outputs.extend(self.tokenizer.batch_decode(output_ids, skip_special_tokens=True))


test_batch = test_batch.union(test_output_gen_batch)
test_batch.meta_info["validate"] = True
Expand All @@ -627,7 +634,6 @@ def _validate(self):
self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)

# dump generations
val_data_dir = self.config.trainer.get("validation_data_dir", None)
if val_data_dir:
self._dump_generations(
inputs=sample_inputs,
Expand Down