Skip to content

Commit

Permalink
add adapter name arg
Browse files Browse the repository at this point in the history
  • Loading branch information
wendy-aw committed Jul 8, 2024
1 parent f76b171 commit cf7548b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
12 changes: 10 additions & 2 deletions eval/api_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from utils.reporting import upload_results


def mk_vllm_json(prompt, num_beams, logprobs=False, sql_lora_path=None):
def mk_vllm_json(
prompt, num_beams, logprobs=False, sql_lora_path=None, sql_lora_name=None
):
payload = {
"prompt": prompt,
"n": 1,
Expand All @@ -25,6 +27,7 @@ def mk_vllm_json(prompt, num_beams, logprobs=False, sql_lora_path=None):
"max_tokens": 4000,
"seed": 42,
"sql_lora_path": sql_lora_path,
"sql_lora_name": sql_lora_name,
}
if logprobs:
payload["logprobs"] = 2
Expand Down Expand Up @@ -53,12 +56,15 @@ def process_row(
decimal_points: int,
logprobs: bool = False,
sql_lora_path: Optional[str] = None,
sql_lora_name: Optional[str] = None,
):
start_time = time()
if api_type == "tgi":
json_data = mk_tgi_json(row["prompt"], num_beams)
elif api_type == "vllm":
json_data = mk_vllm_json(row["prompt"], num_beams, logprobs, sql_lora_path)
json_data = mk_vllm_json(
row["prompt"], num_beams, logprobs, sql_lora_path, sql_lora_name
)
else:
# add any custom JSON data here, e.g. for a custom API
json_data = {
Expand Down Expand Up @@ -189,6 +195,7 @@ def run_api_eval(args):
logprobs = args.logprobs
cot_table_alias = args.cot_table_alias
sql_lora_path = args.adapter if args.adapter else None
sql_lora_name = args.adapter_name if args.adapter_name else None
run_name = args.run_name if args.run_name else None
if sql_lora_path:
print("Using LoRA adapter at:", sql_lora_path)
Expand Down Expand Up @@ -258,6 +265,7 @@ def run_api_eval(args):
decimal_points,
logprobs,
sql_lora_path,
sql_lora_name,
)
)

Expand Down
5 changes: 4 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# model-related parameters
parser.add_argument("-g", "--model_type", type=str, required=True)
parser.add_argument("-m", "--model", type=str)
parser.add_argument("-a", "--adapter", type=str)
parser.add_argument("-a", "--adapter", type=str) # path to adapter
parser.add_argument(
"-an", "--adapter_name", type=str, default=None
) # only for use with production server
parser.add_argument("--api_url", type=str)
parser.add_argument("--api_type", type=str)
# inference-technique-related parameters
Expand Down

0 comments on commit cf7548b

Please sign in to comment.