diff --git a/presets/workspace/inference/vllm/inference_api.py b/presets/workspace/inference/vllm/inference_api.py index f6d32b576..3425b3945 100644 --- a/presets/workspace/inference/vllm/inference_api.py +++ b/presets/workspace/inference/vllm/inference_api.py @@ -3,12 +3,13 @@ import logging import gc import os -from typing import Callable +from typing import Callable, Optional, List import uvloop import torch from vllm.utils import FlexibleArgumentParser import vllm.entrypoints.openai.api_server as api_server +from vllm.entrypoints.openai.serving_engine import LoRAModulePath from vllm.engine.llm_engine import (LLMEngine, EngineArgs, EngineConfig) from vllm.executor.executor_base import ExecutorBase @@ -26,12 +27,12 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: port = 5000 + local_rank # Adjust port based on local rank server_default_args = { - "disable-frontend-multiprocessing": False, - "port": port + "disable_frontend_multiprocessing": False, + "port": port, } parser.set_defaults(**server_default_args) - # See https://docs.vllm.ai/en/latest/models/engine_args.html for more args + # See https://docs.vllm.ai/en/stable/models/engine_args.html for more args engine_default_args = { "model": "/workspace/vllm/weights", "cpu_offload_gb": 0, @@ -42,9 +43,27 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: } parser.set_defaults(**engine_default_args) + # KAITO only args + # They should start with "kaito-" prefix to avoid conflict with vllm args + parser.add_argument("--kaito-adapters-dir", type=str, default="/mnt/adapter", help="Directory where adapters are stored in KAITO preset.") + return parser -def find_max_available_seq_len(engine_config: EngineConfig) -> int: +def load_lora_adapters(adapters_dir: str) -> Optional[LoRAModulePath]: + lora_list: List[LoRAModulePath] = [] + + logger.info(f"Loading LoRA adapters from {adapters_dir}") + if not os.path.exists(adapters_dir): + return lora_list + + for adapter in os.listdir(adapters_dir): + adapter_path = os.path.join(adapters_dir, adapter) + if os.path.isdir(adapter_path): + lora_list.append(LoRAModulePath(adapter, adapter_path)) + + return lora_list + +def find_max_available_seq_len(engine_config: EngineConfig, max_probe_steps: int) -> int: """ Load model and run profiler to find max available seq len. """ @@ -63,13 +82,6 @@ def find_max_available_seq_len(engine_config: EngineConfig) -> int: observability_config=engine_config.observability_config, ) - max_probe_steps = 6 - if os.getenv("MAX_PROBE_STEPS") is not None: - try: - max_probe_steps = int(os.getenv("MAX_PROBE_STEPS")) - except ValueError: - raise ValueError("MAX_PROBE_STEPS must be an integer.") - model_max_blocks = int(engine_config.model_config.max_model_len / engine_config.cache_config.block_size) res = binary_search_with_limited_steps(model_max_blocks, max_probe_steps, lambda x: is_context_length_safe(executor, x)) @@ -131,23 +143,34 @@ def is_context_length_safe(executor: ExecutorBase, num_gpu_blocks: int) -> bool: parser = make_arg_parser(parser) args = parser.parse_args() + # set LoRA adapters + if args.lora_modules is None: + args.lora_modules = load_lora_adapters(args.kaito_adapters_dir) + if args.max_model_len is None: + max_probe_steps = 6 + if os.getenv("MAX_PROBE_STEPS") is not None: + try: + max_probe_steps = int(os.getenv("MAX_PROBE_STEPS")) + except ValueError: + raise ValueError("MAX_PROBE_STEPS must be an integer.") + engine_args = EngineArgs.from_cli_args(args) # read the model config from hf weights path. # vllm will perform different parser for different model architectures # and read it into a unified EngineConfig. engine_config = engine_args.create_engine_config() - logger.info("Try run profiler to find max available seq len") - available_seq_len = find_max_available_seq_len(engine_config) + max_model_len = engine_config.model_config.max_model_len + available_seq_len = max_model_len + if max_probe_steps > 0: + logger.info("Try run profiler to find max available seq len") + available_seq_len = find_max_available_seq_len(engine_config, max_probe_steps) # see https://github.com/vllm-project/vllm/blob/v0.6.3/vllm/worker/worker.py#L262 if available_seq_len <= 0: raise ValueError("No available memory for the cache blocks. " "Try increasing `gpu_memory_utilization` when " "initializing the engine.") - max_model_len = engine_config.model_config.max_model_len - if available_seq_len > max_model_len: - available_seq_len = max_model_len if available_seq_len != max_model_len: logger.info(f"Set max_model_len from {max_model_len} to {available_seq_len}") diff --git a/presets/workspace/inference/vllm/tests/test_vllm_inference_api.py b/presets/workspace/inference/vllm/tests/test_vllm_inference_api.py index 811d93306..20d5eeeb6 100644 --- a/presets/workspace/inference/vllm/tests/test_vllm_inference_api.py +++ b/presets/workspace/inference/vllm/tests/test_vllm_inference_api.py @@ -14,14 +14,18 @@ sys.path.append(parent_dir) from inference_api import binary_search_with_limited_steps +from huggingface_hub import snapshot_download +import shutil TEST_MODEL = "facebook/opt-125m" +TEST_ADAPTER_NAME1 = "mylora1" +TEST_ADAPTER_NAME2 = "mylora2" CHAT_TEMPLATE = ("{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}" "{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}" "{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}") -@pytest.fixture -def setup_server(request): +@pytest.fixture(scope="session", autouse=True) +def setup_server(request, tmp_path_factory, autouse=True): if os.getenv("DEVICE") == "cpu": pytest.skip("Skipping test on cpu device") print("\n>>> Doing setup") @@ -29,15 +33,23 @@ def setup_server(request): global TEST_PORT TEST_PORT = port + tmp_file_dir = tmp_path_factory.mktemp("adapter") + print(f"Downloading adapter image to {tmp_file_dir}") + snapshot_download(repo_id="slall/facebook-opt-125M-imdb-lora", local_dir=str(tmp_file_dir / TEST_ADAPTER_NAME1)) + snapshot_download(repo_id="slall/facebook-opt-125M-imdb-lora", local_dir=str(tmp_file_dir / TEST_ADAPTER_NAME2)) + args = [ "python3", os.path.join(parent_dir, "inference_api.py"), "--model", TEST_MODEL, "--chat-template", CHAT_TEMPLATE, - "--port", str(TEST_PORT) + "--port", str(TEST_PORT), + "--kaito-adapters-dir", tmp_file_dir, ] print(f">>> Starting server on port {TEST_PORT}") - process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + env = os.environ.copy() + env["MAX_PROBE_STEPS"] = "0" + process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env) def fin(): process.terminate() @@ -47,6 +59,8 @@ def fin(): stdout = process.stdout.read().decode() print(f">>> Server stdout: {stdout}") print ("\n>>> Doing teardown") + print(f"Removing adapter image from {tmp_file_dir}") + shutil.rmtree(tmp_file_dir) if not is_port_open("localhost", TEST_PORT): fin() @@ -115,6 +129,17 @@ def test_chat_completions_api(setup_server): assert "content" in choice["message"], "Each message should contain a 'content' key" assert len(choice["message"]["content"]) > 0, "The completion text should not be empty" +def test_model_list(setup_server): + response = requests.get(f"http://127.0.0.1:{TEST_PORT}/v1/models") + data = response.json() + + assert "data" in data, f"The response should contain a 'data' key, but got {data}" + assert len(data["data"]) == 3, f"The response should contain three models, but got {data['data']}" + assert data["data"][0]["id"] == TEST_MODEL, f"The first model should be the test model, but got {data['data'][0]['id']}" + assert data["data"][1]["id"] == TEST_ADAPTER_NAME2, f"The second model should be the test adapter, but got {data['data'][1]['id']}" + assert data["data"][1]["parent"] == TEST_MODEL, f"The second model should have the test model as parent, but got {data['data'][1]['parent']}" + assert data["data"][2]["id"] == TEST_ADAPTER_NAME1, f"The third model should be the test adapter, but got {data['data'][2]['id']}" + assert data["data"][2]["parent"] == TEST_MODEL, f"The third model should have the test model as parent, but got {data['data'][2]['parent']}" def test_binary_search_with_limited_steps():