-
Notifications
You must be signed in to change notification settings - Fork 19
Fix the handling of num_epochs for restart and finetune contexts
#845
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
cscs_ci run |
|
cscs-ci run |
|
cscs-ci run |
1 similar comment
|
cscs-ci run |
|
may I add a point... there's also a problem when you restart a finetuning. The training restart correctly from the model checkpoint, but the LR reset and it goes back to the initial one... for example this was an interrupted finetuning run: And if I restart it with: |
I think I have an idea how to fix it, let me add some changes |
|
cscs-ci run |
|
cscs-ci run |
1 similar comment
|
cscs-ci run |
abmazitov
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
cscs-ci run |
|
cscs-ci run |
2 similar comments
|
cscs-ci run |
|
cscs-ci run |
|
cscs-ci run |
|
cscs-ci run |
|
cscs-ci run |
|
Maybe we need a 5 minute chat in person about this one @abmazitov @frostedoyster to sort it out 😅 |
|
@jwa7 I think @abmazitov and I would rather keep typing "cscs-ci run" until they pass |
|
cscs-ci run |
|
The test error for distributed tests looks a lot like what I had to fix for #922, but it is weird that it did not show there. It might just be a checkpoint update that is not doing what it needs to. The tox test error looks relevant (trying to use matmul with int/float) |
The way that
num_epochsis handled is currently broken. Say for example I run training for 1000 epochs, but it runs out of time after 600 epochs. Restarting from the checkpoint at epoch 600, and keepingnum_epochs = 1000results in a total of 1600 epochs being run. This also messes up the cosine scheduling, which 'resets' after 1000 and starts increasing again.This PR fixes these issues, depending on the context:
epochattribute of theTrainerwhen loaded from checkpoint is the one corresponding to the epoch, i.e. 600, and the training runs up to the max epoch number, i.e. 1000, then stops.epochattribute of theTrainerresets to 0 when loaded from checkpoint, and training runs from zero ->num_epochs.Contributor (creator of pull-request) checklist
Maintainer/Reviewer checklist
📚 Documentation preview 📚: https://metatrain--845.org.readthedocs.build/en/845/