Skip to content

Commit

Permalink
feat: bump accelerate to 1.0.0 (#739)
Browse files Browse the repository at this point in the history
**Reason for Change**:
<!-- What does this PR improve or fix in Kaito? Why is it needed? -->

feat: bump accelerate to 1.0.0

- add chat_template to tuning code

Signed-off-by: jerryzhuang <[email protected]>
  • Loading branch information
zhuangqh authored Dec 3, 2024
1 parent 1b440d5 commit 5269bd7
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 22 deletions.
12 changes: 9 additions & 3 deletions .github/workflows/preset-image-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ on:
type: boolean
default: false
description: "Run all models for build"

force-run-all-public:
type: boolean
default: false
description: "Run all public models for build"
env:
GO_VERSION: "1.22"
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
Expand All @@ -49,14 +52,17 @@ jobs:

- name: Set FORCE_RUN_ALL Flag
id: set_force_run_all
run: echo "FORCE_RUN_ALL=${{ github.event_name == 'workflow_dispatch' && github.event.inputs.force-run-all == 'true' }}" >> $GITHUB_OUTPUT

run: |
echo "FORCE_RUN_ALL=${{ github.event_name == 'workflow_dispatch' && github.event.inputs.force-run-all == 'true' }}" >> $GITHUB_OUTPUT
echo "FORCE_RUN_ALL_PUBLIC=${{ github.event_name == 'workflow_dispatch' && github.event.inputs.force-run-all-public == 'true' }}" >> $GITHUB_OUTPUT
# This script should output a JSON array of model names
- name: Determine Affected Models
id: affected_models
run: |
PR_BRANCH=${{ env.BRANCH_NAME }} \
FORCE_RUN_ALL=${{ steps.set_force_run_all.outputs.FORCE_RUN_ALL }} \
FORCE_RUN_ALL_PUBLIC=${{ steps.set_force_run_all.outputs.FORCE_RUN_ALL_PUBLIC }} \
python3 .github/workflows/kind-cluster/determine_models.py
- name: Print Determined Models
Expand Down
5 changes: 3 additions & 2 deletions docker/presets/models/tfs/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ ARG VERSION
# Set the working directory
WORKDIR /workspace

# Model weights
COPY ${WEIGHTS_PATH} /workspace/weights

COPY kaito/presets/workspace/dependencies/requirements.txt /workspace/requirements.txt

RUN pip install --no-cache-dir -r /workspace/requirements.txt
Expand All @@ -26,8 +29,6 @@ COPY kaito/presets/workspace/inference/vllm/inference_api.py /workspace/vllm/inf
# Chat template
ADD kaito/presets/workspace/inference/chat_templates /workspace/chat_templates

# Model weights
COPY ${WEIGHTS_PATH} /workspace/weights
RUN echo $VERSION > /workspace/version.txt && \
ln -s /workspace/weights /workspace/tfs/weights && \
ln -s /workspace/weights /workspace/vllm/weights
4 changes: 2 additions & 2 deletions presets/workspace/dependencies/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

# Core Dependencies
vllm==0.6.3
transformers >= 4.45.0
transformers == 4.45.0
torch==2.4.0
accelerate==0.30.1
accelerate==1.0.0
fastapi>=0.111.0,<0.112.0 # Allow patch updates
pydantic>=2.9
uvicorn[standard]>=0.29.0,<0.30.0 # Allow patch updates
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def load_chat_template(chat_template: Optional[str]) -> Optional[str]:
resolved_chat_template = Path(chat_template).read_text()

logger.info("Chat template loaded successfully")
logger.info("Chat template: %s", resolved_chat_template)
logger.info("Chat template:\n%s", resolved_chat_template)
return resolved_chat_template


Expand Down
9 changes: 7 additions & 2 deletions presets/workspace/tuning/text-generation/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
# Licensed under the MIT license.
import os
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum, auto
from typing import Any, Dict, List, Optional

import torch
Expand Down Expand Up @@ -78,6 +76,7 @@ class ModelConfig:
load_in_8bit: bool = field(default=False, metadata={"help": "Load model in 8-bit mode"})
torch_dtype: Optional[str] = field(default=None, metadata={"help": "The torch dtype for the pre-trained model"})
device_map: str = field(default="auto", metadata={"help": "The device map for the pre-trained model"})
chat_template: Optional[str] = field(default=None, metadata={"help": "The file path to the chat template, or the template in single-line form for the specified model"})

def __post_init__(self):
"""
Expand All @@ -89,6 +88,12 @@ def __post_init__(self):
elif not isinstance(self.torch_dtype, torch.dtype):
raise ValueError(f"Invalid torch dtype: {self.torch_dtype}")

def get_tokenizer_args(self):
return {k: v for k, v in self.__dict__.items() if k not in ["torch_dtype", "chat_template"]}

def get_model_args(self):
return {k: v for k, v in self.__dict__.items() if k not in ["chat_template"]}

@dataclass
class QuantizationConfig(BitsAndBytesConfig):
"""
Expand Down
21 changes: 12 additions & 9 deletions presets/workspace/tuning/text-generation/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,26 @@
# Licensed under the MIT license.
import logging
import os
import sys
from dataclasses import asdict
from datetime import datetime
from parser import parse_configs
from parser import parse_configs, load_chat_template

import torch
import transformers
from accelerate import Accelerator
from dataset import DatasetManager
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig, HfArgumentParser, Trainer,
TrainerCallback, TrainerControl, TrainerState,
TrainingArguments)
BitsAndBytesConfig,
TrainerCallback, TrainerControl, TrainerState)
from trl import SFTTrainer

# Initialize logger
logger = logging.getLogger(__name__)
debug_mode = os.environ.get('DEBUG_MODE', 'false').lower() == 'true'
logging.basicConfig(level=logging.DEBUG if debug_mode else logging.INFO)
logging.basicConfig(
level=logging.DEBUG if debug_mode else logging.INFO,
format='%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s',
datefmt='%m-%d %H:%M:%S')

CONFIG_YAML = os.environ.get('YAML_FILE_PATH', '/mnt/config/training_config.yaml')
parsed_configs = parse_configs(CONFIG_YAML)
Expand All @@ -36,7 +36,7 @@
accelerator = Accelerator()

# Load Model Args
model_args = asdict(model_config)
model_args = model_config.get_model_args()
if accelerator.distributed_type != "NO": # Meaning we require distributed training
logger.debug("Setting device map for distributed training")
model_args["device_map"] = {"": accelerator.process_index}
Expand All @@ -47,10 +47,13 @@
enable_qlora = bnb_config.is_quantizable()

# Load the Pre-Trained Tokenizer
tokenizer_args = {key: value for key, value in model_args.items() if key != "torch_dtype"}
tokenizer_args = model_config.get_tokenizer_args()
resovled_chat_template = load_chat_template(model_config.chat_template)
tokenizer = AutoTokenizer.from_pretrained(**tokenizer_args)
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
if resovled_chat_template is not None:
tokenizer.chat_template = resovled_chat_template
if dc_args.mlm and tokenizer.mask_token is None:
logger.warning(
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
Expand Down
24 changes: 21 additions & 3 deletions presets/workspace/tuning/text-generation/parser.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import sys
from collections import defaultdict
import logging
from dataclasses import asdict, fields
import codecs
from pathlib import Path
from typing import Optional

import yaml
from cli import (DatasetConfig, ExtDataCollator, ExtLoraConfig, ModelConfig, QuantizationConfig)
from transformers import HfArgumentParser, TrainingArguments

logger = logging.getLogger(__name__)

# Mapping from config section names to data classes
CONFIG_CLASS_MAP = {
'ModelConfig': ModelConfig,
Expand Down Expand Up @@ -69,3 +72,18 @@ def parse_configs(config_yaml):
parsed_configs[section_name] = CONFIG_CLASS_MAP[section_name](**filtered_config)

return parsed_configs

def load_chat_template(chat_template: Optional[str]) -> Optional[str]:
logger.info(chat_template)
if chat_template is None:
return None

JINJA_CHARS = "{}\n"
if any(c in chat_template for c in JINJA_CHARS):
resolved_chat_template = codecs.decode(chat_template, "unicode_escape")
else:
resolved_chat_template = Path(chat_template).read_text()

logger.info("Chat template loaded successfully")
logger.info("Chat template:\n%s", resolved_chat_template)
return resolved_chat_template

0 comments on commit 5269bd7

Please sign in to comment.