-
Notifications
You must be signed in to change notification settings - Fork 537
/
dbrx.yaml
94 lines (83 loc) · 2.44 KB
/
dbrx.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Serving Databricks DBRX on your own infra.
#
# Usage:
#
# HF_TOKEN=xxx sky launch dbrx.yaml -c dbrx --env HF_TOKEN
#
# curl /v1/chat/completions:
#
# IP=$(sky status --ip dbrx)
# curl $IP:8081/v1/models
# curl http://$IP:8081/v1/chat/completions \
# -H "Content-Type: application/json" \
# -d '{
# "model": "databricks/dbrx-instruct",
# "messages": [
# {
# "role": "system",
# "content": "You are a helpful assistant."
# },
# {
# "role": "user",
# "content": "Who are you?"
# }
# ]
# }'
#
# Chat with model with Gradio UI:
#
# Running on local URL: http://127.0.0.1:8811
# Running on public URL: https://<hash>.gradio.live
envs:
MODEL_NAME: databricks/dbrx-instruct
HF_TOKEN: # TODO: Fill with your own huggingface token, or use --env to pass.
service:
replicas: 2
# An actual request for readiness probe.
readiness_probe:
path: /v1/chat/completions
post_data:
model: $MODEL_NAME
messages:
- role: user
content: Hello! What is your name?
max_tokens: 1
resources:
accelerators: {A100-80GB:8, A100-80GB:4, A100:8, A100:16}
cpus: 32+
memory: 512+
use_spot: True
disk_size: 512 # Ensure model checkpoints (~246GB) can fit.
disk_tier: best
ports: 8081 # Expose to internet traffic.
setup: |
conda activate vllm
if [ $? -ne 0 ]; then
conda create -n vllm python=3.10 -y
conda activate vllm
fi
# DBRX merged on master, 3/27/2024
pip install git+https://github.com/vllm-project/vllm.git@e24336b5a772ab3aa6ad83527b880f9e5050ea2a
pip install gradio tiktoken==0.6.0 openai
run: |
conda activate vllm
echo 'Starting vllm api server...'
# https://github.com/vllm-project/vllm/issues/3098
export PATH=$PATH:/sbin
# NOTE: --gpu-memory-utilization 0.95 needed for 4-GPU nodes.
python -u -m vllm.entrypoints.openai.api_server \
--port 8081 \
--model $MODEL_NAME \
--trust-remote-code --tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \
--gpu-memory-utilization 0.95 \
2>&1 | tee api_server.log &
while ! `cat api_server.log | grep -q 'Uvicorn running on'`; do
echo 'Waiting for vllm api server to start...'
sleep 5
done
echo 'Starting gradio server...'
git clone https://github.com/vllm-project/vllm.git || true
python vllm/examples/gradio_openai_chatbot_webserver.py \
-m $MODEL_NAME \
--port 8811 \
--model-url http://localhost:8081/v1