-
-
Notifications
You must be signed in to change notification settings - Fork 192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Resuming Training gives CheckpointMismatchError #1300
Comments
Hi @arushi-08 , the checkpoint files are "just" normal torch archives, i.e., you can load them via The checksum was calculated from the string representations of the model and the optimizer, cf. here pykeen/src/pykeen/training/training_loop.py Lines 203 to 209 in d1222b7
I would suggest that you load the checkpoint file via d = torch.load(path)
d["checksum"] = checksum
torch.save(d, new_path) |
Hi, I was having the same error. I believe the problem comes when using the scheduler object from PyTorch. We can observe in the constructor whenever Basically, this makes I believe the issue can be solved by moving the checksum comparison to the end of the method. |
@pablo-sanchez-sony, would you mind opening a PR with the changes you suggest? |
Sure! |
I am facing this checkpoint mismatch error in the same training loop for RotatE KGE model.
My training script is: result = hpo_pipeline(
study_name='rotate_hpo',
training=training,
testing=testing,
validation=validation,
pruner="MedianPruner",
sampler="tpe",
model='RotatE',
model_kwargs={
"random_seed": 42,
},
model_kwargs_ranges=dict(
embedding_dim=dict(type=int, low=100, high=300, q=100),
),
negative_sampler_kwargs_ranges=dict(
num_negs_per_pos=dict(type=int, low=1, high=100),
),
stopper='early',
n_trials=30,
training_loop="sLCWA",
training_kwargs=dict(
num_epochs=500,
checkpoint_name='rotate-checkpoint.pt',
checkpoint_frequency=10,
),
evaluator_kwargs={"filtered": True, "batch_size":128},
) Kindly suggest how to resolve this, as I am not explicitly trying to resume training, rather the hpo_pipeline itself is reloading from the checkpoint. |
When setting a checkpoint name
it seems to be used for all trials => the second run thinks it is a continuation of the first trial, but the model hyperparameters do not match. |
Here is a smaller reproduction script to reproduce the error from pykeen.hpo import hpo_pipeline
result = hpo_pipeline(
study_name="rotate_hpo",
dataset="nations",
model="RotatE",
model_kwargs_ranges=dict(
embedding_dim=dict(type=int, low=8, high=24, q=8),
),
stopper="early",
n_trials=2,
training_loop="sLCWA",
training_kwargs=dict(
num_epochs=2,
checkpoint_name="rotate-checkpoint.pt",
checkpoint_frequency=1,
),
) |
@arushi-08 , what is your use case for providing a checkpoint name? Do you want to save each trial's model? If yes, we have an explicit |
I have opened a small PR (#1324) to fail fast on the first trial with an error message about how to fix it 🙂 |
@pablo-sanchez-sony , would this resolve your issue, too? |
…uration (#1324) When providing a `checkpoint_name` only the second trial will fail, since it finds an existing checkpoint and tries to continue training, but the model configuration has likely changed. This PR checks the configuration and directly raises an error with a descriptive error message. cf. #1300
I want to resume training my model from a checkpoint file (*.pt), but facing
pykeen.training.training_loop.CheckpointMismatchError
error.Full stack trace:
I have realised that the issue is with checksum mismatch i.e. the checkpoint file has a different configuration.
pykeen/src/pykeen/training/training_loop.py
Lines 1182 to 1188 in d1222b7
However, I am not sure how to load the same configuration as given in the checkpoint file. I feel that this is highlighted in the "Word of Caution and Possible Errors" documentation section (https://pykeen.readthedocs.io/en/stable/tutorial/checkpoints.html#word-of-caution-and-possible-errors), but still unclear what are the next steps.
How do we resume training from the previous checkpoint?
The text was updated successfully, but these errors were encountered: