Skip to content

Commit

Permalink
lora adapters support for vllm
Browse files Browse the repository at this point in the history
  • Loading branch information
wendy-aw committed Jun 18, 2024
1 parent 70e1bfc commit c8f098a
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 7 deletions.
15 changes: 13 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ We also support loading a peft adapter here as well via the `-a` flag. Note that

### vLLM

We also have a [vllm](https://blog.vllm.ai/) runner which uses the vLLM engine to run the inference altogether as a single batch. It is much faster to do so especially when `num_beams` > 1. You would have to pass in a single set of merged model weights, and the model architecture needs to be supported by vLLM. Here's a sample command:
We also have a [vllm](https://blog.vllm.ai/) runner which uses the vLLM engine to run the inference altogether as a single batch. It is much faster to do so especially when `num_beams` > 1. You would have to pass in a single set of merged model weights, path to LoRA adapters if applicable, and the model architecture needs to be supported by vLLM. Here's a sample command:
```bash
python -W ignore main.py \
-db postgres \
Expand All @@ -183,6 +183,7 @@ python -W ignore main.py \
-g vllm \
-f "prompts/prompt.md" \
-m defog/llama-3-sqlcoder-8b \
-a path/to_adapter \
-c 0
```

Expand All @@ -200,7 +201,16 @@ We also provide our custom modification of the vllm api server, which only retur
python -m vllm.entrypoints.api_server \
--model defog/sqlcoder-7b-2 \
--tensor-parallel-size 4 \
--dtype float16
--dtype float16

# to set up a vllm server that supports LoRA adapters
python -m vllm.entrypoints.api_server \
--model defog/sqlcoder-7b-2 \
--tensor-parallel-size 1 \
--dtype float16 \
--max-model-len 4096 \
--enable-lora \
--max-lora-rank 64

# to use our modified api server
python utils/api_server.py \
Expand All @@ -218,6 +228,7 @@ python main.py \
-f prompts/prompt.md \
--api_url "http://localhost:8000/generate" \
--api_type "vllm" \
-a path/to_adapter_if_applicable \
-p 8
```

Expand Down
10 changes: 8 additions & 2 deletions eval/api_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from utils.reporting import upload_results


def mk_vllm_json(prompt, num_beams, logprobs=False):
def mk_vllm_json(prompt, num_beams, logprobs=False, sql_lora_path=None):
payload = {
"prompt": prompt,
"n": 1,
Expand All @@ -24,7 +24,10 @@ def mk_vllm_json(prompt, num_beams, logprobs=False):
"stop": [";", "```"],
"max_tokens": 4000,
"seed": 42,
"sql_lora_path": sql_lora_path,
}
if sql_lora_path:
print("Using LoRA adapter at:", sql_lora_path)
if logprobs:
payload["logprobs"] = 2
return payload
Expand All @@ -51,12 +54,13 @@ def process_row(
num_beams: int,
decimal_points: int,
logprobs: bool = False,
sql_lora_path: Optional[str] = None,
):
start_time = time()
if api_type == "tgi":
json_data = mk_tgi_json(row["prompt"], num_beams)
elif api_type == "vllm":
json_data = mk_vllm_json(row["prompt"], num_beams, logprobs)
json_data = mk_vllm_json(row["prompt"], num_beams, logprobs, sql_lora_path)
else:
# add any custom JSON data here, e.g. for a custom API
json_data = {
Expand Down Expand Up @@ -186,6 +190,7 @@ def run_api_eval(args):
decimal_points = args.decimal_points
logprobs = args.logprobs
cot_table_alias = args.cot_table_alias
sql_lora_path = args.adapter if args.adapter else None

if logprobs:
# check that the eval-visualizer/public directory exists
Expand Down Expand Up @@ -252,6 +257,7 @@ def run_api_eval(args):
num_beams,
decimal_points,
logprobs,
sql_lora_path,
)
)

Expand Down
15 changes: 14 additions & 1 deletion eval/vllm_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import List
import sqlparse
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from eval.eval import compare_query_results
import pandas as pd
from utils.gen_prompt import generate_prompt
Expand All @@ -27,18 +28,29 @@ def run_vllm_eval(args):
k_shot = args.k_shot
db_type = args.db_type
cot_table_alias = args.cot_table_alias
enable_lora = True if args.adapter else False
lora_request = LoRARequest("sql_adapter", 1, args.adapter) if args.adapter else None

# initialize model only once as it takes a while
print(f"Preparing {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id
if not args.quantized:
llm = LLM(model=model_name, tensor_parallel_size=1)
llm = LLM(
model=model_name,
tensor_parallel_size=1,
enable_lora=enable_lora,
max_model_len=4096,
max_lora_rank=64,
)
else:
llm = LLM(
model=model_name,
tensor_parallel_size=1,
quantization="AWQ",
enable_lora=enable_lora,
max_model_len=4096,
max_lora_rank=64,
)

sampling_params = SamplingParams(
Expand Down Expand Up @@ -129,6 +141,7 @@ def chunk_dataframe(df, chunk_size):
sampling_params=sampling_params,
prompt_token_ids=prompt_tokens,
use_tqdm=False,
lora_request=lora_request,
)
print(
f"Generated {len(outputs)} completions in {time.time() - start_time:.2f} seconds"
Expand Down
11 changes: 9 additions & 2 deletions utils/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
from vllm.lora.request import LoRARequest
from vllm import __version__ as vllm_version

TIMEOUT_KEEP_ALIVE = 5 # seconds.
Expand All @@ -22,13 +23,15 @@
# - don't add special_tokens (bos/eos) and only add it if it's missing from the prompt
# You can start it similar to how you would with the usual vllm api server:
# ```
# python3 -m utils/api_server.py \
# python3 utils/api_server.py \
# --model "${model_path}" \
# --tensor-parallel-size 4 \
# --dtype float16 \
# --max-model-len 4096 \
# --port 5000 \
# --gpu-memory-utilization 0.90
# --gpu-memory-utilization 0.90 \
# --enable-lora \
# --max-lora-rank 64 \


@app.get("/health")
Expand All @@ -49,6 +52,8 @@ async def generate(request: Request) -> Response:
request_dict = await request.json()
prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", False)
sql_lora_path = request_dict.pop("sql_lora_path", None)
lora_request = LoRARequest("sql_adapter", 1, sql_lora_path) if sql_lora_path else None
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()
tokenizer = await engine.get_tokenizer()
Expand All @@ -62,13 +67,15 @@ async def generate(request: Request) -> Response:
inputs={"prompt_token_ids": prompt_token_ids},
sampling_params=sampling_params,
request_id=request_id,
lora_request=lora_request,
)
else:
results_generator = engine.generate(
prompt=None,
sampling_params=sampling_params,
request_id=request_id,
prompt_token_ids=prompt_token_ids,
lora_request=LoRARequest("sql_adapter", 1, sql_lora_path)
)

# Streaming case
Expand Down

0 comments on commit c8f098a

Please sign in to comment.