From be608fa355b835b9b0727df2f5476f0a1d90bc59 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Tue, 26 Nov 2024 17:19:22 +0100 Subject: [PATCH] Make plugin type check more flexible (Fabric) (#20452) * make plugin type check more flexible * Change signature and make the equivalent changes to Fabric connector --------- Co-authored-by: Jianing Yang Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- src/lightning/fabric/connector.py | 7 ++++--- .../pytorch/trainer/connectors/accelerator_connector.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index 9161d5f1bd6c2..0ade7f69c3629 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -13,6 +13,7 @@ # limitations under the License. import os from collections import Counter +from collections.abc import Iterable from typing import Any, Optional, Union, cast import torch @@ -102,7 +103,7 @@ def __init__( devices: Union[list[int], str, int] = "auto", num_nodes: int = 1, precision: Optional[_PRECISION_INPUT] = None, - plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]] = None, + plugins: Optional[Union[_PLUGIN_INPUT, Iterable[_PLUGIN_INPUT]]] = None, ) -> None: # These arguments can be set through environment variables set by the CLI accelerator = self._argument_from_env("accelerator", accelerator, default="auto") @@ -165,7 +166,7 @@ def _check_config_and_set_final_flags( strategy: Union[str, Strategy], accelerator: Union[str, Accelerator], precision: Optional[_PRECISION_INPUT], - plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]], + plugins: Optional[Union[_PLUGIN_INPUT, Iterable[_PLUGIN_INPUT]]], ) -> None: """This method checks: @@ -180,7 +181,7 @@ def _check_config_and_set_final_flags( """ if plugins is not None: - plugins = [plugins] if not isinstance(plugins, list) else plugins + plugins = [plugins] if not isinstance(plugins, Iterable) else plugins if isinstance(strategy, str): strategy = strategy.lower() diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index b892fcc3290e4..40ee0eef4de33 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -79,7 +79,7 @@ def __init__( num_nodes: int = 1, accelerator: Union[str, Accelerator] = "auto", strategy: Union[str, Strategy] = "auto", - plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]] = None, + plugins: Optional[Union[_PLUGIN_INPUT, Iterable[_PLUGIN_INPUT]]] = None, precision: Optional[_PRECISION_INPUT] = None, sync_batchnorm: bool = False, benchmark: Optional[bool] = None, @@ -167,7 +167,7 @@ def _check_config_and_set_final_flags( strategy: Union[str, Strategy], accelerator: Union[str, Accelerator], precision: Optional[_PRECISION_INPUT], - plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]], + plugins: Optional[Union[_PLUGIN_INPUT, Iterable[_PLUGIN_INPUT]]], sync_batchnorm: bool, ) -> None: """This method checks: