Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ModelCheckpoint Doesn't Delete Old Best Checkpoints When Resuming Training #18687

Open
danielzeng-gt opened this issue Oct 2, 2023 · 3 comments
Labels
bug Something isn't working callback: model checkpoint repro needed The issue is missing a reproducible example ver: 1.9.x

Comments

@danielzeng-gt
Copy link

danielzeng-gt commented Oct 2, 2023

Bug description

Description:
When using ModelCheckpoint with the parameters top_k=1 and monitor='val_loss' during a singular training run, the behavior is as expected and only retains one 'best_val_confidence-epoch...' checkpoint.

However, in the context of cloud-based training where instances may be preempted or restarted from a checkpoint:

  • The training resumes from a checkpoint labeled "last.ckpt", which was initially created by a different ModelCheckpoint.
  • There aren't any explicit warnings indicating that the ModelCheckpoint state was restored incorrectly.
  • Post-resumption, ModelCheckpoint creates a new checkpoint but fails to delete the old one. Thus, if there's a single preemption/restart during the training run, we end up with two 'best_val_loss' checkpoints.

It should be noted we load/write checkpoints to GCS with fsspec, which allows for checkpoints to be written to and loaded directly from Google Cloud Storage (GCS).

Code Details:

There are two current ModelCheckpoint callbacks in use:

  1. The first is for saving the latest checkpoint:

    last_ckpt_callback = ModelCheckpoint(
        save_top_k= -1,
        save_last= True,
        dirpath= self.checkpoint_dir,
    )
    last_ckpt_callback.CHECKPOINT_NAME_LAST = _CHECKPOINT_NAME_LAST
  2. The second is for saving the best validation loss checkpoint:

    best_val_loss_ckpt_callback = ModelCheckpoint(
        monitor=f'val_loss',
        mode='min',
        save_top_k=1,
        auto_insert_metric_name=False,
        filename='best_val_confidence-epoch{epoch}-val_loss{{val_loss:.4e}}',
        dirpath=self.checkpoint_dir,
    )

Environment:

  • Lightning Component: ModelCheckpoint object
  • PyTorch Lightning Version: 1.9.2
  • PyTorch Version: 1.13.0
  • Python Version: 3.10.12
  • OS: Linux
  • CUDA/cuDNN version: Build cuda_11.6.r11.6/compiler.31057947_0
  • GPU models: Nvidia A100
  • How you installed Lightning: Conda
  • Cloud: Running on GCP Cluster

What version are you seeing the problem on?

v1.9

How to reproduce the bug

1. Setup a training loop on the cloud with the aforementioned `ModelCheckpoint` callbacks.
2. Intentionally interrupt the training to simulate preemption.
3. Resume the training from the "last.ckpt".
4. Post-resumption, inspect the stored checkpoints. There should be two 'best_val_loss' checkpoints instead of one.

**Expected behavior**: Only one 'best_val_confidence-epoch...' checkpoint should remain after resumption.

**Actual behavior**: Multiple 'best_val_confidence-epoch...' checkpoints are observed after training preemption and resumption.

Error messages and logs

# Error messages and logs here please

Environment

Current environment
- Lightning Component: ModelCheckpoint object
- PyTorch Lightning Version: 1.9.2
- PyTorch Version: 1.13.0
- Python Version: 3.10.12
- OS: Linux
- CUDA/cuDNN version: Build cuda_11.6.r11.6/compiler.31057947_0
- GPU models: Nvidia A100
- How you installed Lightning: Conda
- Running environment of LightningApp: Cloud, Running on GCP A100 instance

More info

No response

cc @carmocca @awaelchli

@danielzeng-gt danielzeng-gt added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Oct 2, 2023
@awaelchli awaelchli added callback: model checkpoint and removed needs triage Waiting to be triaged by maintainers labels Oct 2, 2023
@awaelchli
Copy link
Contributor

awaelchli commented Oct 2, 2023

@danielzeng-gt Thanks for submitting the issue.

I read your description multiple times but I don't understand the problem. Can you try to formulate it with an example? Is it related to #17912?

@danielzeng-gt
Copy link
Author

danielzeng-gt commented Oct 2, 2023

Hey Adrian, thanks for the prompt response!
I looked at #17912 and it doesn't seem to be related.

I generated an example with GPT4, and I read over it and it is quite accurate in describing the problem. Please let me know if it's still confusing:

Example:

Suppose Alice is training a neural network to classify images of cats and dogs on a cloud-based preemptible instance. She's interested in keeping two kinds of checkpoints:

  1. The latest checkpoint, irrespective of its performance on validation data.
  2. The checkpoint with the best validation loss.

To achieve this, Alice uses two ModelCheckpoint callbacks as described.

Training Run 1:

  1. Alice starts her training.
  2. After epoch 1, the validation loss is 0.5. The system saves:
    • last.ckpt (The latest checkpoint)
    • best_val_confidence-epoch1-val_loss0.5e (The best checkpoint based on validation loss)
  3. Suddenly, the preemptible instance is terminated.

Training Resumption:

  1. Alice's setup detects the preemption and decides to restart the training from the last checkpoint.
  2. It loads last.ckpt and continues training.
  3. After epoch 2, the validation loss improves to 0.4. The system now tries to save:
    • A new last.ckpt (Replacing the older one)
    • best_val_confidence-epoch2-val_loss0.4e (A new best checkpoint)

Expected Behavior:
Since Alice specified save_top_k=1 for the best validation loss checkpoint, she expects to find only one such checkpoint in her directory, i.e., best_val_confidence-epoch2-val_loss0.4e.

Actual Behavior:
Alice finds two best validation loss checkpoints:

  • best_val_confidence-epoch1-val_loss0.5e
  • best_val_confidence-epoch2-val_loss0.4e

This indicates that the ModelCheckpoint callback did not delete the older "best" checkpoint upon resumption, leading to multiple "best" checkpoints being saved.

Implication:
This behavior can be problematic especially if Alice runs multiple epochs and faces multiple preemptions. Over time, she would accumulate multiple "best" checkpoints, and it is confusing her when trying to identify the genuine best checkpoint.

Conclusion:

The bug seems to arise from a state restoration issue in the ModelCheckpoint callback when resuming training from a checkpoint. It fails to remember its previous "best" state and does not delete older checkpoints as it should.

@danielzeng-gt danielzeng-gt changed the title ModelCheckpoint Doesn't Overwrite Old Checkpoints When Resuming Training ModelCheckpoint Doesn't Delete Old Best Checkpoints When Resuming Training Oct 13, 2023
@awaelchli awaelchli added the repro needed The issue is missing a reproducible example label Jan 29, 2024
@leng-yue
Copy link
Contributor

leng-yue commented May 7, 2024

I met same issue, I understand that maybe a breaking change, can wee add an option to handle that?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working callback: model checkpoint repro needed The issue is missing a reproducible example ver: 1.9.x
Projects
None yet
Development

No branches or pull requests

3 participants