Skip to content

Commit

Permalink
Make plugin type check more flexible (Fabric) (#20452)
Browse files Browse the repository at this point in the history
* make plugin type check more flexible

* Change signature and make the equivalent changes to Fabric connector

---------

Co-authored-by: Jianing Yang <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
3 people authored Nov 26, 2024
1 parent 13b74f7 commit be608fa
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
7 changes: 4 additions & 3 deletions src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import os
from collections import Counter
from collections.abc import Iterable
from typing import Any, Optional, Union, cast

import torch
Expand Down Expand Up @@ -102,7 +103,7 @@ def __init__(
devices: Union[list[int], str, int] = "auto",
num_nodes: int = 1,
precision: Optional[_PRECISION_INPUT] = None,
plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]] = None,
plugins: Optional[Union[_PLUGIN_INPUT, Iterable[_PLUGIN_INPUT]]] = None,
) -> None:
# These arguments can be set through environment variables set by the CLI
accelerator = self._argument_from_env("accelerator", accelerator, default="auto")
Expand Down Expand Up @@ -165,7 +166,7 @@ def _check_config_and_set_final_flags(
strategy: Union[str, Strategy],
accelerator: Union[str, Accelerator],
precision: Optional[_PRECISION_INPUT],
plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]],
plugins: Optional[Union[_PLUGIN_INPUT, Iterable[_PLUGIN_INPUT]]],
) -> None:
"""This method checks:
Expand All @@ -180,7 +181,7 @@ def _check_config_and_set_final_flags(
"""
if plugins is not None:
plugins = [plugins] if not isinstance(plugins, list) else plugins
plugins = [plugins] if not isinstance(plugins, Iterable) else plugins

if isinstance(strategy, str):
strategy = strategy.lower()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(
num_nodes: int = 1,
accelerator: Union[str, Accelerator] = "auto",
strategy: Union[str, Strategy] = "auto",
plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]] = None,
plugins: Optional[Union[_PLUGIN_INPUT, Iterable[_PLUGIN_INPUT]]] = None,
precision: Optional[_PRECISION_INPUT] = None,
sync_batchnorm: bool = False,
benchmark: Optional[bool] = None,
Expand Down Expand Up @@ -167,7 +167,7 @@ def _check_config_and_set_final_flags(
strategy: Union[str, Strategy],
accelerator: Union[str, Accelerator],
precision: Optional[_PRECISION_INPUT],
plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]],
plugins: Optional[Union[_PLUGIN_INPUT, Iterable[_PLUGIN_INPUT]]],
sync_batchnorm: bool,
) -> None:
"""This method checks:
Expand Down

0 comments on commit be608fa

Please sign in to comment.