Skip to content

Commit

Permalink
fix: binary search for best context length avoiding oom (#705)
Browse files Browse the repository at this point in the history
**Reason for Change**:

fix: binary search for best context length avoiding oom

**Issue Fixed**:

`find_max_available_seq_len` runs to oom when running
on the V100 16GB gpu with 128K context.

**Notes for Reviewers**:

In the worst case, it costs about 1minutes to find the best 
length (running with phi3 medium model and 128k search space).

We set the context length to a safe value to avoid oom.
If the serving server receives a request which token length is longer
than `max_model_len`, server will reject this request.

example error message: `This model's maximum context length is 2 tokens.
However, you requested 19 tokens (9 in the messages, 10 in the
completion).
Please reduce the length of the messages or completion.`

---------

Signed-off-by: jerryzhuang <[email protected]>
  • Loading branch information
zhuangqh authored Nov 20, 2024
1 parent aff764e commit 1517106
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 8 deletions.
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ tuning-metrics-server-test:

inference-api-e2e:
pip install -r ./presets/dependencies/requirements-test.txt
pytest -o log_cli=true -o log_cli_level=INFO presets/inference
pytest -o log_cli=true -o log_cli_level=INFO presets/inference/vllm
pytest -o log_cli=true -o log_cli_level=INFO presets/inference/text-generation

# Ginkgo configurations
GINKGO_FOCUS ?=
Expand Down
5 changes: 4 additions & 1 deletion presets/inference/text-generation/inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
# Initialize logger
logger = logging.getLogger(__name__)
debug_mode = os.environ.get('DEBUG_MODE', 'false').lower() == 'true'
logging.basicConfig(level=logging.DEBUG if debug_mode else logging.INFO)
logging.basicConfig(
level=logging.DEBUG if debug_mode else logging.INFO,
format='%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s',
datefmt='%m-%d %H:%M:%S')

ADAPTERS_DIR = '/mnt/adapter'

Expand Down
67 changes: 63 additions & 4 deletions presets/inference/vllm/inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,22 @@
import logging
import gc
import os
from typing import Callable

import uvloop
import torch
from vllm.utils import FlexibleArgumentParser
import vllm.entrypoints.openai.api_server as api_server
from vllm.engine.llm_engine import (LLMEngine, EngineArgs, EngineConfig)
from vllm.executor.executor_base import ExecutorBase

# Initialize logger
logger = logging.getLogger(__name__)
debug_mode = os.environ.get('DEBUG_MODE', 'false').lower() == 'true'
logging.basicConfig(level=logging.DEBUG if debug_mode else logging.INFO)
logging.basicConfig(
level=logging.DEBUG if debug_mode else logging.INFO,
format='%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s',
datefmt='%m-%d %H:%M:%S')

def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
local_rank = int(os.environ.get("LOCAL_RANK",
Expand Down Expand Up @@ -58,15 +63,67 @@ def find_max_available_seq_len(engine_config: EngineConfig) -> int:
observability_config=engine_config.observability_config,
)

# see https://github.com/vllm-project/vllm/blob/v0.6.3/vllm/engine/llm_engine.py#L477
num_gpu_blocks, _ = executor.determine_num_available_blocks()
max_probe_steps = 6
if os.getenv("MAX_PROBE_STEPS") is not None:
try:
max_probe_steps = int(os.getenv("MAX_PROBE_STEPS"))
except ValueError:
raise ValueError("MAX_PROBE_STEPS must be an integer.")

model_max_blocks = int(engine_config.model_config.max_model_len / engine_config.cache_config.block_size)
res = binary_search_with_limited_steps(model_max_blocks, max_probe_steps, lambda x: is_context_length_safe(executor, x))

# release memory
del executor
gc.collect()
torch.cuda.empty_cache()

return engine_config.cache_config.block_size * num_gpu_blocks
return engine_config.cache_config.block_size * res

def binary_search_with_limited_steps(upper: int, max_probe_steps: int, is_valid_fn: Callable[[int], bool]) -> int:
"""
Finds the maximum valid value with limited number of steps.
Parameters:
- upper (int): The upper bound of the search space([0, upper]).
- max_probe_steps (int): Maximum number of steps to try.
- is_valid_fn (Callable[[int], bool]): A function that checks if a given value is valid.
Returns: - int: The maximum valid value.
"""
probe_steps = 0
low = 0
# double the upper bound and firstly search at upper value later.
# because the valid value is likely to be close to the upper bound.
high = upper * 2
while low < high and probe_steps < max_probe_steps:
mid = (low + high + 1) // 2
if mid > upper:
break

if is_valid_fn(mid):
low = mid
else:
high = mid - 1

probe_steps += 1

return low

def is_context_length_safe(executor: ExecutorBase, num_gpu_blocks: int) -> bool:
"""
Check if the avilable gpu blocks is enough for the given num_gpu_blocks.
"""
context_length = executor.cache_config.block_size * num_gpu_blocks
executor.scheduler_config.max_num_batched_tokens = context_length

try:
# see https://github.com/vllm-project/vllm/blob/v0.6.3/vllm/engine/llm_engine.py#L477
available_gpu_blocks, _ = executor.determine_num_available_blocks()
except torch.OutOfMemoryError as e:
return False

return available_gpu_blocks >= num_gpu_blocks

if __name__ == "__main__":
parser = FlexibleArgumentParser(description='vLLM serving server')
Expand Down Expand Up @@ -95,6 +152,8 @@ def find_max_available_seq_len(engine_config: EngineConfig) -> int:
if available_seq_len != max_model_len:
logger.info(f"Set max_model_len from {max_model_len} to {available_seq_len}")
args.max_model_len = available_seq_len
else:
logger.info(f"Using model default max_model_len {max_model_len}")

# Run the serving server
logger.info(f"Starting server on port {args.port}")
Expand Down
37 changes: 35 additions & 2 deletions presets/inference/vllm/tests/test_vllm_inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# Add the parent directory to sys.path
sys.path.append(parent_dir)

from inference_api import binary_search_with_limited_steps

TEST_MODEL = "facebook/opt-125m"
CHAT_TEMPLATE = ("{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}"
"{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}"
"{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}")

@pytest.fixture(scope="module", autouse=True)
@pytest.fixture
def setup_server(request):
if os.getenv("DEVICE") == "cpu":
pytest.skip("Skipping test on cpu device")
Expand Down Expand Up @@ -111,4 +113,35 @@ def test_chat_completions_api(setup_server):
for choice in data["choices"]:
assert "message" in choice, "Each choice should contain a 'message' key"
assert "content" in choice["message"], "Each message should contain a 'content' key"
assert len(choice["message"]["content"]) > 0, "The completion text should not be empty"
assert len(choice["message"]["content"]) > 0, "The completion text should not be empty"


def test_binary_search_with_limited_steps():

def is_safe_fn(x):
return x <= 10

# Test case 1: all values are safe
result = binary_search_with_limited_steps(10, 1, is_safe_fn)
assert result == 10, f"Expected 10, but got {result}"

result = binary_search_with_limited_steps(10, 10, is_safe_fn)
assert result == 10, f"Expected 10, but got {result}"

# Test case 2: partial safe, find the exact value
result = binary_search_with_limited_steps(20, 3, is_safe_fn)
assert result == 10, f"Expected 10, but got {result}"

result = binary_search_with_limited_steps(30, 6, is_safe_fn)
assert result == 10, f"Expected 10, but got {result}"

# Test case 3: partial safe, find an approximate value
result = binary_search_with_limited_steps(30, 3, is_safe_fn)
assert result == 7, f"Expected 7, but got {result}"

# Test case 4: all values are unsafe
result = binary_search_with_limited_steps(10, 1, lambda x: False)
assert result == 0, f"Expected 0, but got {result}"

result = binary_search_with_limited_steps(20, 100, lambda x: False)
assert result == 0, f"Expected 0, but got {result}"

0 comments on commit 1517106

Please sign in to comment.