Skip to content

Commit

Permalink
Allow callbacks to be restored not just during training (#20403)
Browse files Browse the repository at this point in the history
* Allow callbacks to be restored not just during training

* add test case

* test test case failure

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix test case

---------

Co-authored-by: Alan Chu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Luca Antiga <[email protected]>
  • Loading branch information
4 people authored Nov 14, 2024
1 parent cd2bd3c commit c110f4f
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,7 @@ def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None
self.resume_start(checkpoint_path)
self.restore_model()
self.restore_datamodule()
if self.trainer.state.fn == TrainerFn.FITTING:
# restore callback states
self.restore_callbacks()
self.restore_callbacks()

def dump_checkpoint(self, weights_only: bool = False) -> dict:
"""Creating a model checkpoint dictionary object from various component states.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pytest
import torch
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.migration.utils import _set_version
Expand Down Expand Up @@ -234,3 +234,53 @@ def test_strict_loading(strict_loading, expected, tmp_path):
trainer = Trainer(default_root_dir=tmp_path, barebones=True, max_steps=2)
trainer.fit(model, ckpt_path=(tmp_path / "checkpoint.ckpt"))
model.load_state_dict.assert_called_once_with(ANY, strict=expected)


@pytest.mark.parametrize("trainer_fn", ["validate", "test", "predict"])
def test_restore_callbacks_in_non_fit_phases(tmp_path, trainer_fn):
"""Test that callbacks are properly restored in non-fit phases."""

class TestCallback(Callback):
def __init__(self):
self.restored = False

def on_load_checkpoint(self, trainer, pl_module, checkpoint):
if "callbacks" in checkpoint:
callback_state = checkpoint["callbacks"][self.__class__.__name__]
self.restored = callback_state["restored"]

def state_dict(self):
return {"restored": self.restored}

def on_save_checkpoint(self, trainer, pl_module, checkpoint):
checkpoint["callbacks"] = checkpoint.get("callbacks", {})
checkpoint["callbacks"][self.__class__.__name__] = self.state_dict()

# First create and train a model with the callback
callback = TestCallback()
model = BoringModel()
trainer = Trainer(default_root_dir=tmp_path, callbacks=[callback], max_steps=1)
trainer.fit(model)

# Set the callback state to True before saving
callback.restored = True
ckpt_path = tmp_path / "checkpoint.ckpt"
trainer.save_checkpoint(ckpt_path)

# Now create new instances and test restoration
new_callback = TestCallback()
new_model = BoringModel()
assert not new_callback.restored # Should start False

new_trainer = Trainer(default_root_dir=tmp_path, callbacks=[new_callback])

# Connect the model and restore callbacks before evaluation
new_trainer.strategy.connect(new_model)
new_trainer._checkpoint_connector.resume_start(ckpt_path)
new_trainer._checkpoint_connector.restore_callbacks()

# Run the evaluation phase (validate/test/predict)
fn = getattr(new_trainer, trainer_fn)
fn(new_model, ckpt_path=ckpt_path)

assert new_callback.restored # Should be True after loading the checkpoint

0 comments on commit c110f4f

Please sign in to comment.