-
Notifications
You must be signed in to change notification settings - Fork 62
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added notebook for uploading vllm eval results to wandb post-hoc
- Loading branch information
1 parent
7b3cded
commit 9dc9398
Showing
1 changed file
with
399 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,399 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"import pandas as pd\n", | ||
"import wandb" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Setup\n", | ||
"\n", | ||
"### Read Results CSV Files\n", | ||
"\n", | ||
"This assumes that you already have at least 1 result csv file for each step. We would recommend using `run_checkpoints.sh` and `run_checkpoints_cot.sh` to generate the result csv files, where it would automatically save the results in the following format:\n", | ||
"```\n", | ||
"{run_name}/c{checkpoint_number}_api_{benchmark}_{if cot}.csv\n", | ||
"```\n", | ||
"\n", | ||
"Update the `result_folder` variable to point to the folder that contains the csv files. This will import all of the csv files in there.\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Found 32 csv files in results/sqlcoder_8b_fullft_ds_013_llama3_mgn1_b1_0900_b2_0990_steps_1000\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Step 1. Specify the folder where the result csv files are stored\n", | ||
"result_folder = (\n", | ||
" \"results/sqlcoder_8b_fullft_ds_013_llama3_mgn1_b1_0900_b2_0990_steps_1000\"\n", | ||
")\n", | ||
"csv_files = []\n", | ||
"for f in os.listdir(result_folder):\n", | ||
" if f.endswith(\".csv\"):\n", | ||
" csv_files.append(f)\n", | ||
"print(f\"Found {len(csv_files)} csv files in {result_folder}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Step 2. Specify the wandb run id\n", | ||
"# We don't do a lookup via the wandb API because different runs may have the same run name\n", | ||
"wandb_run_id = \"qcbad5rx\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Loaded 3272 results from 32 csv files\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Load results from csv file into dataframe\n", | ||
"results_dfs = []\n", | ||
"for csv_file_name in csv_files:\n", | ||
" file_path = os.path.join(result_folder, csv_file_name)\n", | ||
" df_i = pd.read_csv(file_path, comment=\"#\")\n", | ||
" df_i[\"model\"] = csv_file_name.rsplit(\".csv\", 1)[0]\n", | ||
" results_dfs.append(df_i)\n", | ||
"results_df = pd.concat(results_dfs, ignore_index=True)\n", | ||
"print(f\"Loaded {results_df.shape[0]} results from {len(csv_files)} csv files\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"s = results_df.groupby(\"model\")[\"correct\"].mean()\n", | ||
"s = pd.DataFrame(s)\n", | ||
"s[\"file_name\"] = s.index\n", | ||
"s[\"benchmark\"] = s[\"file_name\"].str.extract(r\"_(advanced|basic|v1|idk)\")\n", | ||
"s[\"checkpoint\"] = s[\"file_name\"].str.extract(r\"c(\\d+)_\").astype(int)\n", | ||
"s[\"cot\"] = s[\"file_name\"].str.extract(r\"_(cot)\").fillna(\"no_cot\")\n", | ||
"s = s.reset_index(drop=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Found 4 checkpoints: [ 400 600 800 1000]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Get unique checkpoints\n", | ||
"checkpoints = s[\"checkpoint\"].unique()\n", | ||
"checkpoints.sort()\n", | ||
"print(f\"Found {len(checkpoints)} checkpoints: {checkpoints}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", | ||
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mwongjingping\u001b[0m (\u001b[33mdefog\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "b755e8f801154731939fd91ddd19cbe3", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"VBox(children=(Label(value='Waiting for wandb.init()...\\r'), FloatProgress(value=0.011166706489812996, max=1.0…" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"text/html": [ | ||
"Tracking run with wandb version 0.17.1" | ||
], | ||
"text/plain": [ | ||
"<IPython.core.display.HTML object>" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"text/html": [ | ||
"Run data is saved locally in <code>/Users/jp/workspace/sql-eval/wandb/run-20240613_122650-qcbad5rx</code>" | ||
], | ||
"text/plain": [ | ||
"<IPython.core.display.HTML object>" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"text/html": [ | ||
"Resuming run <strong><a href='https://wandb.ai/defog/huggingface/runs/qcbad5rx' target=\"_blank\">sqlcoder_8b_fullft_ds_013_llama3_mgn1_b1_0900_b2_0990_steps_1000</a></strong> to <a href='https://wandb.ai/defog/huggingface' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>" | ||
], | ||
"text/plain": [ | ||
"<IPython.core.display.HTML object>" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"text/html": [ | ||
" View project at <a href='https://wandb.ai/defog/huggingface' target=\"_blank\">https://wandb.ai/defog/huggingface</a>" | ||
], | ||
"text/plain": [ | ||
"<IPython.core.display.HTML object>" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"text/html": [ | ||
" View run at <a href='https://wandb.ai/defog/huggingface/runs/qcbad5rx' target=\"_blank\">https://wandb.ai/defog/huggingface/runs/qcbad5rx</a>" | ||
], | ||
"text/plain": [ | ||
"<IPython.core.display.HTML object>" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
} | ||
], | ||
"source": [ | ||
"# Continue existing run, specifying the project and the run ID\n", | ||
"run = wandb.init(project=\"huggingface\", id=wandb_run_id, resume=\"must\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Current step: 1001\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# get current step, so that we can log incrementally after it\n", | ||
"# this is because wandb doesn't allow logging back to previous steps\n", | ||
"current_step = run.step\n", | ||
"print(f\"Current step: {current_step}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Logging checkpoint 400 metrics:\n", | ||
"\tvllm/advanced_cot: 0.75\n", | ||
"\tvllm/advanced: 0.78125\n", | ||
"\tvllm/basic_cot: 0.9\n", | ||
"\tvllm/basic: 0.825\n", | ||
"\tvllm/v1_cot: 0.875\n", | ||
"\tvllm/v1: 0.865\n", | ||
"\tvllm/idk_cot: 0.9238095238095239\n", | ||
"\tvllm/idk: 0.8476190476190476\n", | ||
"Logging checkpoint 600 metrics:\n", | ||
"\tvllm/advanced_cot: 0.703125\n", | ||
"\tvllm/advanced: 0.765625\n", | ||
"\tvllm/basic_cot: 0.9\n", | ||
"\tvllm/basic: 0.85\n", | ||
"\tvllm/v1_cot: 0.85\n", | ||
"\tvllm/v1: 0.84\n", | ||
"\tvllm/idk_cot: 0.9523809523809523\n", | ||
"\tvllm/idk: 0.8952380952380953\n", | ||
"Logging checkpoint 800 metrics:\n", | ||
"\tvllm/advanced_cot: 0.765625\n", | ||
"\tvllm/advanced: 0.765625\n", | ||
"\tvllm/basic_cot: 0.925\n", | ||
"\tvllm/basic: 0.9\n", | ||
"\tvllm/v1_cot: 0.86\n", | ||
"\tvllm/v1: 0.845\n", | ||
"\tvllm/idk_cot: 0.9523809523809523\n", | ||
"\tvllm/idk: 0.8761904761904762\n", | ||
"Logging checkpoint 1000 metrics:\n", | ||
"\tvllm/advanced_cot: 0.78125\n", | ||
"\tvllm/advanced: 0.78125\n", | ||
"\tvllm/basic_cot: 0.925\n", | ||
"\tvllm/basic: 0.9\n", | ||
"\tvllm/v1_cot: 0.865\n", | ||
"\tvllm/v1: 0.845\n", | ||
"\tvllm/idk_cot: 0.9523809523809523\n", | ||
"\tvllm/idk: 0.8761904761904762\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"for checkpoint in checkpoints:\n", | ||
" checkpoint_metrics = {}\n", | ||
" for benchmark in [\"advanced\", \"basic\", \"v1\", \"idk\"]:\n", | ||
" for cot in [\"cot\", \"no_cot\"]:\n", | ||
" mask = (\n", | ||
" (s[\"checkpoint\"] == checkpoint)\n", | ||
" & (s[\"benchmark\"] == benchmark)\n", | ||
" & (s[\"cot\"] == cot)\n", | ||
" )\n", | ||
" if mask.sum() == 1:\n", | ||
" row = s[mask]\n", | ||
" metric_name = f\"vllm/{benchmark}\"\n", | ||
" if cot == \"cot\":\n", | ||
" metric_name += \"_cot\"\n", | ||
" metric_value = row[\"correct\"].values[0]\n", | ||
" checkpoint_metrics[metric_name] = metric_value\n", | ||
" print(f\"Logging checkpoint {checkpoint} metrics:\")\n", | ||
" for k, v in checkpoint_metrics.items():\n", | ||
" print(f\"\\t{k}: {v}\")\n", | ||
" # we log the metrics at the current step + checkpoint\n", | ||
" wandb.log(checkpoint_metrics, step=current_step + checkpoint)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "49370a7af50845f4b0b54eebe0500e05", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"text/html": [ | ||
"<style>\n", | ||
" table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n", | ||
" .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n", | ||
" .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n", | ||
" </style>\n", | ||
"<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>vllm/advanced</td><td>█▁▁█</td></tr><tr><td>vllm/advanced_cot</td><td>▅▁▇█</td></tr><tr><td>vllm/basic</td><td>▁▃██</td></tr><tr><td>vllm/basic_cot</td><td>▁▁██</td></tr><tr><td>vllm/idk</td><td>▁█▅▅</td></tr><tr><td>vllm/idk_cot</td><td>▁███</td></tr><tr><td>vllm/v1</td><td>█▁▂▂</td></tr><tr><td>vllm/v1_cot</td><td>█▁▄▅</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>advanced</td><td>0.45312</td></tr><tr><td>basic</td><td>0.95</td></tr><tr><td>basic_group_order_limit</td><td>1</td></tr><tr><td>basic_join_date_group_order_limit</td><td>0.875</td></tr><tr><td>basic_join_distinct</td><td>1</td></tr><tr><td>basic_join_group_order_limit</td><td>0.875</td></tr><tr><td>basic_left_join</td><td>1</td></tr><tr><td>cat_a</td><td>0</td></tr><tr><td>cat_b</td><td>0</td></tr><tr><td>cat_c</td><td>0</td></tr><tr><td>date_functions</td><td>0.84</td></tr><tr><td>eval/count_mismatch_i_diff_avg</td><td>5.375</td></tr><tr><td>eval/first_index_mismatch_avg</td><td>11.20833</td></tr><tr><td>eval/loss</td><td>0.1782</td></tr><tr><td>eval/mean_mismatch_i_diff_avg</td><td>15.33588</td></tr><tr><td>eval/runtime</td><td>15.7635</td></tr><tr><td>eval/samples_per_second</td><td>1.523</td></tr><tr><td>eval/sql_exact_match_string</td><td>3</td></tr><tr><td>eval/steps_per_second</td><td>0.127</td></tr><tr><td>eval/tokens_match_avg</td><td>0.94784</td></tr><tr><td>group_by</td><td>0.97143</td></tr><tr><td>idk</td><td>0</td></tr><tr><td>instruct</td><td>0.8</td></tr><tr><td>instructions_cte_join</td><td>0.75</td></tr><tr><td>instructions_cte_window</td><td>0</td></tr><tr><td>instructions_date_join</td><td>0.375</td></tr><tr><td>instructions_string_matching</td><td>0.75</td></tr><tr><td>keywords_aggregate</td><td>0.625</td></tr><tr><td>keywords_ratio</td><td>0</td></tr><tr><td>order_by</td><td>0.85714</td></tr><tr><td>overall</td><td>0.5868</td></tr><tr><td>ratio</td><td>0.85714</td></tr><tr><td>table_join</td><td>0.85714</td></tr><tr><td>total_flos</td><td>9.784632598246196e+17</td></tr><tr><td>train/epoch</td><td>1</td></tr><tr><td>train/global_step</td><td>1000</td></tr><tr><td>train/grad_norm</td><td>4</td></tr><tr><td>train/learning_rate</td><td>0.0</td></tr><tr><td>train/loss</td><td>0.1368</td></tr><tr><td>train_loss</td><td>0.15225</td></tr><tr><td>train_runtime</td><td>8061.5451</td></tr><tr><td>train_samples_per_second</td><td>2.977</td></tr><tr><td>train_steps_per_second</td><td>0.124</td></tr><tr><td>v1</td><td>0.865</td></tr><tr><td>vllm/advanced</td><td>0.78125</td></tr><tr><td>vllm/advanced_cot</td><td>0.78125</td></tr><tr><td>vllm/basic</td><td>0.9</td></tr><tr><td>vllm/basic_cot</td><td>0.925</td></tr><tr><td>vllm/idk</td><td>0.87619</td></tr><tr><td>vllm/idk_cot</td><td>0.95238</td></tr><tr><td>vllm/v1</td><td>0.845</td></tr><tr><td>vllm/v1_cot</td><td>0.865</td></tr></table><br/></div></div>" | ||
], | ||
"text/plain": [ | ||
"<IPython.core.display.HTML object>" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"text/html": [ | ||
" View run <strong style=\"color:#cdcd00\">sqlcoder_8b_fullft_ds_013_llama3_mgn1_b1_0900_b2_0990_steps_1000</strong> at: <a href='https://wandb.ai/defog/huggingface/runs/qcbad5rx' target=\"_blank\">https://wandb.ai/defog/huggingface/runs/qcbad5rx</a><br/> View project at: <a href='https://wandb.ai/defog/huggingface' target=\"_blank\">https://wandb.ai/defog/huggingface</a><br/>Synced 4 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)" | ||
], | ||
"text/plain": [ | ||
"<IPython.core.display.HTML object>" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"text/html": [ | ||
"Find logs at: <code>./wandb/run-20240613_122650-qcbad5rx/logs</code>" | ||
], | ||
"text/plain": [ | ||
"<IPython.core.display.HTML object>" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
} | ||
], | ||
"source": [ | ||
"# Finish the run\n", | ||
"run.finish()" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "base", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.4" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |