Skip to content

Conversation

@abmazitov
Copy link
Contributor

@abmazitov abmazitov commented Dec 12, 2025

This PR adds an option to choose the Muon optimizer while training the PET model

Contributor (creator of pull-request) checklist

  • Tests updated (for new features and bugfixes)?
  • Documentation updated (for new features)?
  • [ ] Issue referenced (for PRs that solve an issue)?

Maintainer/Reviewer checklist

  • CHANGELOG updated with public API or any other important changes?
  • GPU tests passed (maintainer comment: "cscs-ci run")?

📚 Documentation preview 📚: https://metatrain--977.org.readthedocs.build/en/977/

@abmazitov abmazitov changed the title Added an experimetal Muon optimizer to PET Added an experimental Muon optimizer to PET Dec 12, 2025
Copy link
Member

@jwa7 jwa7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking great! I'll add the min_lr as discussed, but in the meantime just a comment on the creation of parameter groups :)

@frostedoyster
Copy link
Collaborator

frostedoyster commented Dec 13, 2025

Are we sure the minimum learning rate is worth a new hyperparameter? New hyperparameters make the code harder to maintain and are confusing for users when they see a very large default hyperparameter file. I would rather 1) keep the code as is without a minimum learning rate, at the cost of a tiny inefficiency or 2) choose a minimum learning rate ourselves (e.g., 1/1000 of the initial learning rate) without it being configurable

pyproject.toml Outdated
Comment on lines 71 to 73
pet = [
"torch >= 2.9.1",
]
Copy link
Collaborator

@frostedoyster frostedoyster Dec 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be too aggressive at the moment. For example, the torch version is fixed on many HPC clusters and users can't change it unless they're ready to perform a custom installation. The correct pattern would be to raise an error in case the user chooses Muon and their torch is too old for it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, but the main issue is that Muon is only available at torch == 2.9.1 and higher. I'm also sceptical about raising the requirement. Maybe we can add a check for the Muon optimizer selection and suggest to update torch manually if the old version detected and one really wants to use Muon

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this think is a good idea: remove the constraint and raise an error if Muon in requested

@abmazitov
Copy link
Contributor Author

Are we sure the minimum learning rate is worth a new hyperparameter? New hyperparameters make the code harder to maintain and are confusing for users when they see a very large default hyperparameter file. I would rather 1) keep the code as is without a minimum learning rate, at the cost of a tiny inefficiency or 2) choosing a minimum learning rate ourselves (e.g., 1/1000 of the initial learning rate) without it being configurable

I tend to agree with that. Maybe we can even just hardcode it to be 1e-7

@jwa7
Copy link
Member

jwa7 commented Dec 13, 2025

It's difficult to know whether it should be hardcoded as a ratio relative to the base/maximum learning rate, or as an absolute value. In different training runs I've seen benefit in being able to change it, though I agree that having an extra parameter might be over-engineering

muon_params.append(p)
else:
adam_params.append(p)
adam_group = dict(params=adam_params, use_muon=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want two separate learning rates for the Adam and Muon parameter groups.

If you look at the example from the README of https://github.com/KellerJordan/Muon:

from muon import MuonWithAuxAdam
hidden_weights = [p for p in model.body.parameters() if p.ndim >= 2]
hidden_gains_biases = [p for p in model.body.parameters() if p.ndim < 2]
nonhidden_params = [*model.head.parameters(), *model.embed.parameters()]
param_groups = [
    dict(params=hidden_weights, use_muon=True,
         lr=0.02, weight_decay=0.01),
    dict(params=hidden_gains_biases+nonhidden_params, use_muon=False,
         lr=3e-4, betas=(0.9, 0.95), weight_decay=0.01),
]
optimizer = MuonWithAuxAdam(param_groups)

the Adam LR is more what we'd normally expect but the Muon one can be pushed much higher.

Copy link
Contributor Author

@abmazitov abmazitov Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit sceptical with setting the LR values like this. I mean, then should be highly architecture-dependent, right? In the same time @sirmarcel has tested Muon for PET and noticed that it works nice even with a common LR of ~1e-3 for both Adam and Muon parameters

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm ok, I noticed in my tests that I could push the Muon LR to 1e-1 even and it was still stable, but as soon as the Adam LR went above 1e-3 training diverged. But again, an extra hyperparameter is more complexity, so let's keep it simple and have one as you say for now

@abmazitov
Copy link
Contributor Author

abmazitov commented Dec 15, 2025

It's difficult to know whether it should be hardcoded as a ratio relative to the base/maximum learning rate, or as an absolute value. In different training runs I've seen benefit in being able to change it, though I agree that having an extra parameter might be over-engineering

Optionally, we can add a hard-coded minimal LR which is always equal to initial LR * 1e-3. So for initial LR of 1e-4 the min_lr will be 1e-7, and so on.

@jwa7 jwa7 marked this pull request as ready for review December 16, 2025 09:53
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be regenerated with different hyper parameters?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure - just the PET one or the others too?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great if you could do it for all the newly generated checkpoints!

@jwa7
Copy link
Member

jwa7 commented Dec 16, 2025

@abmazitov @frostedoyster I've hardcoded the minimum LR ratio to 1e-4. If we use high LRs with Muon such as 1e-3, training will finish with LR 1e-7. I think this is a reasonable balance of being low but not too low. You both ok with this?

@abmazitov
Copy link
Contributor Author

@abmazitov @frostedoyster I've hardcoded the minimum LR ratio to 1e-4. If we use high LRs with Muon such as 1e-3, training will finish with LR 1e-7. I think this is a reasonable balance of being low but not too low. You both ok with this?

I think it should be good

@jwa7
Copy link
Member

jwa7 commented Dec 16, 2025

cscs-ci run

@jwa7 jwa7 requested a review from frostedoyster December 16, 2025 11:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants