diff --git a/Makefile b/Makefile index 6b87ca141..f74acc7da 100644 --- a/Makefile +++ b/Makefile @@ -113,7 +113,8 @@ tuning-metrics-server-test: inference-api-e2e: pip install -r ./presets/dependencies/requirements-test.txt - pytest -o log_cli=true -o log_cli_level=INFO presets/inference + pytest -o log_cli=true -o log_cli_level=INFO presets/inference/vllm + pytest -o log_cli=true -o log_cli_level=INFO presets/inference/text-generation # Ginkgo configurations GINKGO_FOCUS ?= diff --git a/presets/inference/text-generation/inference_api.py b/presets/inference/text-generation/inference_api.py index eb875cdeb..871dfbbc9 100644 --- a/presets/inference/text-generation/inference_api.py +++ b/presets/inference/text-generation/inference_api.py @@ -24,7 +24,10 @@ # Initialize logger logger = logging.getLogger(__name__) debug_mode = os.environ.get('DEBUG_MODE', 'false').lower() == 'true' -logging.basicConfig(level=logging.DEBUG if debug_mode else logging.INFO) +logging.basicConfig( + level=logging.DEBUG if debug_mode else logging.INFO, + format='%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s', + datefmt='%m-%d %H:%M:%S') ADAPTERS_DIR = '/mnt/adapter' diff --git a/presets/inference/vllm/inference_api.py b/presets/inference/vllm/inference_api.py index 10fc3e312..f6d32b576 100644 --- a/presets/inference/vllm/inference_api.py +++ b/presets/inference/vllm/inference_api.py @@ -3,17 +3,22 @@ import logging import gc import os +from typing import Callable import uvloop import torch from vllm.utils import FlexibleArgumentParser import vllm.entrypoints.openai.api_server as api_server from vllm.engine.llm_engine import (LLMEngine, EngineArgs, EngineConfig) +from vllm.executor.executor_base import ExecutorBase # Initialize logger logger = logging.getLogger(__name__) debug_mode = os.environ.get('DEBUG_MODE', 'false').lower() == 'true' -logging.basicConfig(level=logging.DEBUG if debug_mode else logging.INFO) +logging.basicConfig( + level=logging.DEBUG if debug_mode else logging.INFO, + format='%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s', + datefmt='%m-%d %H:%M:%S') def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: local_rank = int(os.environ.get("LOCAL_RANK", @@ -58,15 +63,67 @@ def find_max_available_seq_len(engine_config: EngineConfig) -> int: observability_config=engine_config.observability_config, ) - # see https://github.com/vllm-project/vllm/blob/v0.6.3/vllm/engine/llm_engine.py#L477 - num_gpu_blocks, _ = executor.determine_num_available_blocks() + max_probe_steps = 6 + if os.getenv("MAX_PROBE_STEPS") is not None: + try: + max_probe_steps = int(os.getenv("MAX_PROBE_STEPS")) + except ValueError: + raise ValueError("MAX_PROBE_STEPS must be an integer.") + + model_max_blocks = int(engine_config.model_config.max_model_len / engine_config.cache_config.block_size) + res = binary_search_with_limited_steps(model_max_blocks, max_probe_steps, lambda x: is_context_length_safe(executor, x)) # release memory del executor gc.collect() torch.cuda.empty_cache() - return engine_config.cache_config.block_size * num_gpu_blocks + return engine_config.cache_config.block_size * res + +def binary_search_with_limited_steps(upper: int, max_probe_steps: int, is_valid_fn: Callable[[int], bool]) -> int: + """ + Finds the maximum valid value with limited number of steps. + + Parameters: + - upper (int): The upper bound of the search space([0, upper]). + - max_probe_steps (int): Maximum number of steps to try. + - is_valid_fn (Callable[[int], bool]): A function that checks if a given value is valid. + + Returns: - int: The maximum valid value. + """ + probe_steps = 0 + low = 0 + # double the upper bound and firstly search at upper value later. + # because the valid value is likely to be close to the upper bound. + high = upper * 2 + while low < high and probe_steps < max_probe_steps: + mid = (low + high + 1) // 2 + if mid > upper: + break + + if is_valid_fn(mid): + low = mid + else: + high = mid - 1 + + probe_steps += 1 + + return low + +def is_context_length_safe(executor: ExecutorBase, num_gpu_blocks: int) -> bool: + """ + Check if the avilable gpu blocks is enough for the given num_gpu_blocks. + """ + context_length = executor.cache_config.block_size * num_gpu_blocks + executor.scheduler_config.max_num_batched_tokens = context_length + + try: + # see https://github.com/vllm-project/vllm/blob/v0.6.3/vllm/engine/llm_engine.py#L477 + available_gpu_blocks, _ = executor.determine_num_available_blocks() + except torch.OutOfMemoryError as e: + return False + + return available_gpu_blocks >= num_gpu_blocks if __name__ == "__main__": parser = FlexibleArgumentParser(description='vLLM serving server') @@ -95,6 +152,8 @@ def find_max_available_seq_len(engine_config: EngineConfig) -> int: if available_seq_len != max_model_len: logger.info(f"Set max_model_len from {max_model_len} to {available_seq_len}") args.max_model_len = available_seq_len + else: + logger.info(f"Using model default max_model_len {max_model_len}") # Run the serving server logger.info(f"Starting server on port {args.port}") diff --git a/presets/inference/vllm/tests/test_vllm_inference_api.py b/presets/inference/vllm/tests/test_vllm_inference_api.py index 30ae9cc7f..811d93306 100644 --- a/presets/inference/vllm/tests/test_vllm_inference_api.py +++ b/presets/inference/vllm/tests/test_vllm_inference_api.py @@ -13,12 +13,14 @@ # Add the parent directory to sys.path sys.path.append(parent_dir) +from inference_api import binary_search_with_limited_steps + TEST_MODEL = "facebook/opt-125m" CHAT_TEMPLATE = ("{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}" "{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}" "{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}") -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture def setup_server(request): if os.getenv("DEVICE") == "cpu": pytest.skip("Skipping test on cpu device") @@ -111,4 +113,35 @@ def test_chat_completions_api(setup_server): for choice in data["choices"]: assert "message" in choice, "Each choice should contain a 'message' key" assert "content" in choice["message"], "Each message should contain a 'content' key" - assert len(choice["message"]["content"]) > 0, "The completion text should not be empty" \ No newline at end of file + assert len(choice["message"]["content"]) > 0, "The completion text should not be empty" + + +def test_binary_search_with_limited_steps(): + + def is_safe_fn(x): + return x <= 10 + + # Test case 1: all values are safe + result = binary_search_with_limited_steps(10, 1, is_safe_fn) + assert result == 10, f"Expected 10, but got {result}" + + result = binary_search_with_limited_steps(10, 10, is_safe_fn) + assert result == 10, f"Expected 10, but got {result}" + + # Test case 2: partial safe, find the exact value + result = binary_search_with_limited_steps(20, 3, is_safe_fn) + assert result == 10, f"Expected 10, but got {result}" + + result = binary_search_with_limited_steps(30, 6, is_safe_fn) + assert result == 10, f"Expected 10, but got {result}" + + # Test case 3: partial safe, find an approximate value + result = binary_search_with_limited_steps(30, 3, is_safe_fn) + assert result == 7, f"Expected 7, but got {result}" + + # Test case 4: all values are unsafe + result = binary_search_with_limited_steps(10, 1, lambda x: False) + assert result == 0, f"Expected 0, but got {result}" + + result = binary_search_with_limited_steps(20, 100, lambda x: False) + assert result == 0, f"Expected 0, but got {result}"