Skip to content

Commit

Permalink
Create local model directory as needed, only prompt for token if pull…
Browse files Browse the repository at this point in the history
…ing model artifacts.
  • Loading branch information
Kyle Lilly committed May 24, 2024
1 parent 384de66 commit 02881d9
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 10 deletions.
10 changes: 9 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ modelCheck:
[ $$? != 0 ]; \
then \
localModelDir="./models"; \
if \
[ ! -d "$localModelDir" ]; \
then \
mkdir "$localModelDir"; \
fi; \
echo; \
echo "Preparing to download, convert, and upload safetensors for model: $(MODEL_ID)"; \
echo "Local directory: '$$localModelDir' will be used to store downloaded and converted model weights"; \
Expand All @@ -171,10 +176,13 @@ modelCheck:
if \
[ $${confirm_download:-'N'} = 'y' ]; \
then \
mkdir -p $$localModelDir; \
echo "What is your huggingface access token? "; \
read -s access_token; \
echo "Converting and uploading safetensors for model: $(MODEL_ID)"; \
tgiImage=$$(yq '[.$(ENV).ecsModels[] | select(.inferenceContainer == "tgi") | .containerConfig.image.baseImage] | first' $(PROJECT_DIR)/config.yaml | sed "s/^['\"]//;s/['\"]$$//"); \
echo $$tgiImage; \
$(PROJECT_DIR)/scripts/convert-and-upload-model.sh -m $(MODEL_ID) -s $(MODEL_BUCKET) -t $$tgiImage -d $$localModelDir; \
$(PROJECT_DIR)/scripts/convert-and-upload-model.sh -m $(MODEL_ID) -s $(MODEL_BUCKET) -a $$access_token -t $$tgiImage -d $$localModelDir; \
fi; \
fi; \
)
Expand Down
9 changes: 5 additions & 4 deletions scripts/convert-and-upload-model.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

set -e

ACCESS_TOKEN=""
TGI_IMAGE=""
MODEL_DIR=""
MODEL_ID=""
Expand All @@ -11,6 +12,7 @@ OUTPUT_ID=""
usage(){
cat << EOF >&2
Usage: $0
[ -a | --access-token - huggingface access token for accessing restricted models
[ -t | --tgi-image - docker image and tag for TGI
[ -d | --output-dir - local directory to use for model storage
[ -m | --model-id - the huggingface model-id or path to local model dir
Expand All @@ -23,6 +25,8 @@ while true; do
case "$1" in
-t | --tgi-container )
TGI_IMAGE="$2"; shift 2 ;;
-a | --access-token )
ACCESS_TOKEN="$2"; shift 2 ;;
-d | --model-dir )
MODEL_DIR="$2"; shift 2 ;;
-m | --model-id )
Expand All @@ -41,9 +45,6 @@ done
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ABS_MODEL_DIR=`realpath ${MODEL_DIR}`

echo What is your huggingface access token?
read -s AUTH_TOKEN

docker run \
--rm \
--entrypoint python \
Expand All @@ -54,7 +55,7 @@ docker run \
/code/convert-to-safetensors.py \
--output-dir /model \
--model-id $MODEL_ID \
--auth-token $AUTH_TOKEN
--access-token $ACCESS_TOKEN

if [ -n "${S3_BUCKET}" ]; then
echo "Uploading model to ${S3_BUCKET}/${MODEL_ID}"
Expand Down
10 changes: 5 additions & 5 deletions scripts/convert-to-safetensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
----------
--model-id: str = None
huggingface model id or local model dir
--auth-token: str = None
huggingface auth token for restricted models
--access-token: str = None
huggingface access token for restricted models
--output-dir: str = None
directory to output model to
--max-shard-size: str = "2GB"
Expand All @@ -34,7 +34,7 @@

parser = argparse.ArgumentParser()
parser.add_argument("--model-id", help="huggingface model ID or local model dir", default=None)
parser.add_argument("--auth-token", help="huggingface authtoken for restricted models", default=None)
parser.add_argument("--access-token", help="huggingface access token for restricted models", default=None)
parser.add_argument("--output-dir", help="local directory to output model to", default=None)
parser.add_argument("--max-shard-size", help="maximum size of safetensor shard", default="2GB")
args = parser.parse_args()
Expand All @@ -44,8 +44,8 @@
if not model_dir_fpath.exists():
from huggingface_hub import login, snapshot_download

if args.auth_token:
login(token=args.auth_token)
if args.access_token:
login(token=args.access_token)
snapshot_download(
args.model_id,
local_dir=model_dir_fpath,
Expand Down

0 comments on commit 02881d9

Please sign in to comment.