Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support config file for vllm runtime #780

Merged
merged 2 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions charts/kaito/workspace/templates/inference-params.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
apiVersion: v1
kind: ConfigMap
metadata:
name: inference-params-template
namespace: {{ .Release.Namespace }}
data:
inference_config.yaml: |
# Maximum number of steps to find the max available seq len fitting in the GPU memory.
max_probe_steps: 6

vllm:
cpu-offload-gb: 0
gpu-memory-utilization: 0.95
swap-space: 4

# max-seq-len-to-capture: 8192
# num-scheduler-steps: 1
# enable-chunked-prefill: false
# see https://docs.vllm.ai/en/stable/models/engine_args.html for more options.
187 changes: 126 additions & 61 deletions presets/workspace/inference/vllm/inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import logging
import gc
import os
from typing import Callable, Optional, List
import argparse
from typing import Callable, Optional, List, Any
import yaml
from dataclasses import dataclass

import uvloop
import torch
Expand All @@ -21,33 +24,88 @@
format='%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s',
datefmt='%m-%d %H:%M:%S')

def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
local_rank = int(os.environ.get("LOCAL_RANK",
0)) # Default to 0 if not set
port = 5000 + local_rank # Adjust port based on local rank

server_default_args = {
"disable_frontend_multiprocessing": False,
"port": port,
}
parser.set_defaults(**server_default_args)

# See https://docs.vllm.ai/en/stable/models/engine_args.html for more args
engine_default_args = {
"model": "/workspace/vllm/weights",
"cpu_offload_gb": 0,
"gpu_memory_utilization": 0.95,
"swap_space": 4,
"disable_log_stats": False,
"uvicorn_log_level": "error"
}
parser.set_defaults(**engine_default_args)

# KAITO only args
# They should start with "kaito-" prefix to avoid conflict with vllm args
parser.add_argument("--kaito-adapters-dir", type=str, default="/mnt/adapter", help="Directory where adapters are stored in KAITO preset.")

return parser
class KAITOArgumentParser(argparse.ArgumentParser):
vllm_parser = FlexibleArgumentParser(description="vLLM serving server")

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# Initialize vllm parser
self.vllm_parser = api_server.make_arg_parser(self.vllm_parser)
self._reset_vllm_defaults()

# KAITO only args
# They should start with "kaito-" prefix to avoid conflict with vllm args
self.add_argument("--kaito-adapters-dir", type=str, default="/mnt/adapter", help="Directory where adapters are stored in KAITO preset.")
self.add_argument("--kaito-config-file", type=str, default="", help="Additional args for KAITO preset.")
self.add_argument("--kaito-max-probe-steps", type=int, default=6, help="Maximum number of steps to find the max available seq len fitting in the GPU memory.")

def _reset_vllm_defaults(self):
local_rank = int(os.environ.get("LOCAL_RANK",
0)) # Default to 0 if not set
port = 5000 + local_rank # Adjust port based on local rank

server_default_args = {
"disable_frontend_multiprocessing": False,
"port": port,
}
self.vllm_parser.set_defaults(**server_default_args)

# See https://docs.vllm.ai/en/stable/models/engine_args.html for more args
engine_default_args = {
"model": "/workspace/vllm/weights",
"cpu_offload_gb": 0,
"gpu_memory_utilization": 0.95,
"swap_space": 4,
"disable_log_stats": False,
"uvicorn_log_level": "error"
}
self.vllm_parser.set_defaults(**engine_default_args)

def parse_args(self, *args, **kwargs):
args = super().parse_known_args(*args, **kwargs)
kaito_args = args[0]
runtime_args = args[1] # Remaining args

# Load KAITO config
if kaito_args.kaito_config_file:
file_config = KaitoConfig.from_yaml(kaito_args.kaito_config_file)
if kaito_args.kaito_max_probe_steps is None:
kaito_args.kaito_max_probe_steps = file_config.max_probe_steps

for key, value in file_config.vllm.items():
runtime_args.append(f"--{key}")
runtime_args.append(str(value))

vllm_args = self.vllm_parser.parse_args(runtime_args, **kwargs)
# Merge KAITO and vLLM args
return argparse.Namespace(**vars(kaito_args), **vars(vllm_args))

def print_help(self, file=None):
super().print_help(file)
print("\norignal vLLM server arguments:\n")
self.vllm_parser.print_help(file)

@dataclass
class KaitoConfig:
# Extra arguments for the vllm serving server, will be forwarded to the vllm server.
# This should be in key-value format.
vllm: dict[str, Any]

# Maximum number of steps to find the max available seq len fitting in the GPU memory.
max_probe_steps: int

@staticmethod
def from_yaml(yaml_file: str) -> 'KaitoConfig':
with open(yaml_file, 'r') as file:
config_data = yaml.safe_load(file)
return KaitoConfig(
vllm=config_data.get('vllm', {}),
max_probe_steps=config_data.get('max_probe_steps', 6)
)

def to_yaml(self) -> str:
return yaml.dump(self.__dict__)

def load_lora_adapters(adapters_dir: str) -> Optional[LoRAModulePath]:
lora_list: List[LoRAModulePath] = []
Expand Down Expand Up @@ -130,53 +188,60 @@ def is_context_length_safe(executor: ExecutorBase, num_gpu_blocks: int) -> bool:
executor.scheduler_config.max_num_batched_tokens = context_length

try:
logger.info(f"Try to determine available gpu blocks for context length {context_length}")
# see https://github.com/vllm-project/vllm/blob/v0.6.3/vllm/engine/llm_engine.py#L477
available_gpu_blocks, _ = executor.determine_num_available_blocks()
except torch.OutOfMemoryError as e:
return False

return available_gpu_blocks >= num_gpu_blocks

def try_set_max_available_seq_len(args: argparse.Namespace):
if args.max_model_len is not None:
logger.info(f"max_model_len is set to {args.max_model_len}, skip probing.")
return

max_probe_steps = 0
if args.kaito_max_probe_steps is not None:
try:
max_probe_steps = int(args.kaito_max_probe_steps)
except ValueError:
raise ValueError("kaito_max_probe_steps must be an integer.")

if max_probe_steps <= 0:
return

engine_args = EngineArgs.from_cli_args(args)
# read the model config from hf weights path.
# vllm will perform different parser for different model architectures
# and read it into a unified EngineConfig.
engine_config = engine_args.create_engine_config()

max_model_len = engine_config.model_config.max_model_len
available_seq_len = max_model_len
logger.info("Try run profiler to find max available seq len")
available_seq_len = find_max_available_seq_len(engine_config, max_probe_steps)
# see https://github.com/vllm-project/vllm/blob/v0.6.3/vllm/worker/worker.py#L262
if available_seq_len <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")

if available_seq_len != max_model_len:
logger.info(f"Set max_model_len from {max_model_len} to {available_seq_len}")
args.max_model_len = available_seq_len
else:
logger.info(f"Using model default max_model_len {max_model_len}")

if __name__ == "__main__":
parser = FlexibleArgumentParser(description='vLLM serving server')
parser = api_server.make_arg_parser(parser)
parser = make_arg_parser(parser)
parser = KAITOArgumentParser(description='KAITO wrapper of vLLM serving server')
args = parser.parse_args()

# set LoRA adapters
if args.lora_modules is None:
args.lora_modules = load_lora_adapters(args.kaito_adapters_dir)

if args.max_model_len is None:
max_probe_steps = 6
if os.getenv("MAX_PROBE_STEPS") is not None:
try:
max_probe_steps = int(os.getenv("MAX_PROBE_STEPS"))
except ValueError:
raise ValueError("MAX_PROBE_STEPS must be an integer.")

engine_args = EngineArgs.from_cli_args(args)
# read the model config from hf weights path.
# vllm will perform different parser for different model architectures
# and read it into a unified EngineConfig.
engine_config = engine_args.create_engine_config()

max_model_len = engine_config.model_config.max_model_len
available_seq_len = max_model_len
if max_probe_steps > 0:
logger.info("Try run profiler to find max available seq len")
available_seq_len = find_max_available_seq_len(engine_config, max_probe_steps)
# see https://github.com/vllm-project/vllm/blob/v0.6.3/vllm/worker/worker.py#L262
if available_seq_len <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")

if available_seq_len != max_model_len:
logger.info(f"Set max_model_len from {max_model_len} to {available_seq_len}")
args.max_model_len = available_seq_len
else:
logger.info(f"Using model default max_model_len {max_model_len}")
try_set_max_available_seq_len(args)

# Run the serving server
logger.info(f"Starting server on port {args.port}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
# Add the parent directory to sys.path
sys.path.append(parent_dir)

from inference_api import binary_search_with_limited_steps
from inference_api import binary_search_with_limited_steps, KaitoConfig
from huggingface_hub import snapshot_download
import shutil

TEST_MODEL = "facebook/opt-125m"
TEST_ADAPTER_NAME1 = "mylora1"
TEST_ADAPTER_NAME2 = "mylora2"
TEST_MODEL_NAME = "mymodel"
TEST_MODEL_LEN = 1024
CHAT_TEMPLATE = ("{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}"
"{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}"
"{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}")
Expand All @@ -33,18 +35,33 @@ def setup_server(request, tmp_path_factory, autouse=True):
global TEST_PORT
TEST_PORT = port

# prepare testing adapter images
tmp_file_dir = tmp_path_factory.mktemp("adapter")
print(f"Downloading adapter image to {tmp_file_dir}")
snapshot_download(repo_id="slall/facebook-opt-125M-imdb-lora", local_dir=str(tmp_file_dir / TEST_ADAPTER_NAME1))
snapshot_download(repo_id="slall/facebook-opt-125M-imdb-lora", local_dir=str(tmp_file_dir / TEST_ADAPTER_NAME2))

# prepare testing config file
config_file = tmp_file_dir / "config.yaml"
kaito_config = KaitoConfig(
vllm={
"max-model-len": TEST_MODEL_LEN,
"served-model-name": TEST_MODEL_NAME
},
max_probe_steps=0,
)
with open(config_file, "w") as f:
f.write(kaito_config.to_yaml())

args = [
"python3",
os.path.join(parent_dir, "inference_api.py"),
"--model", TEST_MODEL,
"--chat-template", CHAT_TEMPLATE,
"--max-model-len", "2048", # expected to be overridden by config file
"--port", str(TEST_PORT),
"--kaito-adapters-dir", tmp_file_dir,
"--kaito-config-file", config_file,
]
print(f">>> Starting server on port {TEST_PORT}")
env = os.environ.copy()
Expand Down Expand Up @@ -90,7 +107,7 @@ def find_available_port(start_port=5000, end_port=8000):

def test_completions_api(setup_server):
request_data = {
"model": TEST_MODEL,
"model": TEST_MODEL_NAME,
"prompt": "Say this is a test",
"max_tokens": 7,
"temperature": 0.5,
Expand All @@ -108,7 +125,7 @@ def test_completions_api(setup_server):

def test_chat_completions_api(setup_server):
request_data = {
"model": TEST_MODEL,
"model": TEST_MODEL_NAME,
"messages": [
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi there! How can I help you today?"}
Expand All @@ -135,11 +152,12 @@ def test_model_list(setup_server):

assert "data" in data, f"The response should contain a 'data' key, but got {data}"
assert len(data["data"]) == 3, f"The response should contain three models, but got {data['data']}"
assert data["data"][0]["id"] == TEST_MODEL, f"The first model should be the test model, but got {data['data'][0]['id']}"
assert data["data"][0]["id"] == TEST_MODEL_NAME, f"The first model should be the test model, but got {data['data'][0]['id']}"
assert data["data"][0]["max_model_len"] == TEST_MODEL_LEN, f"The first model should have the test model length, but got {data['data'][0]['max_model_len']}"
assert data["data"][1]["id"] == TEST_ADAPTER_NAME2, f"The second model should be the test adapter, but got {data['data'][1]['id']}"
assert data["data"][1]["parent"] == TEST_MODEL, f"The second model should have the test model as parent, but got {data['data'][1]['parent']}"
assert data["data"][1]["parent"] == TEST_MODEL_NAME, f"The second model should have the test model as parent, but got {data['data'][1]['parent']}"
assert data["data"][2]["id"] == TEST_ADAPTER_NAME1, f"The third model should be the test adapter, but got {data['data'][2]['id']}"
assert data["data"][2]["parent"] == TEST_MODEL, f"The third model should have the test model as parent, but got {data['data'][2]['parent']}"
assert data["data"][2]["parent"] == TEST_MODEL_NAME, f"The third model should have the test model as parent, but got {data['data'][2]['parent']}"

def test_binary_search_with_limited_steps():

Expand Down
Loading