Skip to content

Commit

Permalink
Fix LightningCLI failing when both module and data module save hyperp…
Browse files Browse the repository at this point in the history
…arameters (#20221)

* Fix LightningCLI failing when both module and data module save hyperparameters due to conflicting internal  parameter

* Update changelog pull link

* Only skip logging internal LightningCLI params

* Only skip logging internal LightningCLI params

* Only skip _class_path

---------

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Luca Antiga <[email protected]>
  • Loading branch information
3 people authored Dec 11, 2024
1 parent 60289d7 commit a9125c2
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,21 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [unreleased] - YYYY-MM-DD

### Added

### Changed

- Merging of hparams when logging now ignores parameter names that begin with underscore `_` ([#20221](https://github.com/Lightning-AI/pytorch-lightning/pull/20221))

### Removed

### Fixed

- Fix LightningCLI failing when both module and data module save hyperparameters due to conflicting internal `_class_path` parameter ([#20221](https://github.com/Lightning-AI/pytorch-lightning/pull/20221))


## [2.4.0] - 2024-08-06

### Added
Expand Down
7 changes: 7 additions & 0 deletions src/lightning/pytorch/loggers/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def _log_hyperparams(trainer: "pl.Trainer") -> None:
lightning_hparams = pl_module.hparams_initial
inconsistent_keys = []
for key in lightning_hparams.keys() & datamodule_hparams.keys():
if key == "_class_path":
# Skip LightningCLI's internal hparam
continue
lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key]
if (
type(lm_val) != type(dm_val)
Expand All @@ -88,6 +91,10 @@ def _log_hyperparams(trainer: "pl.Trainer") -> None:
elif datamodule_log_hyperparams:
hparams_initial = trainer.datamodule.hparams_initial

# Don't log LightningCLI's internal hparam
if hparams_initial is not None:
hparams_initial = {k: v for k, v in hparams_initial.items() if k != "_class_path"}

for logger in trainer.loggers:
if hparams_initial is not None:
logger.log_hyperparams(hparams_initial)
Expand Down
23 changes: 23 additions & 0 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,29 @@ def test_lightning_cli_save_hyperparameters_untyped_module(cleandir):
assert model.kwargs == {"x": 1}


class TestDataSaveHparams(BoringDataModule):
def __init__(self, batch_size: int = 32, num_workers: int = 4):
super().__init__()
self.save_hyperparameters()
self.batch_size = batch_size
self.num_workers = num_workers


def test_lightning_cli_save_hyperparameters_merge(cleandir):
config = {
"model": {
"class_path": f"{__name__}.TestModelSaveHparams",
},
"data": {
"class_path": f"{__name__}.TestDataSaveHparams",
},
}
with mock.patch("sys.argv", ["any.py", "fit", f"--config={json.dumps(config)}", "--trainer.max_epochs=1"]):
cli = LightningCLI(auto_configure_optimizers=False)
assert set(cli.model.hparams) == {"optimizer", "scheduler", "activation", "_instantiator", "_class_path"}
assert set(cli.datamodule.hparams) == {"batch_size", "num_workers", "_instantiator", "_class_path"}


@pytest.mark.parametrize("fn", [fn.value for fn in TrainerFn])
def test_lightning_cli_trainer_fn(fn):
class TestCLI(LightningCLI):
Expand Down

0 comments on commit a9125c2

Please sign in to comment.