-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feature] add basic multi-backends support (#236)
* add basic multi-backends support * fix lint error * move the backend support check outside the function * add nonecontext * selective import bitsandbytes * selective import BitsAndBytesConfig * fix lint error * fix circular import * format with black * fix ci errors --------- Co-authored-by: yezhem <[email protected]>
- Loading branch information
1 parent
2b8923a
commit 9f8c70d
Showing
14 changed files
with
426 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import os | ||
from typing import Optional | ||
|
||
import torch | ||
|
||
from .common import BasicBackend | ||
from .cpu import CPUBackend | ||
from .cuda import CUDABackend | ||
from .mps import MPSBackend | ||
|
||
_backend: Optional[BasicBackend] = None | ||
|
||
|
||
backend_dict = { | ||
"CUDA": CUDABackend, | ||
"MPS": MPSBackend, | ||
"CPU": CPUBackend, | ||
} | ||
|
||
|
||
def _init_backend() -> BasicBackend: | ||
env = os.getenv("MLORA_BACKEND_TYPE") | ||
if env is not None: | ||
env = env.upper() | ||
if env not in backend_dict: | ||
raise ValueError(f"Assigning unknown backend type {env}") | ||
return backend_dict[env]() | ||
elif torch.cuda.is_available(): | ||
return CUDABackend() | ||
elif torch.backends.mps.is_available(): | ||
return MPSBackend() | ||
else: | ||
return CPUBackend() | ||
|
||
|
||
def get_backend() -> BasicBackend: | ||
global _backend | ||
if _backend is None: | ||
_backend = _init_backend() | ||
|
||
return _backend | ||
|
||
|
||
__all__ = [ | ||
"BasicBackend", | ||
"CUDABackend", | ||
"MPSBackend", | ||
"CPUBackend", | ||
"get_backend", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import logging | ||
import random | ||
|
||
import torch | ||
|
||
from mlora.utils import NoneContexts | ||
|
||
|
||
class BasicBackend: | ||
def name(self) -> str: | ||
return "none" | ||
|
||
def device_name(self) -> str: | ||
return "null" | ||
|
||
def is_available(self) -> bool: | ||
return False | ||
|
||
def is_initialized(self) -> bool: | ||
return False | ||
|
||
def is_bf16_supported(self) -> bool: | ||
return False | ||
|
||
def manual_seed(self, seed: int): | ||
random.seed(seed) | ||
torch.manual_seed(seed) | ||
|
||
def empty_cache(self): | ||
pass | ||
|
||
def use_deterministic_algorithms(self, mode: bool): | ||
torch.use_deterministic_algorithms(mode) | ||
|
||
def allow_tf32(self, mode: bool): | ||
pass | ||
|
||
def set_rng_state(self, device, state): | ||
pass | ||
|
||
def get_rng_state(self, device): | ||
pass | ||
|
||
def fork_rng(self, rng_devices: list): | ||
return torch.random.fork_rng( | ||
devices=rng_devices, device_type=self.device_name() | ||
) | ||
|
||
def autocast(self, **kwargs): | ||
return NoneContexts() | ||
|
||
def check_available(self): | ||
if not self.is_available(): | ||
logging.error(f"{self.name()} not available.") | ||
return False | ||
if not self.is_initialized(): | ||
logging.error(f"{self.name()} not initialized.") | ||
return False | ||
logging.info(f"{self.name()} initialized successfully.") | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import contextlib | ||
import logging | ||
|
||
import torch | ||
|
||
from .common import BasicBackend | ||
|
||
_cpu_bf16_supported = None | ||
|
||
|
||
class CPUBackend(BasicBackend): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
|
||
def name(self) -> str: | ||
return "CPU" | ||
|
||
def device_name(self) -> str: | ||
return "cpu" | ||
|
||
def is_available(self) -> bool: | ||
return True | ||
|
||
def is_initialized(self) -> bool: | ||
return False | ||
|
||
def is_bf16_supported(self) -> bool: | ||
# TODO: change to official implementation | ||
global _cpu_bf16_supported | ||
if _cpu_bf16_supported is None: | ||
try: | ||
torch.ones(5, dtype=torch.bfloat16, device="cpu") | ||
_cpu_bf16_supported = True | ||
except TypeError: | ||
_cpu_bf16_supported = False | ||
|
||
return _cpu_bf16_supported | ||
|
||
def allow_tf32(self, mode: bool): | ||
assert not mode, "Enabling tf32 for CPU." | ||
|
||
def set_rng_state(self, device: int, state: torch.Tensor): | ||
raise RuntimeError("Can not setting rng state for CPU.") | ||
|
||
def get_rng_state(self, device: int): | ||
raise RuntimeError("Can not setting rng state for CPU.") | ||
|
||
@contextlib.contextmanager | ||
def fork_rng(self, rng_devices: list): | ||
# TODO: change to official implementation | ||
assert len(rng_devices) == 0 | ||
cpu_rng_state = torch.get_rng_state() | ||
try: | ||
yield | ||
finally: | ||
torch.set_rng_state(cpu_rng_state) | ||
|
||
def check_available(self): | ||
logging.info(f"{self.name()} initialized successfully.") | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import torch | ||
|
||
from .common import BasicBackend | ||
|
||
|
||
class CUDABackend(BasicBackend): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
torch.cuda.init() | ||
|
||
def name(self) -> str: | ||
return "NVIDIA CUDA" | ||
|
||
def device_name(self) -> str: | ||
return "cuda" | ||
|
||
def is_available(self) -> bool: | ||
return torch.cuda.is_available() | ||
|
||
def is_initialized(self) -> bool: | ||
return torch.cuda.is_initialized() | ||
|
||
def is_bf16_supported(self) -> bool: | ||
return torch.cuda.is_bf16_supported() | ||
|
||
def manual_seed(self, seed: int): | ||
super().manual_seed(seed) | ||
torch.cuda.manual_seed_all(seed) | ||
|
||
def empty_cache(self): | ||
torch.cuda.empty_cache() | ||
|
||
def use_deterministic_algorithms(self, mode: bool): | ||
torch.backends.cudnn.benchmark = not mode | ||
torch.backends.cudnn.deterministic = mode | ||
|
||
def allow_tf32(self, mode: bool): | ||
torch.backends.cudnn.allow_tf32 = mode | ||
torch.backends.cuda.matmul.allow_tf32 = mode | ||
|
||
def set_rng_state(self, device, state): | ||
with torch.cuda.device(device): | ||
return torch.cuda.set_rng_state(state) | ||
|
||
def get_rng_state(self, device): | ||
with torch.cuda.device(device): | ||
return torch.cuda.get_rng_state() | ||
|
||
def autocast(self, **kwargs): | ||
return torch.cuda.amp.autocast(**kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import contextlib | ||
|
||
import torch | ||
|
||
from .common import BasicBackend | ||
|
||
_mps_bf16_supported = None | ||
|
||
|
||
class MPSBackend(BasicBackend): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
|
||
def name(self) -> str: | ||
return "APPLE MPS" | ||
|
||
def device_name(self) -> str: | ||
return "mps" | ||
|
||
def is_available(self) -> bool: | ||
return torch.backends.mps.is_available() | ||
|
||
def is_initialized(self) -> bool: | ||
# TODO: change to official implementation | ||
return not torch.mps._is_in_bad_fork() | ||
|
||
def is_bf16_supported(self) -> bool: | ||
# TODO: change to official implementation | ||
global _mps_bf16_supported | ||
if _mps_bf16_supported is None: | ||
try: | ||
torch.ones(5, dtype=torch.bfloat16, device="mps") | ||
_mps_bf16_supported = True | ||
except TypeError: | ||
_mps_bf16_supported = False | ||
|
||
return _mps_bf16_supported | ||
|
||
def manual_seed(self, seed: int): | ||
super().manual_seed(seed) | ||
torch.mps.manual_seed(seed) | ||
|
||
def empty_cache(self): | ||
torch.mps.empty_cache() | ||
|
||
def allow_tf32(self, mode: bool): | ||
assert not mode, "Enabling tf32 for MPS devices." | ||
|
||
def set_rng_state(self, device: int, state: torch.Tensor): | ||
assert device == 0 | ||
return torch.mps.set_rng_state(state) | ||
|
||
def get_rng_state(self, device: int): | ||
assert device == 0 | ||
return torch.mps.get_rng_state() | ||
|
||
@contextlib.contextmanager | ||
def fork_rng(self, rng_devices: list): | ||
# TODO: change to official implementation | ||
assert len(rng_devices) == 1 and rng_devices[0] == 0 | ||
cpu_rng_state = torch.get_rng_state() | ||
device_rng_states = torch.mps.get_rng_state() | ||
try: | ||
yield | ||
finally: | ||
torch.set_rng_state(cpu_rng_state) | ||
torch.mps.set_rng_state(device_rng_states) | ||
|
||
def autocast(self, **kwargs): | ||
# TODO: change to official implementation | ||
# running with compatible mode | ||
return torch.cuda.amp.autocast(**kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .loader import load_model | ||
|
||
__all__ = [ | ||
"load_model", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.