Skip to content

Commit

Permalink
refactor and make it easier to add new backends for benchmarking
Browse files Browse the repository at this point in the history
  • Loading branch information
mrwyattii committed Feb 28, 2024
1 parent fea1a79 commit 966469a
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 182 deletions.
6 changes: 4 additions & 2 deletions benchmarks/inference/mii/run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def run_benchmark() -> None:
args = parse_args(server_args=True, client_args=True)

for server_args in get_args_product(args, which=SERVER_PARAMS):
start_server(server_args)
if server_args.backend != "aml":
start_server(server_args)

for client_args in get_args_product(server_args, which=CLIENT_PARAMS):
if results_exist(client_args) and not args.overwrite_results:
Expand All @@ -33,7 +34,8 @@ def run_benchmark() -> None:
print_summary(client_args, response_details)
save_json_results(client_args, response_details)

stop_server(server_args)
if server_args.backend != "aml":
stop_server(server_args)


if __name__ == "__main__":
Expand Down
191 changes: 82 additions & 109 deletions benchmarks/inference/mii/src/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# DeepSpeed Team

import argparse
import asyncio
import json
import multiprocessing
Expand All @@ -12,23 +13,30 @@
import requests
import threading
import time
from typing import List, Iterable
from typing import List, Iterable, Union

import numpy as np
from transformers import AutoTokenizer

#from .postprocess_results import ResponseDetails
#from .random_query_generator import RandomQueryGenerator
#from .sample_input import all_text
#from .utils import parse_args, print_summary, get_args_product, CLIENT_PARAMS
try:
from .postprocess_results import ResponseDetails
from .random_query_generator import RandomQueryGenerator
from .sample_input import all_text
from .utils import parse_args, print_summary, get_args_product, CLIENT_PARAMS
except ImportError:
from postprocess_results import ResponseDetails
from random_query_generator import RandomQueryGenerator
from sample_input import all_text
from utils import parse_args, print_summary, get_args_product, CLIENT_PARAMS

from postprocess_results import ResponseDetails
from random_query_generator import RandomQueryGenerator
from sample_input import all_text
from utils import parse_args, print_summary, get_args_product, CLIENT_PARAMS

def call_fastgen(
input_tokens: str, max_new_tokens: int, args: argparse.Namespace
) -> ResponseDetails:
import mii

client = mii.client(args.deployment_name)

def call_mii(client, input_tokens, max_new_tokens, stream):
output_tokens = []
token_gen_time = []
time_last_token = 0
Expand All @@ -43,7 +51,7 @@ def callback(response):

time_last_token = start_time = time.time()
token_gen_time = []
if stream:
if args.stream:
output_tokens = []
client.generate(
input_tokens, max_new_tokens=max_new_tokens, streaming_fn=callback
Expand All @@ -62,7 +70,12 @@ def callback(response):
)


def call_vllm(input_tokens, max_new_tokens, stream=True):
def call_vllm(
input_tokens: str, max_new_tokens: int, args: argparse.Namespace
) -> ResponseDetails:
if not args.stream:
raise NotImplementedError("Not implemented for non-streaming")

api_url = "http://localhost:26500/generate"
headers = {"User-Agent": "Benchmark Client"}
pload = {
Expand All @@ -73,7 +86,7 @@ def call_vllm(input_tokens, max_new_tokens, stream=True):
"top_p": 0.9,
"max_tokens": max_new_tokens,
"ignore_eos": False,
"stream": stream,
"stream": args.stream,
}

def clear_line(n: int = 1) -> None:
Expand All @@ -95,42 +108,41 @@ def get_streaming_response(
yield output, time_now - time_last_token
time_last_token = time_now

# For non-streaming, but currently non-streaming is not fully implemented
def get_response(response: requests.Response) -> List[str]:
data = json.loads(response.content)
output = data["text"]
return output

token_gen_time = []
start_time = time.time()
response = requests.post(api_url, headers=headers, json=pload, stream=stream)
if stream:
token_gen_time = []
for h, t in get_streaming_response(response, start_time):
output = h
token_gen_time.append(t)

return ResponseDetails(
generated_tokens=output,
prompt=input_tokens,
start_time=start_time,
end_time=time.time(),
model_time=0,
token_gen_time=token_gen_time,
)
else:
output = get_response(response)
raise NotImplementedError("Not implemented for non-streaming")
response = requests.post(api_url, headers=headers, json=pload, stream=args.stream)
for h, t in get_streaming_response(response, start_time):
output = h
token_gen_time.append(t)

return ResponseDetails(
generated_tokens=output,
prompt=input_tokens,
start_time=start_time,
end_time=time.time(),
model_time=0,
token_gen_time=token_gen_time,
)


## TODO (lekurile): Create AML call function
def call_aml(input_tokens, max_new_tokens, stream=False):
# TODO (lekurile): Hardcoded for now
api_url = 'https://alexander256v100-utiwh.southcentralus.inference.ml.azure.com/score'
api_key = ''
if not api_key:
raise Exception("A key should be provided to invoke the endpoint")
headers = {'Content-Type':'application/json', 'Authorization':('Bearer '+ api_key), 'azureml-model-deployment': 'mistralai-mixtral-8x7b-v01-4' }
token_gen_time = []
print(f"\ninput_tokens = {input_tokens}")
def call_aml(
input_tokens: str, max_new_tokens: int, args: argparse.Namespace
) -> ResponseDetails:
if args.stream:
raise NotImplementedError("Not implemented for streaming")

headers = {
"Content-Type": "application/json",
"Authorization": ("Bearer " + args.aml_api_key),
"azureml-model-deployment": args.deployment_name,
}
pload = {
"input_data": {
"input_string": [
Expand All @@ -139,23 +151,20 @@ def call_aml(input_tokens, max_new_tokens, stream=False):
"parameters": {
"max_new_tokens": max_new_tokens,
"do_sample": True,
"return_full_text": False
}
"return_full_text": False,
},
}
}

def get_response(response: requests.Response) -> List[str]:
data = json.loads(response.content)
#print(f"data = {data}")
output = data[0]["0"]
return output

token_gen_time = []
start_time = time.time()
response = requests.post(api_url, headers=headers, json=pload, stream=stream)
print(f"response = {response}")

response = requests.post(args.aml_api_url, headers=headers, json=pload)
output = get_response(response)
print(f"output = {output}")

return ResponseDetails(
generated_tokens=output,
Expand All @@ -165,58 +174,41 @@ def get_response(response: requests.Response) -> List[str]:
model_time=0,
token_gen_time=token_gen_time,
)
## TODO (lekurile): Create AML call function


def _run_parallel(
deployment_name,
warmup,
barrier,
query_queue,
result_queue,
num_clients,
stream,
vllm,
aml,
barrier: Union[threading.Barrier, multiprocessing.Barrier],
query_queue: Union[queue.Queue, multiprocessing.Queue],
result_queue: Union[queue.Queue, multiprocessing.Queue],
args: argparse.Namespace,
):
pid = os.getpid()
session_id = f"test_session_p{pid}_t{threading.get_ident()}"

event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)
if not (vllm or aml):
import mii

client = mii.client(deployment_name)
backend_call_fns = {"fastgen": call_fastgen, "vllm": call_vllm, "aml": call_aml}
call_fn = backend_call_fns[args.backend]

barrier.wait()

for _ in range(warmup):
for _ in range(args.warmup):
print(f"warmup queue size: {query_queue.qsize()} ({pid})", flush=True)
input_tokens, req_max_new_tokens = query_queue.get(timeout=1.0)

if vllm:
call_vllm(input_tokens, req_max_new_tokens, stream)
elif aml:
call_aml(input_tokens, req_max_new_tokens)
else:
call_mii(client, input_tokens, req_max_new_tokens, stream)
_ = call_fn(input_tokens, req_max_new_tokens, args)
# call_fastgen(client, input_tokens, req_max_new_tokens, stream)

barrier.wait()

time.sleep(random.uniform(0, num_clients) * 0.01)
time.sleep(random.uniform(0, args.num_clients) * 0.01)
try:
while not query_queue.empty():
print(f"queue size: {query_queue.qsize()} ({pid})", flush=True)
input_tokens, req_max_new_tokens = query_queue.get(timeout=1.0)

# Set max_new_tokens following normal distribution
if vllm:
r = call_vllm(input_tokens, req_max_new_tokens)
elif aml:
r = call_aml(input_tokens, req_max_new_tokens)
else:
r = call_mii(client, input_tokens, req_max_new_tokens, stream)
r = call_fn(input_tokens, req_max_new_tokens, args)
# r = call_fastgen(client, input_tokens, req_max_new_tokens, stream)

result_queue.put(r)
except queue.Empty:
Expand All @@ -237,23 +229,7 @@ def run_client(args):
6. The main process marks the end time after receiving `num_requests' results
"""

# Unpack arguments
model = args.model
deployment_name = args.deployment_name
mean_prompt_length = args.mean_prompt_length
mean_max_new_tokens = args.mean_max_new_tokens
num_clients = args.num_clients
num_requests = args.num_requests
warmup = args.warmup
max_prompt_length = args.max_prompt_length
prompt_length_var = args.prompt_length_var
max_new_tokens_var = args.max_new_tokens_var
stream = args.stream
vllm = args.vllm
aml = args.aml
use_thread = args.use_thread

if use_thread:
if args.use_thread:
runnable_cls = threading.Thread
barrier_cls = threading.Barrier
queue_cls = queue.Queue
Expand All @@ -262,43 +238,40 @@ def run_client(args):
barrier_cls = multiprocessing.Barrier
queue_cls = multiprocessing.Queue

barrier = barrier_cls(num_clients + 1)
barrier = barrier_cls(args.num_clients + 1)
query_queue = queue_cls()
result_queue = queue_cls()

processes = [
runnable_cls(
target=_run_parallel,
args=(
deployment_name,
warmup,
barrier,
query_queue,
result_queue,
num_clients,
stream,
vllm,
aml,
args,
),
)
for i in range(num_clients)
for i in range(args.num_clients)
]
for p in processes:
p.start()

tokenizer = AutoTokenizer.from_pretrained(model)
tokenizer = AutoTokenizer.from_pretrained(args.model)
query_generator = RandomQueryGenerator(all_text, tokenizer, seed=42)
request_text = query_generator.get_random_request_text(
mean_prompt_length,
mean_prompt_length * prompt_length_var,
max_prompt_length,
num_requests + warmup * num_clients,
args.mean_prompt_length,
args.mean_prompt_length * args.prompt_length_var,
args.max_prompt_length,
args.num_requests + args.warmup * args.num_clients,
)

for t in request_text:
# Set max_new_tokens following normal distribution
req_max_new_tokens = int(
np.random.normal(
mean_max_new_tokens, max_new_tokens_var * mean_max_new_tokens
args.mean_max_new_tokens,
args.max_new_tokens_var * args.mean_max_new_tokens,
)
)
query_queue.put((t, req_max_new_tokens))
Expand All @@ -311,10 +284,10 @@ def run_client(args):
barrier.wait()

response_details = []
while len(response_details) < num_requests:
while len(response_details) < args.num_requests:
res = result_queue.get()
# vLLM returns concatinated tokens
if vllm:
if args.backend == "vllm":
all_tokens = tokenizer.tokenize(res.generated_tokens)
res.generated_tokens = all_tokens[len(tokenizer.tokenize(res.prompt)) :]
response_details.append(res)
Expand Down
2 changes: 2 additions & 0 deletions benchmarks/inference/mii/src/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# DeepSpeed Team

ARG_DEFAULTS = {
"model": "meta-llama/Llama-2-7b-hf",
"deployment_name": "benchmark-deployment",
"tp_size": 1,
"max_ragged_batch_size": 768,
"num_replicas": 1,
Expand Down
Loading

0 comments on commit 966469a

Please sign in to comment.