Skip to content

Commit

Permalink
[feature] add basic multi-backends support (#236)
Browse files Browse the repository at this point in the history
* add basic multi-backends support

* fix lint error

* move the backend support check outside the function

* add nonecontext

* selective import bitsandbytes

* selective import BitsAndBytesConfig

* fix lint error

* fix circular import

* format with black

* fix ci errors

---------

Co-authored-by: yezhem <[email protected]>
  • Loading branch information
mikecovlee and yezhengmao1 authored Jul 7, 2024
1 parent 2b8923a commit 9f8c70d
Show file tree
Hide file tree
Showing 14 changed files with 426 additions and 31 deletions.
50 changes: 50 additions & 0 deletions mlora/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os
from typing import Optional

import torch

from .common import BasicBackend
from .cpu import CPUBackend
from .cuda import CUDABackend
from .mps import MPSBackend

_backend: Optional[BasicBackend] = None


backend_dict = {
"CUDA": CUDABackend,
"MPS": MPSBackend,
"CPU": CPUBackend,
}


def _init_backend() -> BasicBackend:
env = os.getenv("MLORA_BACKEND_TYPE")
if env is not None:
env = env.upper()
if env not in backend_dict:
raise ValueError(f"Assigning unknown backend type {env}")
return backend_dict[env]()
elif torch.cuda.is_available():
return CUDABackend()
elif torch.backends.mps.is_available():
return MPSBackend()
else:
return CPUBackend()


def get_backend() -> BasicBackend:
global _backend
if _backend is None:
_backend = _init_backend()

return _backend


__all__ = [
"BasicBackend",
"CUDABackend",
"MPSBackend",
"CPUBackend",
"get_backend",
]
60 changes: 60 additions & 0 deletions mlora/backends/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import logging
import random

import torch

from mlora.utils import NoneContexts


class BasicBackend:
def name(self) -> str:
return "none"

def device_name(self) -> str:
return "null"

def is_available(self) -> bool:
return False

def is_initialized(self) -> bool:
return False

def is_bf16_supported(self) -> bool:
return False

def manual_seed(self, seed: int):
random.seed(seed)
torch.manual_seed(seed)

def empty_cache(self):
pass

def use_deterministic_algorithms(self, mode: bool):
torch.use_deterministic_algorithms(mode)

def allow_tf32(self, mode: bool):
pass

def set_rng_state(self, device, state):
pass

def get_rng_state(self, device):
pass

def fork_rng(self, rng_devices: list):
return torch.random.fork_rng(
devices=rng_devices, device_type=self.device_name()
)

def autocast(self, **kwargs):
return NoneContexts()

def check_available(self):
if not self.is_available():
logging.error(f"{self.name()} not available.")
return False
if not self.is_initialized():
logging.error(f"{self.name()} not initialized.")
return False
logging.info(f"{self.name()} initialized successfully.")
return True
60 changes: 60 additions & 0 deletions mlora/backends/cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import contextlib
import logging

import torch

from .common import BasicBackend

_cpu_bf16_supported = None


class CPUBackend(BasicBackend):
def __init__(self) -> None:
super().__init__()

def name(self) -> str:
return "CPU"

def device_name(self) -> str:
return "cpu"

def is_available(self) -> bool:
return True

def is_initialized(self) -> bool:
return False

def is_bf16_supported(self) -> bool:
# TODO: change to official implementation
global _cpu_bf16_supported
if _cpu_bf16_supported is None:
try:
torch.ones(5, dtype=torch.bfloat16, device="cpu")
_cpu_bf16_supported = True
except TypeError:
_cpu_bf16_supported = False

return _cpu_bf16_supported

def allow_tf32(self, mode: bool):
assert not mode, "Enabling tf32 for CPU."

def set_rng_state(self, device: int, state: torch.Tensor):
raise RuntimeError("Can not setting rng state for CPU.")

def get_rng_state(self, device: int):
raise RuntimeError("Can not setting rng state for CPU.")

@contextlib.contextmanager
def fork_rng(self, rng_devices: list):
# TODO: change to official implementation
assert len(rng_devices) == 0
cpu_rng_state = torch.get_rng_state()
try:
yield
finally:
torch.set_rng_state(cpu_rng_state)

def check_available(self):
logging.info(f"{self.name()} initialized successfully.")
return True
50 changes: 50 additions & 0 deletions mlora/backends/cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch

from .common import BasicBackend


class CUDABackend(BasicBackend):
def __init__(self) -> None:
super().__init__()
torch.cuda.init()

def name(self) -> str:
return "NVIDIA CUDA"

def device_name(self) -> str:
return "cuda"

def is_available(self) -> bool:
return torch.cuda.is_available()

def is_initialized(self) -> bool:
return torch.cuda.is_initialized()

def is_bf16_supported(self) -> bool:
return torch.cuda.is_bf16_supported()

def manual_seed(self, seed: int):
super().manual_seed(seed)
torch.cuda.manual_seed_all(seed)

def empty_cache(self):
torch.cuda.empty_cache()

def use_deterministic_algorithms(self, mode: bool):
torch.backends.cudnn.benchmark = not mode
torch.backends.cudnn.deterministic = mode

def allow_tf32(self, mode: bool):
torch.backends.cudnn.allow_tf32 = mode
torch.backends.cuda.matmul.allow_tf32 = mode

def set_rng_state(self, device, state):
with torch.cuda.device(device):
return torch.cuda.set_rng_state(state)

def get_rng_state(self, device):
with torch.cuda.device(device):
return torch.cuda.get_rng_state()

def autocast(self, **kwargs):
return torch.cuda.amp.autocast(**kwargs)
72 changes: 72 additions & 0 deletions mlora/backends/mps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import contextlib

import torch

from .common import BasicBackend

_mps_bf16_supported = None


class MPSBackend(BasicBackend):
def __init__(self) -> None:
super().__init__()

def name(self) -> str:
return "APPLE MPS"

def device_name(self) -> str:
return "mps"

def is_available(self) -> bool:
return torch.backends.mps.is_available()

def is_initialized(self) -> bool:
# TODO: change to official implementation
return not torch.mps._is_in_bad_fork()

def is_bf16_supported(self) -> bool:
# TODO: change to official implementation
global _mps_bf16_supported
if _mps_bf16_supported is None:
try:
torch.ones(5, dtype=torch.bfloat16, device="mps")
_mps_bf16_supported = True
except TypeError:
_mps_bf16_supported = False

return _mps_bf16_supported

def manual_seed(self, seed: int):
super().manual_seed(seed)
torch.mps.manual_seed(seed)

def empty_cache(self):
torch.mps.empty_cache()

def allow_tf32(self, mode: bool):
assert not mode, "Enabling tf32 for MPS devices."

def set_rng_state(self, device: int, state: torch.Tensor):
assert device == 0
return torch.mps.set_rng_state(state)

def get_rng_state(self, device: int):
assert device == 0
return torch.mps.get_rng_state()

@contextlib.contextmanager
def fork_rng(self, rng_devices: list):
# TODO: change to official implementation
assert len(rng_devices) == 1 and rng_devices[0] == 0
cpu_rng_state = torch.get_rng_state()
device_rng_states = torch.mps.get_rng_state()
try:
yield
finally:
torch.set_rng_state(cpu_rng_state)
torch.mps.set_rng_state(device_rng_states)

def autocast(self, **kwargs):
# TODO: change to official implementation
# running with compatible mode
return torch.cuda.amp.autocast(**kwargs)
5 changes: 5 additions & 0 deletions mlora/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .loader import load_model

__all__ = [
"load_model",
]
8 changes: 6 additions & 2 deletions mlora/model/llm/model_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
from mlora.model.checkpoint import CheckpointRecomputeFunction
from mlora.model.modules import AdapterModel, Decoder, Embedding, OutputLayer, RMSNorm
from mlora.profiler import nvtx_wrapper, set_backward_tracepoint
from mlora.utils import is_package_available

if is_package_available("bitsandbytes"):
from transformers import BitsAndBytesConfig
else:
from mlora.utils import BitsAndBytesConfig

from .model_llm import LLMModel

Expand Down Expand Up @@ -218,8 +224,6 @@ def create_device_map() -> str | Dict[str, str]:
load_4bit = precision in ["nf4", "fp4"]
load_8bit = precision == "int8"

from transformers import BitsAndBytesConfig

additional_load_args["torch_dtype"] = torch.float32
additional_load_args["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=load_4bit,
Expand Down
File renamed without changes.
11 changes: 8 additions & 3 deletions mlora/model/modules/linear.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from typing import Callable, List, MutableMapping, Optional, Tuple

import bitsandbytes
import torch
import torch.nn.functional as F

from mlora.model.args import ModelData
from mlora.profiler import nvtx_range, set_backward_tracepoint
from mlora.utils import is_package_available

if is_package_available("bitsandbytes"):
from bitsandbytes.nn import Linear4bit, Linear8bitLt
else:
from mlora.utils import Linear8bitLt, Linear4bit

from .adapter import Adapter
from .dora import DoRA
Expand All @@ -20,8 +25,8 @@ def __init__(self, weight: torch.nn.Module):
super().__init__()

if not isinstance(weight, torch.nn.Linear):
assert isinstance(weight, bitsandbytes.nn.Linear8bitLt) or isinstance(
weight, bitsandbytes.nn.Linear4bit
assert isinstance(weight, Linear8bitLt) or isinstance(
weight, Linear4bit
), f"error type - {type(weight)}."
else:
weight.requires_grad_(False)
Expand Down
Loading

0 comments on commit 9f8c70d

Please sign in to comment.