Skip to content

Commit

Permalink
feat: support LoRA adapters for vllm runtime (#774)
Browse files Browse the repository at this point in the history
**Reason for Change**:

support LoRA adapters for vllm runtime

Signed-off-by: jerryzhuang <[email protected]>
  • Loading branch information
zhuangqh authored Dec 11, 2024
1 parent b099c66 commit 83f25cd
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 21 deletions.
57 changes: 40 additions & 17 deletions presets/workspace/inference/vllm/inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import logging
import gc
import os
from typing import Callable
from typing import Callable, Optional, List

import uvloop
import torch
from vllm.utils import FlexibleArgumentParser
import vllm.entrypoints.openai.api_server as api_server
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
from vllm.engine.llm_engine import (LLMEngine, EngineArgs, EngineConfig)
from vllm.executor.executor_base import ExecutorBase

Expand All @@ -26,12 +27,12 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
port = 5000 + local_rank # Adjust port based on local rank

server_default_args = {
"disable-frontend-multiprocessing": False,
"port": port
"disable_frontend_multiprocessing": False,
"port": port,
}
parser.set_defaults(**server_default_args)

# See https://docs.vllm.ai/en/latest/models/engine_args.html for more args
# See https://docs.vllm.ai/en/stable/models/engine_args.html for more args
engine_default_args = {
"model": "/workspace/vllm/weights",
"cpu_offload_gb": 0,
Expand All @@ -42,9 +43,27 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
}
parser.set_defaults(**engine_default_args)

# KAITO only args
# They should start with "kaito-" prefix to avoid conflict with vllm args
parser.add_argument("--kaito-adapters-dir", type=str, default="/mnt/adapter", help="Directory where adapters are stored in KAITO preset.")

return parser

def find_max_available_seq_len(engine_config: EngineConfig) -> int:
def load_lora_adapters(adapters_dir: str) -> Optional[LoRAModulePath]:
lora_list: List[LoRAModulePath] = []

logger.info(f"Loading LoRA adapters from {adapters_dir}")
if not os.path.exists(adapters_dir):
return lora_list

for adapter in os.listdir(adapters_dir):
adapter_path = os.path.join(adapters_dir, adapter)
if os.path.isdir(adapter_path):
lora_list.append(LoRAModulePath(adapter, adapter_path))

return lora_list

def find_max_available_seq_len(engine_config: EngineConfig, max_probe_steps: int) -> int:
"""
Load model and run profiler to find max available seq len.
"""
Expand All @@ -63,13 +82,6 @@ def find_max_available_seq_len(engine_config: EngineConfig) -> int:
observability_config=engine_config.observability_config,
)

max_probe_steps = 6
if os.getenv("MAX_PROBE_STEPS") is not None:
try:
max_probe_steps = int(os.getenv("MAX_PROBE_STEPS"))
except ValueError:
raise ValueError("MAX_PROBE_STEPS must be an integer.")

model_max_blocks = int(engine_config.model_config.max_model_len / engine_config.cache_config.block_size)
res = binary_search_with_limited_steps(model_max_blocks, max_probe_steps, lambda x: is_context_length_safe(executor, x))

Expand Down Expand Up @@ -131,23 +143,34 @@ def is_context_length_safe(executor: ExecutorBase, num_gpu_blocks: int) -> bool:
parser = make_arg_parser(parser)
args = parser.parse_args()

# set LoRA adapters
if args.lora_modules is None:
args.lora_modules = load_lora_adapters(args.kaito_adapters_dir)

if args.max_model_len is None:
max_probe_steps = 6
if os.getenv("MAX_PROBE_STEPS") is not None:
try:
max_probe_steps = int(os.getenv("MAX_PROBE_STEPS"))
except ValueError:
raise ValueError("MAX_PROBE_STEPS must be an integer.")

engine_args = EngineArgs.from_cli_args(args)
# read the model config from hf weights path.
# vllm will perform different parser for different model architectures
# and read it into a unified EngineConfig.
engine_config = engine_args.create_engine_config()

logger.info("Try run profiler to find max available seq len")
available_seq_len = find_max_available_seq_len(engine_config)
max_model_len = engine_config.model_config.max_model_len
available_seq_len = max_model_len
if max_probe_steps > 0:
logger.info("Try run profiler to find max available seq len")
available_seq_len = find_max_available_seq_len(engine_config, max_probe_steps)
# see https://github.com/vllm-project/vllm/blob/v0.6.3/vllm/worker/worker.py#L262
if available_seq_len <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
max_model_len = engine_config.model_config.max_model_len
if available_seq_len > max_model_len:
available_seq_len = max_model_len

if available_seq_len != max_model_len:
logger.info(f"Set max_model_len from {max_model_len} to {available_seq_len}")
Expand Down
33 changes: 29 additions & 4 deletions presets/workspace/inference/vllm/tests/test_vllm_inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,42 @@
sys.path.append(parent_dir)

from inference_api import binary_search_with_limited_steps
from huggingface_hub import snapshot_download
import shutil

TEST_MODEL = "facebook/opt-125m"
TEST_ADAPTER_NAME1 = "mylora1"
TEST_ADAPTER_NAME2 = "mylora2"
CHAT_TEMPLATE = ("{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}"
"{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}"
"{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}")

@pytest.fixture
def setup_server(request):
@pytest.fixture(scope="session", autouse=True)
def setup_server(request, tmp_path_factory, autouse=True):
if os.getenv("DEVICE") == "cpu":
pytest.skip("Skipping test on cpu device")
print("\n>>> Doing setup")
port = find_available_port()
global TEST_PORT
TEST_PORT = port

tmp_file_dir = tmp_path_factory.mktemp("adapter")
print(f"Downloading adapter image to {tmp_file_dir}")
snapshot_download(repo_id="slall/facebook-opt-125M-imdb-lora", local_dir=str(tmp_file_dir / TEST_ADAPTER_NAME1))
snapshot_download(repo_id="slall/facebook-opt-125M-imdb-lora", local_dir=str(tmp_file_dir / TEST_ADAPTER_NAME2))

args = [
"python3",
os.path.join(parent_dir, "inference_api.py"),
"--model", TEST_MODEL,
"--chat-template", CHAT_TEMPLATE,
"--port", str(TEST_PORT)
"--port", str(TEST_PORT),
"--kaito-adapters-dir", tmp_file_dir,
]
print(f">>> Starting server on port {TEST_PORT}")
process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
env = os.environ.copy()
env["MAX_PROBE_STEPS"] = "0"
process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env)

def fin():
process.terminate()
Expand All @@ -47,6 +59,8 @@ def fin():
stdout = process.stdout.read().decode()
print(f">>> Server stdout: {stdout}")
print ("\n>>> Doing teardown")
print(f"Removing adapter image from {tmp_file_dir}")
shutil.rmtree(tmp_file_dir)

if not is_port_open("localhost", TEST_PORT):
fin()
Expand Down Expand Up @@ -115,6 +129,17 @@ def test_chat_completions_api(setup_server):
assert "content" in choice["message"], "Each message should contain a 'content' key"
assert len(choice["message"]["content"]) > 0, "The completion text should not be empty"

def test_model_list(setup_server):
response = requests.get(f"http://127.0.0.1:{TEST_PORT}/v1/models")
data = response.json()

assert "data" in data, f"The response should contain a 'data' key, but got {data}"
assert len(data["data"]) == 3, f"The response should contain three models, but got {data['data']}"
assert data["data"][0]["id"] == TEST_MODEL, f"The first model should be the test model, but got {data['data'][0]['id']}"
assert data["data"][1]["id"] == TEST_ADAPTER_NAME2, f"The second model should be the test adapter, but got {data['data'][1]['id']}"
assert data["data"][1]["parent"] == TEST_MODEL, f"The second model should have the test model as parent, but got {data['data'][1]['parent']}"
assert data["data"][2]["id"] == TEST_ADAPTER_NAME1, f"The third model should be the test adapter, but got {data['data'][2]['id']}"
assert data["data"][2]["parent"] == TEST_MODEL, f"The third model should have the test model as parent, but got {data['data'][2]['parent']}"

def test_binary_search_with_limited_steps():

Expand Down

0 comments on commit 83f25cd

Please sign in to comment.