Skip to content

Conversation

@frostedoyster
Copy link
Collaborator

@frostedoyster frostedoyster commented Dec 26, 2025

This PR implements even load management between different batches, based on the total number of atoms in a batch. This is implemented by letting the user specify a lower and upper bound of atoms in a batch. In practice, training is done as normal but, if a batch has too few or too many atoms, it is skipped.

This is useful for two main reasons:

  • avoids OOM issues due to unlucky large batches, allowing for much larger batch sizes which potentially increase training efficiency
  • evens out the load on different processes in distributed settings, where all processes have to wait for the slowest batch across all processes at every training step

For PET, this complements the adaptive cutoff strategies to fully even out memory and computation usage across batches


📚 Documentation preview 📚: https://metatrain--990.org.readthedocs.build/en/990/

frostedoyster and others added 11 commits December 26, 2025 18:17
- Add min_atoms_per_batch and max_atoms_per_batch to BaseHypers schema
- Create CollateFnWithBatchBounds wrapper for batch validation
- Add batch bounds to DatasetInfo class
- Update train.py to pass batch bounds to DatasetInfo
- Update SOAP-BPNN trainer to use batch bounds when creating dataloaders
- Add comprehensive tests for batch bounds functionality
- Create example demonstrating batch bounds usage

Co-authored-by: frostedoyster <[email protected]>
- Update LLPR trainer to use batch bounds
- Update PET trainer to use batch bounds
- Update MACE trainer to use batch bounds
- Update FlashMD trainer to use batch bounds
- Update deprecated NanoPET trainer to use batch bounds
All trainers now support min_atoms_per_batch and max_atoms_per_batch configuration options

Co-authored-by: frostedoyster <[email protected]>
- Move import to module level in test_batch_bounds.py
- Clarify docstring on how RuntimeError integrates with PyTorch DataLoader
- Make it clear that batches are not automatically skipped but training is interrupted

Co-authored-by: frostedoyster <[email protected]>
- Remove batch bounds from global BaseHypers
- Add batch_atom_bounds to TrainerHypers for all architectures (SOAP-BPNN, LLPR, PET, MACE, FlashMD, NanoPET)
- Change API to use list format: batch_atom_bounds: [min, max]
- Return None for invalid batches instead of raising errors
- Update all trainers to skip None batches in training/validation loops
- Update tests to reflect new API and behavior
- Update example documentation

Co-authored-by: frostedoyster <[email protected]>
- Synchronize batch rejection across all processes in distributed mode
- Use all_reduce with MIN operation to ensure all processes agree on batch validity
- If any process has None batch, all processes skip that batch
- Applied to all trainers: SOAP-BPNN, LLPR, PET, MACE, FlashMD, NanoPET

Co-authored-by: frostedoyster <[email protected]>
- Add validation for batch_atom_bounds format (must be list of 2 elements)
- Create should_skip_batch_distributed utility function
- Apply utility to SOAP-BPNN and LLPR trainers
- Note: Remaining trainers (PET, MACE, FlashMD, NanoPET) still use inline code but are functional

Co-authored-by: frostedoyster <[email protected]>
- Remove CollateFnWithBatchBounds wrapper class
- Add batch_atom_bounds parameter directly to CollateFn.__init__
- CollateFn now returns None for batches outside bounds
- Update all trainers to pass batch_atom_bounds to CollateFn directly
- Use should_skip_batch_distributed helper in all training/validation loops
- Update tests to work with new API
- Minimal code changes per trainer (3-4 lines for collate_fn, 1-2 lines per loop)

Co-authored-by: frostedoyster <[email protected]>
…ers[] instead of get()

- Renamed should_skip_batch_distributed to should_skip_batch (simpler name)
- Fixed SOAP-BPNN training loop to use helper function instead of manual code
- Changed all hypers.get("batch_atom_bounds", [None, None]) to hypers["batch_atom_bounds"]
- Defaults are now handled in documentation.py files as intended
- All trainers now consistently use the helper function everywhere

Co-authored-by: frostedoyster <[email protected]>
@frostedoyster frostedoyster changed the title Implement batch bounds to even out resource needs during training Implement batch bounds to even out resource usage during training Dec 26, 2025
@pfebrer
Copy link
Contributor

pfebrer commented Jan 4, 2026

Hey, I have some questions:

  1. Why not skipping the batch in the data loader itself, instead of creating a None batch and then checking for it in the trainer? Apart from the code being nicer, the current implementation seems to have the possibility of dropping a lot of data in distributed environments.
  2. Does data get shuffled at every epoch or will this result in some structures never appearing in training because their batch is too large/too small?
  3. Would it make sense to implement load balancing by interfering in the batch creation, instead of naively creating batches and then dropping?

@frostedoyster
Copy link
Collaborator Author

  1. If you skip just a single batch in one dataloader, the dataloaders fall out of sync, because, in my understanding, the length of each dataloader is determined in advance and it's the same for all parallel dataloaders. This is the standard way to skip batches in distributed environments in GNNs, as far as I know
  2. The data is shuffled at each epoch
  3. Maybe but the implementation would be much more difficult. When I say much, I mean much. This is the standard, as far as I know

@pfebrer
Copy link
Contributor

pfebrer commented Jan 5, 2026

  1. Ok, but this doesn't mean that you can't synchronize the skipping inside the dataloaders/datasets no?

If this is the way to go, could we at least keep track of how many batches were skipped and print it somewhere in the log file (e.g. as a percentage of the full dataset)?

@pfebrer
Copy link
Contributor

pfebrer commented Jan 5, 2026

If we can synchronize the skipping at the beggining of the collate function, apart from the code (this is just an opinion) being nicer, we would save a lot of wasted computation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants