Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

release: Update llama models endpoint #737

Closed
wants to merge 9 commits into from
216 changes: 216 additions & 0 deletions presets/inference/text-generation/inference_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Optional

import GPUtil
import torch
import transformers
import uvicorn
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Extra, Field
from transformers import (AutoModelForCausalLM, AutoTokenizer,
GenerationConfig, HfArgumentParser)


@dataclass
class ModelConfig:
"""
Transformers Model Configuration Parameters
"""
pipeline: str = field(metadata={"help": "The model pipeline for the pre-trained model"})
pretrained_model_name_or_path: Optional[str] = field(default="/workspace/tfs/weights", metadata={"help": "Path to the pretrained model or model identifier from huggingface.co/models"})
state_dict: Optional[Dict[str, Any]] = field(default=None, metadata={"help": "State dictionary for the model"})
cache_dir: Optional[str] = field(default=None, metadata={"help": "Cache directory for the model"})
from_tf: bool = field(default=False, metadata={"help": "Load model from a TensorFlow checkpoint"})
force_download: bool = field(default=False, metadata={"help": "Force the download of the model"})
resume_download: bool = field(default=False, metadata={"help": "Resume an interrupted download"})
proxies: Optional[str] = field(default=None, metadata={"help": "Proxy configuration for downloading the model"})
output_loading_info: bool = field(default=False, metadata={"help": "Output additional loading information"})
allow_remote_files: bool = field(default=False, metadata={"help": "Allow using remote files, default is local only"})
revision: str = field(default="main", metadata={"help": "Specific model version to use"})
trust_remote_code: bool = field(default=False, metadata={"help": "Enable trusting remote code when loading the model"})
load_in_4bit: bool = field(default=False, metadata={"help": "Load model in 4-bit mode"})
load_in_8bit: bool = field(default=False, metadata={"help": "Load model in 8-bit mode"})
torch_dtype: Optional[str] = field(default=None, metadata={"help": "The torch dtype for the pre-trained model"})
device_map: str = field(default="auto", metadata={"help": "The device map for the pre-trained model"})

# Method to process additional arguments
def process_additional_args(self, addt_args: List[str]):
"""
Process additional cmd line args and update the model configuration accordingly.
"""
addt_args_dict = {}
i = 0
while i < len(addt_args):
key = addt_args[i].lstrip('-') # Remove leading dashes
if i + 1 < len(addt_args) and not addt_args[i + 1].startswith('--'):
value = addt_args[i + 1]
i += 2 # Move past the current key-value pair
else:
value = True # Assign a True value for standalone flags
i += 1 # Move to the next item

addt_args_dict[key] = value

# Update the ModelConfig instance with the additional args
self.__dict__.update(addt_args_dict)

def __post_init__(self):
"""
Post-initialization to validate some ModelConfig values
"""
if self.torch_dtype and not hasattr(torch, self.torch_dtype):
raise ValueError(f"Invalid torch dtype: {self.torch_dtype}")
self.torch_dtype = getattr(torch, self.torch_dtype) if self.torch_dtype else None

supported_pipelines = {"conversational", "text-generation"}
if self.pipeline not in supported_pipelines:
raise ValueError(f"Unsupported pipeline: {self.pipeline}")

parser = HfArgumentParser(ModelConfig)
args, additional_args = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)

args.process_additional_args(additional_args)

model_args = asdict(args)
model_args["local_files_only"] = not model_args.pop('allow_remote_files')
model_pipeline = model_args.pop('pipeline')

app = FastAPI()
tokenizer = AutoTokenizer.from_pretrained(**model_args)
model = AutoModelForCausalLM.from_pretrained(**model_args)

pipeline_kwargs = {
"trust_remote_code": args.trust_remote_code,
"device_map": args.device_map,
}

if args.torch_dtype:
pipeline_kwargs["torch_dtype"] = args.torch_dtype

pipeline = transformers.pipeline(
model_pipeline,
model=model,
tokenizer=tokenizer,
**pipeline_kwargs
)

try:
# Attempt to load the generation configuration
default_generate_config = GenerationConfig.from_pretrained(
args.pretrained_model_name_or_path,
local_files_only=args.local_files_only
).to_dict()
except Exception as e:
default_generate_config = {}

@app.get('/')
def home():
return "Server is running", 200

@app.get("/healthz")
def health_check():
if not torch.cuda.is_available():
raise HTTPException(status_code=500, detail="No GPU available")
if not model:
raise HTTPException(status_code=500, detail="Model not initialized")
if not pipeline:
raise HTTPException(status_code=500, detail="Pipeline not initialized")
return {"status": "Healthy"}

class GenerateKwargs(BaseModel):
max_length: int = 200 # Length of input prompt+max_new_tokens
min_length: int = 0
do_sample: bool = False
early_stopping: bool = False
num_beams: int = 1
temperature: float = 1.0
top_k: int = 50
top_p: float = 1
typical_p: float = 1
repetition_penalty: float = 1
pad_token_id: Optional[int] = tokenizer.pad_token_id
eos_token_id: Optional[int] = tokenizer.eos_token_id
class Config:
extra = Extra.allow # Allows for additional fields not explicitly defined

class UnifiedRequestModel(BaseModel):
# Fields for text generation
prompt: Optional[str] = Field(None, description="Prompt for text generation")
return_full_text: Optional[bool] = Field(True, description="Return full text if True, else only added text")
clean_up_tokenization_spaces: Optional[bool] = Field(False, description="Clean up extra spaces in text output")
prefix: Optional[str] = Field(None, description="Prefix added to prompt")
handle_long_generation: Optional[str] = Field(None, description="Strategy to handle long generation")
generate_kwargs: Optional[GenerateKwargs] = Field(default_factory=GenerateKwargs, description="Additional kwargs for generate method")

# Field for conversational model
messages: Optional[List[Dict[str, str]]] = Field(None, description="Messages for conversational model")

@app.post("/chat")
def generate_text(request_model: UnifiedRequestModel):
user_generate_kwargs = request_model.generate_kwargs.dict() if request_model.generate_kwargs else {}
generate_kwargs = {**default_generate_config, **user_generate_kwargs}

if args.pipeline == "text-generation":
if not request_model.prompt:
raise HTTPException(status_code=400, detail="Text generation parameter prompt required")
sequences = pipeline(
request_model.prompt,
# return_tensors=request_model.return_tensors,
# return_text=request_model.return_text,
return_full_text=request_model.return_full_text,
clean_up_tokenization_spaces=request_model.clean_up_tokenization_spaces,
prefix=request_model.prefix,
handle_long_generation=request_model.handle_long_generation,
**generate_kwargs
)

result = ""
for seq in sequences:
print(f"Result: {seq['generated_text']}")
result += seq['generated_text']

return {"Result": result}

elif args.pipeline == "conversational":
if not request_model.messages:
raise HTTPException(status_code=400, detail="Conversational parameter messages required")

response = pipeline(
request_model.messages,
clean_up_tokenization_spaces=request_model.clean_up_tokenization_spaces,
**generate_kwargs
)
return {"Result": str(response[-1])}

else:
raise HTTPException(status_code=400, detail="Invalid pipeline type")

@app.get("/metrics")
def get_metrics():
try:
gpus = GPUtil.getGPUs()
gpu_info = []
for gpu in gpus:
gpu_info.append({
"id": gpu.id,
"name": gpu.name,
"load": f"{gpu.load * 100:.2f}%", # Format as percentage
"temperature": f"{gpu.temperature} C",
"memory": {
"used": f"{gpu.memoryUsed / 1024:.2f} GB",
"total": f"{gpu.memoryTotal / 1024:.2f} GB"
}
})
return {"gpu_info": gpu_info}
except Exception as e:
return {"error": str(e)}

if __name__ == "__main__":
local_rank = int(os.environ.get("LOCAL_RANK", 0)) # Default to 0 if not set
port = 5000 + local_rank # Adjust port based on local rank
uvicorn.run(app=app, host='0.0.0.0', port=port)
165 changes: 165 additions & 0 deletions presets/inference/text-generation/tests/test_inference_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import importlib
import sys
from pathlib import Path
from unittest.mock import patch

import pytest
import torch
from fastapi.testclient import TestClient

# Get the parent directory of the current file
parent_dir = str(Path(__file__).resolve().parent.parent)
# Add the parent directory to sys.path
sys.path.append(parent_dir)

@pytest.fixture(params=[
{"pipeline": "text-generation", "model_path": "stanford-crfm/alias-gpt2-small-x21"},
{"pipeline": "conversational", "model_path": "stanford-crfm/alias-gpt2-small-x21"},
])
def configured_app(request):
original_argv = sys.argv.copy()
# Use request.param to set correct test arguments for each configuration
test_args = [
'program_name',
'--pipeline', request.param['pipeline'],
'--pretrained_model_name_or_path', request.param['model_path'],
'--allow_remote_files', 'True'
]
sys.argv = test_args

import inference_api
importlib.reload(inference_api) # Reload to prevent module caching
from inference_api import app

# Attach the request params to the app instance for access in tests
app.test_config = request.param
yield app

sys.argv = original_argv

def test_conversational(configured_app):
if configured_app.test_config['pipeline'] != 'conversational':
pytest.skip("Skipping non-conversational tests")
client = TestClient(configured_app)
messages = [
{"role": "user", "content": "What is your favourite condiment?"},
{"role": "assistant", "content": "Well, Im quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever Im cooking up in the kitchen!"},
{"role": "user", "content": "Do you have mayonnaise recipes?"}
]
request_data = {
"messages": messages,
"generate_kwargs": {"max_new_tokens": 20, "do_sample": True}
}
response = client.post("/chat", json=request_data)

assert response.status_code == 200
data = response.json()
assert "Result" in data
assert len(data["Result"]) > 0 # Check if the conversation result is not empty

def test_missing_messages_for_conversation(configured_app):
if configured_app.test_config['pipeline'] != 'conversational':
pytest.skip("Skipping non-conversational tests")
client = TestClient(configured_app)
request_data = {
# "messages" is missing for conversational pipeline
}
response = client.post("/chat", json=request_data)
assert response.status_code == 400 # Expecting a Bad Request response due to missing messages
assert "Conversational parameter messages required" in response.json().get("detail", "")

def test_text_generation(configured_app):
if configured_app.test_config['pipeline'] != 'text-generation':
pytest.skip("Skipping non-text-generation tests")
client = TestClient(configured_app)
request_data = {
"prompt": "Hello, world!",
"return_full_text": True,
"clean_up_tokenization_spaces": False,
"generate_kwargs": {"max_length": 50, "min_length": 10} # Example generate_kwargs
}
response = client.post("/chat", json=request_data)
assert response.status_code == 200
data = response.json()
assert "Result" in data
assert len(data["Result"]) > 0 # Check if the result text is not empty

def test_missing_prompt(configured_app):
if configured_app.test_config['pipeline'] != 'text-generation':
pytest.skip("Skipping non-text-generation tests")
client = TestClient(configured_app)
request_data = {
# "prompt" is missing
"return_full_text": True,
"clean_up_tokenization_spaces": False,
"generate_kwargs": {"max_length": 50}
}
response = client.post("/chat", json=request_data)
assert response.status_code == 400 # Expecting a Bad Request response due to missing prompt
assert "Text generation parameter prompt required" in response.json().get("detail", "")

def test_read_main(configured_app):
client = TestClient(configured_app)
response = client.get("/")
server_msg, status_code = response.json()
assert server_msg == "Server is running"
assert status_code == 200

def test_health_check(configured_app):
device = "GPU" if torch.cuda.is_available() else "CPU"
if device != "GPU":
pytest.skip("Skipping healthz endpoint check - running on CPU")
client = TestClient(configured_app)
response = client.get("/healthz")
# Assuming we have a GPU available
assert response.status_code == 200
assert response.json() == {"status": "Healthy"}

def test_get_metrics(configured_app):
client = TestClient(configured_app)
response = client.get("/metrics")
assert response.status_code == 200
assert "gpu_info" in response.json()

def test_get_metrics_no_gpus(configured_app):
client = TestClient(configured_app)
with patch('GPUtil.getGPUs', return_value=[]) as mock_getGPUs:
response = client.get("/metrics")
assert response.status_code == 200
assert response.json()["gpu_info"] == []

def test_default_generation_params(configured_app):
if configured_app.test_config['pipeline'] != 'text-generation':
pytest.skip("Skipping non-text-generation tests")

client = TestClient(configured_app)

request_data = {
"prompt": "Test default params",
"return_full_text": True,
"clean_up_tokenization_spaces": False
# Note: generate_kwargs is not provided, so defaults should be used
}

with patch('inference_api.pipeline') as mock_pipeline:
mock_pipeline.return_value = [{"generated_text": "Mocked response"}] # Mock the response of the pipeline function

response = client.post("/chat", json=request_data)

assert response.status_code == 200
data = response.json()
assert "Result" in data
assert len(data["Result"]) > 0

# Check the default args
_, kwargs = mock_pipeline.call_args
assert kwargs['max_length'] == 200
assert kwargs['min_length'] == 0
assert kwargs['do_sample'] is False
assert kwargs['temperature'] == 1.0
assert kwargs['top_k'] == 50
assert kwargs['top_p'] == 1
assert kwargs['typical_p'] == 1
assert kwargs['repetition_penalty'] == 1
assert kwargs['num_beams'] == 1
assert kwargs['early_stopping'] is False
Loading
Loading