Skip to content

Commit

Permalink
download best lora checkpoint for run and run with custom checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
mitya52 committed Oct 26, 2023
1 parent 033a3d4 commit f1dcc49
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 16 deletions.
5 changes: 4 additions & 1 deletion self_hosting_machinery/webgui/static/style.css
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,10 @@ h3 {
opacity: 1 !important;
pointer-events: auto;
}
.table-checkpoints tr td:first-of-type {
.table-checkpoints tr td:nth-of-type(1) {
width: 30px;
}
.table-checkpoints tr td:nth-of-type(2) {
width: 30px;
}
.use-model-pane {
Expand Down
24 changes: 22 additions & 2 deletions self_hosting_machinery/webgui/static/tab-finetune.js
Original file line number Diff line number Diff line change
Expand Up @@ -226,14 +226,18 @@ function render_runs() {
<i class="bi bi-play-fill"></i>
</button>`;
run_download.innerHTML = `
<a href="/lora-download/${run.run_id}.zip" download class="btn btn-hover btn-primary btn-sm" ${item_disabled}>
<a href="/lora-download?run_id=${run.run_id}"
download class="btn btn-hover btn-primary btn-sm" ${item_disabled}>
<i class="bi bi-download"></i>
</a>`;
if (!run_is_working) {
run_active.addEventListener('click', (event) => {
event.stopPropagation();
finetune_activate_run(run_table_row.dataset.run);
});
run_download.addEventListener('click', (event) => {
event.stopPropagation();
});
}
} else {
run_active.innerHTML = ``;
Expand Down Expand Up @@ -384,9 +388,22 @@ function render_checkpoints(data = []) {
row.classList.add('table-success');
}
const activate_cell = document.createElement('td');
activate_cell.innerHTML = `<button class="btn btn-hover btn-primary btn-sm"><i class="bi bi-play-fill"></i></button>`;
const download_cell = document.createElement('td');

activate_cell.innerHTML = `
<button class="btn btn-hover btn-primary btn-sm">
<i class="bi bi-play-fill"></i>
</button>`;
download_cell.innerHTML = `
<a href="/lora-download?run_id=${selected_lora}&checkpoint_id=${element.checkpoint_name}"
download class="btn btn-hover btn-primary btn-sm">
<i class="bi bi-download"></i>
</a>`;

row.appendChild(activate_cell);
row.appendChild(download_cell);
row.appendChild(cell);

checkpoints.appendChild(row);
activate_cell.addEventListener('click', (event) => {
if(!row.classList.contains('table-success')) {
Expand All @@ -398,6 +415,9 @@ function render_checkpoints(data = []) {
}
finetune_activate_run(selected_lora, cell.dataset.checkpoint);
});
activate_cell.addEventListener('click', (event) => {
event.stopPropagation();
});
});
}
}
Expand Down
61 changes: 48 additions & 13 deletions self_hosting_machinery/webgui/tab_loras.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import asyncio
import uuid

import aiofiles
import os
import subprocess

from typing import Union
from pathlib import Path

from fastapi import APIRouter, UploadFile, HTTPException
from fastapi import APIRouter, UploadFile, HTTPException, Query
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import Required

from self_hosting_machinery import env
from self_hosting_machinery.webgui.selfhost_webutils import log
from self_hosting_machinery.webgui.tab_upload import download_file_from_url, UploadViaURL
from self_hosting_machinery.scripts.best_lora import find_best_checkpoint


def rm(f):
Expand Down Expand Up @@ -59,7 +63,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.add_api_route("/lora-upload", self._upload_lora, methods=["POST"])
self.add_api_route("/lora-upload-url", self._upload_lora_url, methods=["POST"])
self.add_api_route("/lora-download/{run_id}.zip", self._download_lora, methods=["GET"])
self.add_api_route("/lora-download", self._download_lora, methods=["GET"])

async def _upload_lora(self, file: UploadFile):
async def write_to_file() -> JSONResponse:
Expand Down Expand Up @@ -107,32 +111,63 @@ async def _upload_lora_url(self, file: UploadViaURL):
return resp
return JSONResponse("OK", status_code=200)

async def _download_lora(self, run_id: str):
async def _download_lora(self,
run_id: str = Query(default=Required),
checkpoint_id: str = Query(default="")):

async def _archived_content():
temp_filename = str(Path(env.TMPDIR) / f"{run_id}.zip")
async def _archived_content(run_id: str, checkpoint_id: str):
tempdir = Path(env.TMPDIR) / f"lora-download-{uuid.uuid4()}"
copy_run_dirname = tempdir / run_id
zipped_run_filename = tempdir / f"{run_id}.zip"
try:
rm(temp_filename)
tempdir.mkdir(parents=False, exist_ok=False)

# copy run to temp_run_dirname
process = await asyncio.create_subprocess_exec(
"zip", "-r", temp_filename, run_id, cwd=env.DIR_LORAS)
"cp", "-r", str(Path(env.DIR_LORAS) / run_id), str(copy_run_dirname))
await process.wait()
if process.returncode != 0:
raise RuntimeError("run copying failed")

# remove unspecified checkpoints
checkpoints_dir = copy_run_dirname / "checkpoints"
for checkpoint_dir in checkpoints_dir.iterdir():
if checkpoint_dir.name != checkpoint_id:
rm(str(checkpoint_dir))

# zip prepared run
process = await asyncio.create_subprocess_exec(
"zip", "-r", str(zipped_run_filename), run_id,
cwd=str(zipped_run_filename.parent))
await process.wait()
if process.returncode != 0:
raise RuntimeError("archive creation failed")

async with aiofiles.open(temp_filename, "rb") as f:
async with aiofiles.open(zipped_run_filename, "rb") as f:
while True:
if not (contents := await f.read(1024 * 1024)):
break
yield contents

rm(str(tempdir))

except BaseException as e:
rm(temp_filename)
rm(str(tempdir))
err_msg = "Error while downloading: %s" % (e or str(type(e)))
log(err_msg)
raise HTTPException(detail=err_msg, status_code=500)

finally:
rm(temp_filename)
rm(str(tempdir))

download_filename = run_id + (f"-{checkpoint_id}" if checkpoint_id else "") + ".zip"
if not checkpoint_id:
checkpoint_id = find_best_checkpoint(run_id)["best_checkpoint_id"]

return StreamingResponse(
_archived_content(),
media_type="application/zip",
headers={"Content-Disposition": f"attachment; filename={run_id}.zip"})
_archived_content(run_id, checkpoint_id),
media_type="application/x-zip-compressed",
headers={
"Content-Type": "application/x-zip-compressed",
"Content-Disposition": f'attachment; filename={download_filename}',
})

0 comments on commit f1dcc49

Please sign in to comment.