Skip to content

Commit

Permalink
feat: Debug Flag and tuning /metrics endpoint (#544)
Browse files Browse the repository at this point in the history
**Reason for Change**:
Added debug flag and tuning /metrics endpoint

---------

Signed-off-by: Ishaan Sehgal <[email protected]>
  • Loading branch information
ishaansehgal99 authored Aug 21, 2024
1 parent 9d42673 commit 5f2f531
Show file tree
Hide file tree
Showing 10 changed files with 245 additions and 19 deletions.
3 changes: 3 additions & 0 deletions docker/presets/models/tfs/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,8 @@ COPY kaito/presets/tuning/${MODEL_TYPE}/fine_tuning.py /workspace/tfs/fine_tunin
COPY kaito/presets/tuning/${MODEL_TYPE}/parser.py /workspace/tfs/parser.py
COPY kaito/presets/tuning/${MODEL_TYPE}/dataset.py /workspace/tfs/dataset.py

# Copy the metrics server
COPY kaito/presets/tuning/${MODEL_TYPE}/metrics/metrics_server.py /workspace/tfs/metrics_server.py

# Copy the entire model weights to the weights directory
COPY ${WEIGHTS_PATH} /workspace/tfs/weights
29 changes: 22 additions & 7 deletions presets/inference/text-generation/inference_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import os
import subprocess
from dataclasses import asdict, dataclass, field
Expand All @@ -17,7 +18,13 @@
from transformers import (AutoModelForCausalLM, AutoTokenizer,
GenerationConfig, HfArgumentParser)

# Initialize logger
logger = logging.getLogger(__name__)
debug_mode = os.environ.get('DEBUG_MODE', 'false').lower() == 'true'
logging.basicConfig(level=logging.DEBUG if debug_mode else logging.INFO)

ADAPTERS_DIR = '/mnt/adapter'

@dataclass
class ModelConfig:
"""
Expand Down Expand Up @@ -128,13 +135,14 @@ def __post_init__(self): # validate parameters

active_adapters = model.active_adapters
if len(active_adapters) != 1 or active_adapters[0] != "combined_adapter":
raise ValueError(f"Adpaters is input but not merged correctlly")
print("Adapter added:", ', '.join(sorted(adapter_names)))
raise ValueError(f"Adapters not merged correctly")
logger.info("Adapter added: %s", ', '.join(sorted(adapter_names)))
else:
print("Warning: Did not find any valid adapters mounted, using base model")
logger.warning("Did not find any valid adapters mounted, using base model")
model = base_model

print("Model:", model)

logger.info("Model loaded successfully")
logger.info("Model: %s", model)

pipeline_kwargs = {
"trust_remote_code": args.trust_remote_code,
Expand Down Expand Up @@ -206,8 +214,10 @@ class HealthStatus(BaseModel):
)
def health_check():
if not model:
logger.error("Model not initialized")
raise HTTPException(status_code=500, detail="Model not initialized")
if not pipeline:
logger.error("Pipeline not initialized")
raise HTTPException(status_code=500, detail="Pipeline not initialized")
return {"status": "Healthy"}

Expand Down Expand Up @@ -351,6 +361,7 @@ def generate_text(

if args.pipeline == "text-generation":
if not request_model.prompt:
logger.error("Text generation parameter prompt required")
raise HTTPException(status_code=400, detail="Text generation parameter prompt required")
sequences = pipeline(
request_model.prompt,
Expand All @@ -365,13 +376,14 @@ def generate_text(

result = ""
for seq in sequences:
print(f"Result: {seq['generated_text']}")
logger.debug(f"Result: {seq['generated_text']}")
result += seq['generated_text']

return {"Result": result}

elif args.pipeline == "conversational":
if not request_model.messages:
logger.error("Conversational parameter messages required")
raise HTTPException(status_code=400, detail="Conversational parameter messages required")

response = pipeline(
Expand All @@ -382,6 +394,7 @@ def generate_text(
return {"Result": str(response[-1])}

else:
logger.error("Invalid pipeline type")
raise HTTPException(status_code=400, detail="Invalid pipeline type")

class MemoryInfo(BaseModel):
Expand Down Expand Up @@ -476,9 +489,11 @@ def get_metrics():
)
return MetricsResponse(cpu_info=cpu_info)
except Exception as e:
logger.error(f"Error fetching metrics: {e}")
raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
local_rank = int(os.environ.get("LOCAL_RANK", 0)) # Default to 0 if not set
port = 5000 + local_rank # Adjust port based on local rank
uvicorn.run(app=app, host='0.0.0.0', port=port)
logger.info(f"Starting server on port {port}")
uvicorn.run(app=app, host='0.0.0.0', port=port)
2 changes: 1 addition & 1 deletion presets/models/falcon/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ var (
"Falcon40BInstruct": "0.0.6",
}

baseCommandPresetFalcon = "accelerate launch"
baseCommandPresetFalcon = "python3 metrics_server.py & accelerate launch"
falconRunParams = map[string]string{
"torch_dtype": "bfloat16",
"pipeline": "text-generation",
Expand Down
2 changes: 1 addition & 1 deletion presets/models/mistral/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ var (
"Mistral7BInstruct": "0.0.6",
}

baseCommandPresetMistral = "accelerate launch"
baseCommandPresetMistral = "python3 metrics_server.py & accelerate launch"
mistralRunParams = map[string]string{
"torch_dtype": "bfloat16",
"pipeline": "text-generation",
Expand Down
2 changes: 1 addition & 1 deletion presets/models/phi2/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ var (
"Phi2": "0.0.4",
}

baseCommandPresetPhi = "accelerate launch"
baseCommandPresetPhi = "python3 metrics_server.py & accelerate launch"
phiRunParams = map[string]string{
"torch_dtype": "float16",
"pipeline": "text-generation",
Expand Down
2 changes: 1 addition & 1 deletion presets/models/phi3/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ var (
"Phi3Medium128kInstruct": "0.0.1",
}

baseCommandPresetPhi = "accelerate launch"
baseCommandPresetPhi = "python3 metrics_server.py & accelerate launch"
phiRunParams = map[string]string{
"torch_dtype": "auto",
"pipeline": "text-generation",
Expand Down
20 changes: 13 additions & 7 deletions presets/tuning/text-generation/fine_tuning.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import os
import sys
from dataclasses import asdict
Expand All @@ -17,8 +18,12 @@
TrainingArguments)
from trl import SFTTrainer

CONFIG_YAML = os.environ.get('YAML_FILE_PATH', '/mnt/config/training_config.yaml')
# Initialize logger
logger = logging.getLogger(__name__)
debug_mode = os.environ.get('DEBUG_MODE', 'false').lower() == 'true'
logging.basicConfig(level=logging.DEBUG if debug_mode else logging.INFO)

CONFIG_YAML = os.environ.get('YAML_FILE_PATH', '/mnt/config/training_config.yaml')
parsed_configs = parse_configs(CONFIG_YAML)

model_config = parsed_configs.get('ModelConfig')
Expand All @@ -33,7 +38,7 @@
# Load Model Args
model_args = asdict(model_config)
if accelerator.distributed_type != "NO": # Meaning we require distributed training
print("Setting device map for distributed training")
logger.debug("Setting device map for distributed training")
model_args["device_map"] = {"": accelerator.process_index}

# Load BitsAndBytesConfig
Expand All @@ -47,7 +52,7 @@
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
if dc_args.mlm and tokenizer.mask_token is None:
print(
logger.warning(
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
"You should pass `mlm=False` to train on causal language modeling instead. "
"Setting mlm=False"
Expand All @@ -61,14 +66,15 @@
quantization_config=bnb_config if enable_qlora else None,
)

print("Model Loaded")
logger.info("Model Loaded")

if enable_qlora:
# Preparing the Model for QLoRA
model = prepare_model_for_kbit_training(model)
print("QLoRA Enabled")
logger.info("QLoRA Enabled")

if not ext_lora_config:
logger.error("LoraConfig must be specified")
raise ValueError("LoraConfig must be specified")

lora_config_args = asdict(ext_lora_config)
Expand All @@ -83,7 +89,7 @@
# Load the dataset
dm.load_data()
if not dm.get_dataset():
print("Failed to load dataset.")
logger.error("Failed to load dataset.")
raise ValueError("Unable to load the dataset.")

# Shuffling the dataset (if needed)
Expand Down Expand Up @@ -119,7 +125,7 @@ def on_step_end(self, args, state: TrainerState, control: TrainerControl, **kwar

# Write file to signify training completion
timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
print("Fine-Tuning completed\n")
logger.info("Fine-Tuning completed\n")
completion_indicator_path = os.path.join(ta_args.output_dir, "fine_tuning_completed.txt")
with open(completion_indicator_path, 'w') as f:
f.write(f"Fine-Tuning completed at {timestamp}\n")
122 changes: 122 additions & 0 deletions presets/tuning/text-generation/metrics/metrics_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# metrics_server.py
import logging
import os
from typing import List, Optional

import GPUtil
import psutil
import torch
import uvicorn
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

# Initialize logger
logger = logging.getLogger(__name__)
debug_mode = os.environ.get('DEBUG_MODE', 'false').lower() == 'true'
logging.basicConfig(level=logging.DEBUG if debug_mode else logging.INFO)

app = FastAPI()

class ErrorResponse(BaseModel):
detail: str

class MemoryInfo(BaseModel):
used: str
total: str

class CPUInfo(BaseModel):
load_percentage: float
physical_cores: int
total_cores: int
memory: MemoryInfo

class GPUInfo(BaseModel):
id: str
name: str
load: str
temperature: str
memory: MemoryInfo

class MetricsResponse(BaseModel):
gpu_info: Optional[List[GPUInfo]] = None
cpu_info: Optional[CPUInfo] = None

@app.get(
"/metrics",
response_model=MetricsResponse,
summary="Metrics Endpoint",
responses={
200: {
"description": "Successful Response",
"content": {
"application/json": {
"examples": {
"gpu_metrics": {
"summary": "Example when GPUs are available",
"value": {
"gpu_info": [{"id": "GPU-1234", "name": "GeForce GTX 950", "load": "25.00%", "temperature": "55 C", "memory": {"used": "1.00 GB", "total": "2.00 GB"}}],
"cpu_info": None # Indicates CPUs info might not be present when GPUs are available
}
},
"cpu_metrics": {
"summary": "Example when only CPU is available",
"value": {
"gpu_info": None, # Indicates GPU info might not be present when only CPU is available
"cpu_info": {"load_percentage": 20.0, "physical_cores": 4, "total_cores": 8, "memory": {"used": "4.00 GB", "total": "16.00 GB"}}
}
}
}
}
}
},
500: {
"description": "Internal Server Error",
"model": ErrorResponse,
}
}
)
def get_metrics():
"""
Provides system metrics, including GPU details if available, or CPU and memory usage otherwise.
Useful for monitoring the resource utilization of the server running the ML models.
"""
try:
if torch.cuda.is_available():
gpus = GPUtil.getGPUs()
gpu_info = [GPUInfo(
id=str(gpu.id),
name=gpu.name,
load=f"{gpu.load * 100:.2f}%",
temperature=f"{gpu.temperature} C",
memory=MemoryInfo(
used=f"{gpu.memoryUsed / (1024 ** 3):.2f} GB",
total=f"{gpu.memoryTotal / (1024 ** 3):.2f} GB"
)
) for gpu in gpus]
return MetricsResponse(gpu_info=gpu_info)
else:
# Gather CPU metrics
cpu_usage = psutil.cpu_percent(interval=1, percpu=False)
physical_cores = psutil.cpu_count(logical=False)
total_cores = psutil.cpu_count(logical=True)
virtual_memory = psutil.virtual_memory()
memory = MemoryInfo(
used=f"{virtual_memory.used / (1024 ** 3):.2f} GB",
total=f"{virtual_memory.total / (1024 ** 3):.2f} GB"
)
cpu_info = CPUInfo(
load_percentage=cpu_usage,
physical_cores=physical_cores,
total_cores=total_cores,
memory=memory
)
return MetricsResponse(cpu_info=cpu_info)
except Exception as e:
logger.error(f"Error fetching metrics: {e}")
raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
local_rank = int(os.environ.get("LOCAL_RANK", 0)) # Default to 0 if not set
port = 5000 + local_rank # Adjust port based on local rank
logger.info(f"Starting server on port {port}")
uvicorn.run(app=app, host='0.0.0.0', port=port)
Loading

0 comments on commit 5f2f531

Please sign in to comment.