From 966469a3d4818e280127cd05f0ffd610931ecb5c Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Tue, 27 Feb 2024 16:53:18 -0800 Subject: [PATCH] refactor and make it easier to add new backends for benchmarking --- benchmarks/inference/mii/run_benchmark.py | 6 +- benchmarks/inference/mii/src/client.py | 191 ++++++++---------- benchmarks/inference/mii/src/defaults.py | 2 + .../inference/mii/src/postprocess_results.py | 23 ++- benchmarks/inference/mii/src/server.py | 97 ++++----- benchmarks/inference/mii/src/utils.py | 58 ++++-- 6 files changed, 195 insertions(+), 182 deletions(-) diff --git a/benchmarks/inference/mii/run_benchmark.py b/benchmarks/inference/mii/run_benchmark.py index 96e88155f..801d45b85 100644 --- a/benchmarks/inference/mii/run_benchmark.py +++ b/benchmarks/inference/mii/run_benchmark.py @@ -20,7 +20,8 @@ def run_benchmark() -> None: args = parse_args(server_args=True, client_args=True) for server_args in get_args_product(args, which=SERVER_PARAMS): - start_server(server_args) + if server_args.backend != "aml": + start_server(server_args) for client_args in get_args_product(server_args, which=CLIENT_PARAMS): if results_exist(client_args) and not args.overwrite_results: @@ -33,7 +34,8 @@ def run_benchmark() -> None: print_summary(client_args, response_details) save_json_results(client_args, response_details) - stop_server(server_args) + if server_args.backend != "aml": + stop_server(server_args) if __name__ == "__main__": diff --git a/benchmarks/inference/mii/src/client.py b/benchmarks/inference/mii/src/client.py index ff31784d9..d290efe07 100644 --- a/benchmarks/inference/mii/src/client.py +++ b/benchmarks/inference/mii/src/client.py @@ -3,6 +3,7 @@ # DeepSpeed Team +import argparse import asyncio import json import multiprocessing @@ -12,23 +13,30 @@ import requests import threading import time -from typing import List, Iterable +from typing import List, Iterable, Union import numpy as np from transformers import AutoTokenizer -#from .postprocess_results import ResponseDetails -#from .random_query_generator import RandomQueryGenerator -#from .sample_input import all_text -#from .utils import parse_args, print_summary, get_args_product, CLIENT_PARAMS +try: + from .postprocess_results import ResponseDetails + from .random_query_generator import RandomQueryGenerator + from .sample_input import all_text + from .utils import parse_args, print_summary, get_args_product, CLIENT_PARAMS +except ImportError: + from postprocess_results import ResponseDetails + from random_query_generator import RandomQueryGenerator + from sample_input import all_text + from utils import parse_args, print_summary, get_args_product, CLIENT_PARAMS -from postprocess_results import ResponseDetails -from random_query_generator import RandomQueryGenerator -from sample_input import all_text -from utils import parse_args, print_summary, get_args_product, CLIENT_PARAMS +def call_fastgen( + input_tokens: str, max_new_tokens: int, args: argparse.Namespace +) -> ResponseDetails: + import mii + + client = mii.client(args.deployment_name) -def call_mii(client, input_tokens, max_new_tokens, stream): output_tokens = [] token_gen_time = [] time_last_token = 0 @@ -43,7 +51,7 @@ def callback(response): time_last_token = start_time = time.time() token_gen_time = [] - if stream: + if args.stream: output_tokens = [] client.generate( input_tokens, max_new_tokens=max_new_tokens, streaming_fn=callback @@ -62,7 +70,12 @@ def callback(response): ) -def call_vllm(input_tokens, max_new_tokens, stream=True): +def call_vllm( + input_tokens: str, max_new_tokens: int, args: argparse.Namespace +) -> ResponseDetails: + if not args.stream: + raise NotImplementedError("Not implemented for non-streaming") + api_url = "http://localhost:26500/generate" headers = {"User-Agent": "Benchmark Client"} pload = { @@ -73,7 +86,7 @@ def call_vllm(input_tokens, max_new_tokens, stream=True): "top_p": 0.9, "max_tokens": max_new_tokens, "ignore_eos": False, - "stream": stream, + "stream": args.stream, } def clear_line(n: int = 1) -> None: @@ -95,42 +108,41 @@ def get_streaming_response( yield output, time_now - time_last_token time_last_token = time_now + # For non-streaming, but currently non-streaming is not fully implemented def get_response(response: requests.Response) -> List[str]: data = json.loads(response.content) output = data["text"] return output + token_gen_time = [] start_time = time.time() - response = requests.post(api_url, headers=headers, json=pload, stream=stream) - if stream: - token_gen_time = [] - for h, t in get_streaming_response(response, start_time): - output = h - token_gen_time.append(t) - - return ResponseDetails( - generated_tokens=output, - prompt=input_tokens, - start_time=start_time, - end_time=time.time(), - model_time=0, - token_gen_time=token_gen_time, - ) - else: - output = get_response(response) - raise NotImplementedError("Not implemented for non-streaming") + response = requests.post(api_url, headers=headers, json=pload, stream=args.stream) + for h, t in get_streaming_response(response, start_time): + output = h + token_gen_time.append(t) + + return ResponseDetails( + generated_tokens=output, + prompt=input_tokens, + start_time=start_time, + end_time=time.time(), + model_time=0, + token_gen_time=token_gen_time, + ) ## TODO (lekurile): Create AML call function -def call_aml(input_tokens, max_new_tokens, stream=False): - # TODO (lekurile): Hardcoded for now - api_url = 'https://alexander256v100-utiwh.southcentralus.inference.ml.azure.com/score' - api_key = '' - if not api_key: - raise Exception("A key should be provided to invoke the endpoint") - headers = {'Content-Type':'application/json', 'Authorization':('Bearer '+ api_key), 'azureml-model-deployment': 'mistralai-mixtral-8x7b-v01-4' } - token_gen_time = [] - print(f"\ninput_tokens = {input_tokens}") +def call_aml( + input_tokens: str, max_new_tokens: int, args: argparse.Namespace +) -> ResponseDetails: + if args.stream: + raise NotImplementedError("Not implemented for streaming") + + headers = { + "Content-Type": "application/json", + "Authorization": ("Bearer " + args.aml_api_key), + "azureml-model-deployment": args.deployment_name, + } pload = { "input_data": { "input_string": [ @@ -139,23 +151,20 @@ def call_aml(input_tokens, max_new_tokens, stream=False): "parameters": { "max_new_tokens": max_new_tokens, "do_sample": True, - "return_full_text": False - } + "return_full_text": False, + }, } } def get_response(response: requests.Response) -> List[str]: data = json.loads(response.content) - #print(f"data = {data}") output = data[0]["0"] return output + token_gen_time = [] start_time = time.time() - response = requests.post(api_url, headers=headers, json=pload, stream=stream) - print(f"response = {response}") - + response = requests.post(args.aml_api_url, headers=headers, json=pload) output = get_response(response) - print(f"output = {output}") return ResponseDetails( generated_tokens=output, @@ -165,58 +174,41 @@ def get_response(response: requests.Response) -> List[str]: model_time=0, token_gen_time=token_gen_time, ) -## TODO (lekurile): Create AML call function def _run_parallel( - deployment_name, - warmup, - barrier, - query_queue, - result_queue, - num_clients, - stream, - vllm, - aml, + barrier: Union[threading.Barrier, multiprocessing.Barrier], + query_queue: Union[queue.Queue, multiprocessing.Queue], + result_queue: Union[queue.Queue, multiprocessing.Queue], + args: argparse.Namespace, ): pid = os.getpid() session_id = f"test_session_p{pid}_t{threading.get_ident()}" event_loop = asyncio.new_event_loop() asyncio.set_event_loop(event_loop) - if not (vllm or aml): - import mii - client = mii.client(deployment_name) + backend_call_fns = {"fastgen": call_fastgen, "vllm": call_vllm, "aml": call_aml} + call_fn = backend_call_fns[args.backend] barrier.wait() - for _ in range(warmup): + for _ in range(args.warmup): print(f"warmup queue size: {query_queue.qsize()} ({pid})", flush=True) input_tokens, req_max_new_tokens = query_queue.get(timeout=1.0) - - if vllm: - call_vllm(input_tokens, req_max_new_tokens, stream) - elif aml: - call_aml(input_tokens, req_max_new_tokens) - else: - call_mii(client, input_tokens, req_max_new_tokens, stream) + _ = call_fn(input_tokens, req_max_new_tokens, args) + # call_fastgen(client, input_tokens, req_max_new_tokens, stream) barrier.wait() - time.sleep(random.uniform(0, num_clients) * 0.01) + time.sleep(random.uniform(0, args.num_clients) * 0.01) try: while not query_queue.empty(): print(f"queue size: {query_queue.qsize()} ({pid})", flush=True) input_tokens, req_max_new_tokens = query_queue.get(timeout=1.0) - # Set max_new_tokens following normal distribution - if vllm: - r = call_vllm(input_tokens, req_max_new_tokens) - elif aml: - r = call_aml(input_tokens, req_max_new_tokens) - else: - r = call_mii(client, input_tokens, req_max_new_tokens, stream) + r = call_fn(input_tokens, req_max_new_tokens, args) + # r = call_fastgen(client, input_tokens, req_max_new_tokens, stream) result_queue.put(r) except queue.Empty: @@ -237,23 +229,7 @@ def run_client(args): 6. The main process marks the end time after receiving `num_requests' results """ - # Unpack arguments - model = args.model - deployment_name = args.deployment_name - mean_prompt_length = args.mean_prompt_length - mean_max_new_tokens = args.mean_max_new_tokens - num_clients = args.num_clients - num_requests = args.num_requests - warmup = args.warmup - max_prompt_length = args.max_prompt_length - prompt_length_var = args.prompt_length_var - max_new_tokens_var = args.max_new_tokens_var - stream = args.stream - vllm = args.vllm - aml = args.aml - use_thread = args.use_thread - - if use_thread: + if args.use_thread: runnable_cls = threading.Thread barrier_cls = threading.Barrier queue_cls = queue.Queue @@ -262,7 +238,7 @@ def run_client(args): barrier_cls = multiprocessing.Barrier queue_cls = multiprocessing.Queue - barrier = barrier_cls(num_clients + 1) + barrier = barrier_cls(args.num_clients + 1) query_queue = queue_cls() result_queue = queue_cls() @@ -270,35 +246,32 @@ def run_client(args): runnable_cls( target=_run_parallel, args=( - deployment_name, - warmup, barrier, query_queue, result_queue, - num_clients, - stream, - vllm, - aml, + args, ), ) - for i in range(num_clients) + for i in range(args.num_clients) ] for p in processes: p.start() - tokenizer = AutoTokenizer.from_pretrained(model) + tokenizer = AutoTokenizer.from_pretrained(args.model) query_generator = RandomQueryGenerator(all_text, tokenizer, seed=42) request_text = query_generator.get_random_request_text( - mean_prompt_length, - mean_prompt_length * prompt_length_var, - max_prompt_length, - num_requests + warmup * num_clients, + args.mean_prompt_length, + args.mean_prompt_length * args.prompt_length_var, + args.max_prompt_length, + args.num_requests + args.warmup * args.num_clients, ) for t in request_text: + # Set max_new_tokens following normal distribution req_max_new_tokens = int( np.random.normal( - mean_max_new_tokens, max_new_tokens_var * mean_max_new_tokens + args.mean_max_new_tokens, + args.max_new_tokens_var * args.mean_max_new_tokens, ) ) query_queue.put((t, req_max_new_tokens)) @@ -311,10 +284,10 @@ def run_client(args): barrier.wait() response_details = [] - while len(response_details) < num_requests: + while len(response_details) < args.num_requests: res = result_queue.get() # vLLM returns concatinated tokens - if vllm: + if args.backend == "vllm": all_tokens = tokenizer.tokenize(res.generated_tokens) res.generated_tokens = all_tokens[len(tokenizer.tokenize(res.prompt)) :] response_details.append(res) diff --git a/benchmarks/inference/mii/src/defaults.py b/benchmarks/inference/mii/src/defaults.py index 79ce91c97..89255dfa6 100644 --- a/benchmarks/inference/mii/src/defaults.py +++ b/benchmarks/inference/mii/src/defaults.py @@ -4,6 +4,8 @@ # DeepSpeed Team ARG_DEFAULTS = { + "model": "meta-llama/Llama-2-7b-hf", + "deployment_name": "benchmark-deployment", "tp_size": 1, "max_ragged_batch_size": 768, "num_replicas": 1, diff --git a/benchmarks/inference/mii/src/postprocess_results.py b/benchmarks/inference/mii/src/postprocess_results.py index 38fa9ab20..4260f1341 100644 --- a/benchmarks/inference/mii/src/postprocess_results.py +++ b/benchmarks/inference/mii/src/postprocess_results.py @@ -79,17 +79,24 @@ def get_summary(args, response_details): for r in response_details ] ) - #first_token_latency = mean([r.token_gen_time[0] for r in response_details]) - #token_gen_latency_flat = reduce( - # list.__add__, - # [r.token_gen_time[1:-1] for r in response_details if len(r.token_gen_time) > 2], - #) - #token_gen_latency = mean([t for t in token_gen_latency_flat]) + # For non-streaming results, we don't have any token_gen_time information + first_token_latency = 0.0 + token_gen_latency = 0.0 + if response_details[0].token_gen_time: + first_token_latency = mean([r.token_gen_time[0] for r in response_details]) + token_gen_latency_flat = reduce( + list.__add__, + [ + r.token_gen_time[1:-1] + for r in response_details + if len(r.token_gen_time) > 2 + ], + ) + token_gen_latency = mean([t for t in token_gen_latency_flat]) return ProfilingSummary( - #throughput, latency, token_gen_latency, first_token_latency, tokens_per_sec - throughput, latency, 0.0, 0.0, tokens_per_sec + throughput, latency, token_gen_latency, first_token_latency, tokens_per_sec ) diff --git a/benchmarks/inference/mii/src/server.py b/benchmarks/inference/mii/src/server.py index d0ecabaf3..ec04338b5 100644 --- a/benchmarks/inference/mii/src/server.py +++ b/benchmarks/inference/mii/src/server.py @@ -3,37 +3,28 @@ # DeepSpeed Team +import argparse import subprocess import time -import mii -from deepspeed.inference import RaggedInferenceEngineConfig, DeepSpeedTPConfig -from deepspeed.inference.v2.ragged import DSStateManagerConfig -from .utils import parse_args, SERVER_PARAMS +try: + from .utils import parse_args, SERVER_PARAMS +except ImportError: + from utils import parse_args, SERVER_PARAMS -def start_server(args): - vllm = args.vllm - model = args.model - deployment_name = args.deployment_name - tp_size = args.tp_size - num_replicas = args.num_replicas - max_ragged_batch_size = args.max_ragged_batch_size - - if vllm: - start_vllm_server(model=model, tp_size=tp_size) - else: - start_mii_server( - model=model, - deployment_name=deployment_name, - tp_size=tp_size, - num_replicas=num_replicas, - max_ragged_batch_size=max_ragged_batch_size, - ) +def start_server(args: argparse.Namespace) -> None: + start_server_fns = { + "fastgen": start_fastgen_server, + "vllm": start_vllm_server, + "aml": start_aml_server, + } + start_fn = start_server_fns[args.backend] + start_fn(args) -def start_vllm_server(model: str, tp_size: int) -> None: +def start_vllm_server(args: argparse.Namespace) -> None: vllm_cmd = ( "python", "-m", @@ -43,9 +34,9 @@ def start_vllm_server(model: str, tp_size: int) -> None: "--port", "26500", "--tensor-parallel-size", - str(tp_size), + str(args.tp_size), "--model", - model, + args.model, ) p = subprocess.Popen( vllm_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, close_fds=True @@ -67,45 +58,61 @@ def start_vllm_server(model: str, tp_size: int) -> None: time.sleep(0.01) -def start_mii_server( - model, deployment_name, tp_size, num_replicas, max_ragged_batch_size -): - tp_config = DeepSpeedTPConfig(tp_size=tp_size) +def start_fastgen_server(args: argparse.Namespace) -> None: + import mii + from deepspeed.inference import RaggedInferenceEngineConfig, DeepSpeedTPConfig + from deepspeed.inference.v2.ragged import DSStateManagerConfig + + tp_config = DeepSpeedTPConfig(tp_size=args.tp_size) mgr_config = DSStateManagerConfig( - max_ragged_batch_size=max_ragged_batch_size, - max_ragged_sequence_count=max_ragged_batch_size, + max_ragged_batch_size=args.max_ragged_batch_size, + max_ragged_sequence_count=args.max_ragged_batch_size, ) inference_config = RaggedInferenceEngineConfig( tensor_parallel=tp_config, state_manager=mgr_config ) mii.serve( - model, - deployment_name=deployment_name, - tensor_parallel=tp_size, + args.model, + deployment_name=args.deployment_name, + tensor_parallel=args.tp_size, inference_engine_config=inference_config, - replica_num=num_replicas, + replica_num=args.num_replicas, ) -def stop_server(args): - vllm = args.vllm - deployment_name = args.deployment_name +def start_aml_server(args: argparse.Namespace) -> None: + raise NotImplementedError( + "AML server start not implemented. Please use Azure Portal to start the server." + ) - if vllm: - stop_vllm_server() - else: - stop_mii_server(deployment_name) +def stop_server(args: argparse.Namespace) -> None: + stop_server_fns = { + "fastgen": stop_fastgen_server, + "vllm": stop_vllm_server, + "aml": stop_aml_server, + } + stop_fn = stop_server_fns[args.backend] + stop_fn(args) -def stop_vllm_server(): + +def stop_vllm_server(args: argparse.Namespace) -> None: vllm_cmd = ("pkill", "-f", "vllm.entrypoints.api_server") p = subprocess.Popen(vllm_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) p.wait() -def stop_mii_server(deployment_name): - mii.client(deployment_name).terminate_server() +def stop_fastgen_server(args: argparse.Namespace) -> None: + import mii + + mii.client(args.deployment_name).terminate_server() + + +def stop_aml_server(args: argparse.Namespace) -> None: + raise NotImplementedError( + "AML server stop not implemented. Please use Azure Portal to stop the server." + ) if __name__ == "__main__": diff --git a/benchmarks/inference/mii/src/utils.py b/benchmarks/inference/mii/src/utils.py index d5e739734..2f780bd5a 100644 --- a/benchmarks/inference/mii/src/utils.py +++ b/benchmarks/inference/mii/src/utils.py @@ -14,16 +14,20 @@ from pathlib import Path from typing import Iterator, List -#from .defaults import ARG_DEFAULTS, MODEL_DEFAULTS -#from .postprocess_results import get_summary, ResponseDetails -from defaults import ARG_DEFAULTS, MODEL_DEFAULTS -from postprocess_results import get_summary, ResponseDetails +try: + from .defaults import ARG_DEFAULTS, MODEL_DEFAULTS + from .postprocess_results import get_summary, ResponseDetails +except ImportError: + from defaults import ARG_DEFAULTS, MODEL_DEFAULTS + from postprocess_results import get_summary, ResponseDetails # For these arguments, users can provide multiple values when running the # benchmark. The benchmark will iterate over all possible combinations. SERVER_PARAMS = ["tp_size", "max_ragged_batch_size", "num_replicas"] CLIENT_PARAMS = ["mean_prompt_length", "mean_max_new_tokens", "num_clients"] +AML_REQUIRED_PARAMS = ["aml_api_url", "aml_api_key", "deployment_name", "model"] + def parse_args( server_args: bool = False, client_args: bool = False @@ -48,7 +52,7 @@ def parse_args( type=int, nargs="+", default=None, - help="Number of MII model replicas", + help="Number of FastGen model replicas", ) server_parser.add_argument( "cmd", @@ -114,6 +118,18 @@ def parse_args( default="./results/", help="Directory to save result JSON files", ) + client_parser.add_argument( + "--aml_api_url", + type=str, + default=None, + help="When using the AML backend, this is the API URL that points to an AML endpoint", + ) + client_parser.add_argument( + "--aml_api_key", + type=str, + default=None, + help="When using the AML backend, this is the API key for a given aml_api_url", + ) # Create the parser, inheriting from the server and/or client parsers parents = [] @@ -125,16 +141,21 @@ def parse_args( # Common args parser = argparse.ArgumentParser(parents=parents) parser.add_argument( - "--model", type=str, default="meta-llama/Llama-2-7b-hf", help="Model name" + "--model", type=str, default=None, help="HuggingFace.co model name" ) parser.add_argument( "--deployment_name", type=str, - default="mii-benchmark-deployment", - help="Deployment name for MII server", + default=None, + help="When using FastGen backend, specifies which model deployment to use", + ) + parser.add_argument( + "--backend", + type=str, + choices=["aml", "fastgen", "vllm"], + default="fastgen", + help="Which backend to benchmark", ) - parser.add_argument("--vllm", action="store_true", help="Use VLLM") - parser.add_argument("--aml", action="store_true", help="Use AML") parser.add_argument( "--overwrite_results", action="store_true", help="Overwrite existing results" ) @@ -142,6 +163,12 @@ def parse_args( # Parse arguments args = parser.parse_args() + # Verify that AML required parameters are defined before filling in defaults + if args.backend == "aml": + for k in AML_REQUIRED_PARAMS: + if getattr(args, k) is None: + raise ValueError(f"AML backend requires {k} to be specified") + # Set default values for model-specific parameters if args.model in MODEL_DEFAULTS: for k, v in MODEL_DEFAULTS[args.model].items(): @@ -153,8 +180,9 @@ def parse_args( if hasattr(args, k) and getattr(args, k) is None: setattr(args, k, v) + # If we are not running the benchmark, we need to make sure to only have one + # value for the server args if server_args and not client_args: - # If we are not running the benchmark, we need to make sure to only have one value for the server args for k in SERVER_PARAMS: if not isinstance(getattr(args, k), int): setattr(args, k, getattr(args, k)[0]) @@ -179,15 +207,9 @@ def get_args_product( def get_results_path(args: argparse.Namespace) -> Path: - if args.vllm: - lib_path = "vllm" - elif args.aml: - lib_path = "aml" - else: - lib_path = "fastgen" return Path( args.out_json_dir, - f"{lib_path}/", + f"{args.backend}/", "-".join( ( args.model.replace("/", "_"),