From 60289d772f16362e7c75578a90801a2c1d7dffe4 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Wed, 11 Dec 2024 12:17:08 +0100 Subject: [PATCH] Force hook standalone tests to single device (#20491) * Force hook standalone tests to single device * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/tests_pytorch/models/test_hooks.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index 685bd6c0bdaef..1a8aeb4b297a9 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -61,7 +61,7 @@ def on_before_zero_grad(self, optimizer): model = CurrentTestModel() - trainer = Trainer(default_root_dir=tmp_path, max_steps=max_steps, max_epochs=2) + trainer = Trainer(devices=1, default_root_dir=tmp_path, max_steps=max_steps, max_epochs=2) assert model.on_before_zero_grad_called == 0 trainer.fit(model) assert max_steps == model.on_before_zero_grad_called @@ -406,7 +406,7 @@ def prepare_data(self): ... @pytest.mark.parametrize( "kwargs", [ - {}, + {"devices": 1}, # these precision plugins modify the optimization flow, so testing them explicitly pytest.param({"accelerator": "gpu", "devices": 1, "precision": "16-mixed"}, marks=RunIf(min_cuda_gpus=1)), pytest.param( @@ -528,6 +528,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmp_path): # initial training to get a checkpoint model = BoringModel() trainer = Trainer( + devices=1, default_root_dir=tmp_path, max_epochs=1, limit_train_batches=2, @@ -543,6 +544,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmp_path): callback = HookedCallback(called) # already performed 1 step, resume and do 2 more trainer = Trainer( + devices=1, default_root_dir=tmp_path, max_epochs=2, limit_train_batches=2, @@ -605,6 +607,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmp_path): # initial training to get a checkpoint model = BoringModel() trainer = Trainer( + devices=1, default_root_dir=tmp_path, max_steps=1, limit_val_batches=0, @@ -624,6 +627,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmp_path): train_batches = 2 steps_after_reload = 1 + train_batches trainer = Trainer( + devices=1, default_root_dir=tmp_path, max_steps=steps_after_reload, limit_val_batches=0, @@ -690,6 +694,7 @@ def test_trainer_model_hook_system_eval(tmp_path, override_on_x_model_train, bat assert is_overridden(f"on_{noun}_model_train", model) == override_on_x_model_train callback = HookedCallback(called) trainer = Trainer( + devices=1, default_root_dir=tmp_path, max_epochs=1, limit_val_batches=batches, @@ -731,7 +736,11 @@ def test_trainer_model_hook_system_predict(tmp_path): callback = HookedCallback(called) batches = 2 trainer = Trainer( - default_root_dir=tmp_path, limit_predict_batches=batches, enable_progress_bar=False, callbacks=[callback] + devices=1, + default_root_dir=tmp_path, + limit_predict_batches=batches, + enable_progress_bar=False, + callbacks=[callback], ) trainer.predict(model) expected = [ @@ -797,7 +806,7 @@ def predict_dataloader(self): model = CustomBoringModel() - trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=5) + trainer = Trainer(devices=1, default_root_dir=tmp_path, fast_dev_run=5) trainer.fit(model) trainer.test(model) @@ -812,6 +821,7 @@ def test_trainer_datamodule_hook_system(tmp_path): model = BoringModel() batches = 2 trainer = Trainer( + devices=1, default_root_dir=tmp_path, max_epochs=1, limit_train_batches=batches, @@ -887,7 +897,7 @@ class CustomHookedModel(HookedModel): assert is_overridden("configure_model", model) == override_configure_model datamodule = CustomHookedDataModule(ldm_called) - trainer = Trainer() + trainer = Trainer(devices=1) trainer.strategy.connect(model) trainer._data_connector.attach_data(model, datamodule=datamodule) ckpt_path = str(tmp_path / "file.ckpt") @@ -960,6 +970,7 @@ def predict_step(self, *args, **kwargs): model = MixedTrainModeModule() trainer = Trainer( + devices=1, default_root_dir=tmp_path, max_epochs=1, val_check_interval=1,