-
Notifications
You must be signed in to change notification settings - Fork 2.8k
[trainer] feat: Optimize tokenizer decode calls in validation #4407
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a valuable optimization by making tokenizer decode operations conditional within the _validate method. The logic to skip decoding when val_data_dir is not set and log_val_generations is zero is sound and should reduce computational overhead during validation. The removal of the duplicate val_data_dir definition is also a good cleanup.
I've added a couple of suggestions to further improve performance and code consistency by using tokenizer.batch_decode instead of iterating with tokenizer.decode. This is more efficient and aligns with practices in other parts of the file.
| input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] | ||
| sample_inputs.extend(input_texts) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For better performance and consistency with other parts of the codebase (e.g., _log_rollout_data), it's recommended to use tokenizer.batch_decode instead of a list comprehension with tokenizer.decode. batch_decode is generally more efficient for decoding batches of sequences. You can also directly extend sample_inputs to make the code more concise.
| input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] | |
| sample_inputs.extend(input_texts) | |
| sample_inputs.extend(self.tokenizer.batch_decode(input_ids, skip_special_tokens=True)) |
| output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] | ||
| sample_outputs.extend(output_texts) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the input decoding, using tokenizer.batch_decode here is more performant and consistent with other parts of the code. The intermediate output_texts variable can also be removed for conciseness.
| output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] | |
| sample_outputs.extend(output_texts) | |
| sample_outputs.extend(self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)) |
What does this PR do?
Optimize the
_validatemethod by conditionally executing tokenizer decode operations only when needed.Changes:
val_data_dirandgenerations_to_logto determine if decoding is requiredinput_textsandoutput_textsdecoding conditionalval_data_dirdefinitionPerformance Impact:
When
val_data_diris not set andlog_val_generationsis 0, tokenizer decode operations are skipped entirely, reducing unnecessary computational overhead during validation.Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)