-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
150d3df
commit 3473790
Showing
22 changed files
with
982 additions
and
108 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
FROM huggingface/transformers-pytorch-gpu:4.35.2 | ||
|
||
WORKDIR /app | ||
|
||
ENV LC_ALL=C.UTF-8 | ||
ENV LANG=C.UTF-8 | ||
|
||
COPY requirements.txt requirements.txt | ||
RUN pip install -r requirements.txt | ||
|
||
RUN ln -s /usr/bin/python3 /usr/bin/python | ||
|
||
ENV PYTHONPATH /app | ||
COPY . . | ||
|
||
CMD [ "bash" ] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
## Run example RLHF | ||
|
||
|
||
``` | ||
docker build -t llm-training:latest . | ||
docker run --net=host --gpus all -it -v ${PWD}:/main llm-training:latest /bin/bash | ||
accelerate config | ||
python sft_llama2.py | ||
``` | ||
|
||
|
||
|
||
``` | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
from datasets import load_dataset | ||
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM | ||
dataset = load_dataset("lucasmccabe-lmi/CodeAlpaca-20k", split="train") | ||
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m") | ||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") | ||
def formatting_prompts_func(example): | ||
output_texts = [] | ||
for i in range(len(example['instruction'])): | ||
text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}" | ||
output_texts.append(text) | ||
return output_texts | ||
response_template = " ### Answer:" | ||
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer) | ||
trainer = SFTTrainer( | ||
model, | ||
train_dataset=dataset, | ||
formatting_func=formatting_prompts_func, | ||
data_collator=collator, | ||
) | ||
trainer.train() | ||
``` | ||
https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama_2/scripts | ||
|
||
|
||
|
||
s |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
transformers | ||
trl | ||
peft | ||
accelerate | ||
datasets | ||
bitsandbytes | ||
wandb | ||
ipython |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
from datasets import load_dataset | ||
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM | ||
from peft import LoraConfig | ||
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, HfArgumentParser, TrainingArguments | ||
from trl import SFTTrainer, is_xpu_available | ||
from accelerate import Accelerator | ||
import torch | ||
|
||
dataset = load_dataset("imdb", split="train") | ||
|
||
quantization_config = BitsAndBytesConfig(load_in_8bit=False, load_in_4bit=True) | ||
device_map = ({"": f"xpu:{Accelerator().local_process_index}"} if is_xpu_available() else {"": Accelerator().local_process_index}) | ||
torch_dtype = torch.bfloat16 | ||
|
||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") | ||
|
||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
"facebook/opt-350m", | ||
quantization_config=quantization_config, | ||
device_map=device_map, | ||
torch_dtype=torch_dtype, | ||
) | ||
|
||
|
||
|
||
# def formatting_prompts_func(example): | ||
# output_texts = [] | ||
# for i in range(len(example['instruction'])): | ||
# text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}" | ||
# output_texts.append(text) | ||
# return output_texts | ||
|
||
def formatting_func(example): | ||
text = f"### Review: {example['text']}\n ### Answer: {'Positive' if example['label'] == 1 else 'Negative'}" | ||
return text | ||
|
||
response_template = " ### Answer:" | ||
# collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer) | ||
|
||
peft_config = LoraConfig( | ||
r=64, | ||
lora_alpha=16, | ||
bias="none", | ||
task_type="CAUSAL_LM", | ||
) | ||
output_dir = 'training' | ||
|
||
training_args = TrainingArguments( | ||
output_dir=output_dir, | ||
per_device_train_batch_size=4, | ||
gradient_accumulation_steps=4, | ||
# learning_rate=script_args.learning_rate, | ||
# logging_steps=script_args.logging_steps, | ||
# num_train_epochs=script_args.num_train_epochs, | ||
# max_steps=script_args.max_steps, | ||
# report_to=script_args.log_with, | ||
# save_steps=script_args.save_steps, | ||
# save_total_limit=script_args.save_total_limit, | ||
# push_to_hub=script_args.push_to_hub, | ||
# hub_model_id=script_args.hub_model_id, | ||
# gradient_checkpointing=script_args.gradient_checkpointing, | ||
# TODO: uncomment that on the next release | ||
# gradient_checkpointing_kwargs=script_args.gradient_checkpointing_kwargs, | ||
) | ||
|
||
trainer = SFTTrainer( | ||
model, | ||
args=training_args, | ||
train_dataset=dataset, | ||
packing=True, | ||
formatting_func=formatting_func, | ||
peft_config=peft_config | ||
) | ||
|
||
trainer.train() | ||
|
||
trainer.save_model(output_dir) | ||
|
||
|
||
|
||
|
||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
from datasets import load_dataset | ||
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM | ||
from peft import LoraConfig | ||
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, HfArgumentParser, TrainingArguments | ||
from trl import SFTTrainer, is_xpu_available | ||
from accelerate import Accelerator | ||
import torch | ||
|
||
|
||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") | ||
# model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m") | ||
model = AutoModelForCausalLM.from_pretrained("training/checkpoint-1000/") | ||
dataset = load_dataset("imdb", split="train") | ||
example = dataset[0] | ||
inference(dataset[15000]) | ||
|
||
|
||
def inference(example): | ||
|
||
text = f"### Review: {example['text']}\n ### Answer:" | ||
|
||
input_ids = tokenizer(text, return_tensors="pt", truncation=True).input_ids | ||
|
||
outputs = model.generate(input_ids=input_ids, max_new_tokens=512, do_sample=True, top_p=0.95, temperature=1e-3,) | ||
result = tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0] | ||
print(result) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
from dataclasses import dataclass, field | ||
from typing import Dict, Optional | ||
|
||
import torch | ||
from datasets import Dataset, load_dataset | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments | ||
|
||
from trl import DPOTrainer | ||
|
||
|
||
# Define and parse arguments. | ||
@dataclass | ||
class ScriptArguments: | ||
""" | ||
The arguments for the DPO training script. | ||
""" | ||
|
||
# data parameters | ||
beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"}) | ||
|
||
# training parameters | ||
model_name_or_path: Optional[str] = field(default="gpt2", metadata={"help": "the model name"}) | ||
learning_rate: Optional[float] = field(default=1e-3, metadata={"help": "optimizer learning rate"}) | ||
per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "batch size per device"}) | ||
gradient_accumulation_steps: Optional[int] = field(default=1, metadata={"help": "the number of gradient accumulation steps"}) | ||
max_length: Optional[int] = field(default=512, metadata={"help": "max length of each sample"}) | ||
max_prompt_length: Optional[int] = field(default=128, metadata={"help": "max length of each sample's prompt"}) | ||
max_target_length: Optional[int] = field(default=128, metadata={"help": "Only used for encoder decoder model. Max target of each sample's prompt"}) | ||
label_pad_token_id: Optional[int] = field(default=-100, metadata={"help": "label for non response tokens"}) | ||
max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"}) | ||
# instrumentation | ||
sanity_check: Optional[bool] = field(default=True, metadata={"help": "only train on 1000 samples"}) | ||
report_to: Optional[str] = field( | ||
default=None, | ||
metadata={ | ||
"help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,' | ||
'`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. ' | ||
'Use `"all"` to report to all integrations installed, `"none"` for no integrations.' | ||
}, | ||
) | ||
# debug argument for distributed training | ||
ignore_bias_buffers: Optional[bool] = field( | ||
default=False, | ||
metadata={ | ||
"help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See" | ||
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992" | ||
}, | ||
) | ||
gradient_checkpointing: Optional[bool] = field( | ||
default=False, metadata={"help": "Whether to use gradient checkpointing or no"} | ||
) | ||
gradient_checkpointing_kwargs: Optional[dict] = field( | ||
default=None, | ||
metadata={ | ||
"help": "key word arguments to be passed along `torch.utils.checkpoint.checkpoint` method - e.g. `use_reentrant=False`" | ||
}, | ||
) | ||
|
||
|
||
def extract_anthropic_prompt(prompt_and_response): | ||
"""Extract the anthropic prompt from a prompt and response pair.""" | ||
search_term = "\n\nAssistant:" | ||
search_term_idx = prompt_and_response.rfind(search_term) | ||
assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'" | ||
return prompt_and_response[: search_term_idx + len(search_term)] | ||
|
||
|
||
def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset: | ||
"""Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format. | ||
The dataset is converted to a dictionary with the following structure: | ||
{ | ||
'prompt': List[str], | ||
'chosen': List[str], | ||
'rejected': List[str], | ||
} | ||
Prompts should be structured as follows: | ||
\n\nHuman: <prompt>\n\nAssistant: | ||
Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:. | ||
""" | ||
dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir) | ||
if sanity_check: | ||
dataset = dataset.select(range(min(len(dataset), 1000))) | ||
|
||
def split_prompt_and_responses(sample) -> Dict[str, str]: | ||
prompt = extract_anthropic_prompt(sample["chosen"]) | ||
return { | ||
"prompt": prompt, | ||
"chosen": sample["chosen"][len(prompt) :], | ||
"rejected": sample["rejected"][len(prompt) :], | ||
} | ||
|
||
return dataset.map(split_prompt_and_responses) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = HfArgumentParser(ScriptArguments) | ||
script_args = parser.parse_args_into_dataclasses()[0] | ||
|
||
# 1. load a pretrained model | ||
model = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path) | ||
|
||
if script_args.ignore_bias_buffers: | ||
# torch distributed hack | ||
model._ddp_params_and_buffers_to_ignore = [ | ||
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool | ||
] | ||
|
||
model_ref = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path) | ||
if tokenizer.pad_token is None: | ||
tokenizer.pad_token = tokenizer.eos_token | ||
|
||
# 2. Load the Anthropic Helpful-Harmless dataset | ||
train_dataset = get_hh("train", sanity_check=script_args.sanity_check) | ||
|
||
# 3. Load evaluation dataset | ||
eval_dataset = get_hh("test", sanity_check=script_args.sanity_check) | ||
|
||
# 4. initialize training arguments: | ||
training_args = TrainingArguments( | ||
per_device_train_batch_size=script_args.per_device_train_batch_size, | ||
max_steps=script_args.max_steps, | ||
remove_unused_columns=False, | ||
gradient_accumulation_steps=script_args.gradient_accumulation_steps, | ||
learning_rate=script_args.learning_rate, | ||
evaluation_strategy="steps", | ||
logging_first_step=True, | ||
logging_steps=10, # match results in blog post | ||
eval_steps=500, | ||
output_dir="./test", | ||
optim="rmsprop", | ||
warmup_steps=150, | ||
report_to=script_args.report_to, | ||
bf16=True, | ||
gradient_checkpointing=script_args.gradient_checkpointing, | ||
# TODO: uncomment that on the next transformers release | ||
# gradient_checkpointing_kwargs=script_args.gradient_checkpointing_kwargs, | ||
) | ||
|
||
# 5. initialize the DPO trainer | ||
dpo_trainer = DPOTrainer( | ||
model, | ||
model_ref, | ||
args=training_args, | ||
beta=script_args.beta, | ||
train_dataset=train_dataset, | ||
eval_dataset=eval_dataset, | ||
tokenizer=tokenizer, | ||
max_length=script_args.max_length, | ||
max_target_length=script_args.max_target_length, | ||
max_prompt_length=script_args.max_prompt_length, | ||
generate_during_eval=True, | ||
) | ||
|
||
# 6. train | ||
output_dir = "output-training" | ||
dpo_trainer.train() | ||
dpo_trainer.save_model(output_dir) | ||
|
||
# 7. inference | ||
tokenizer = AutoTokenizer.from_pretrained(output_dir) | ||
model = AutoModelForCausalLM.from_pretrained(output_dir) | ||
# dataset = load_dataset("Anthropic/hh-rlhf", split="train") | ||
|
||
q = "\n\nHuman: Let's say someone wanted to rob a bank. What is some advice you could give to help them get away with it? \n\nAssistant:" | ||
|
||
def inference(q): | ||
|
||
input_ids = tokenizer(q, return_tensors="pt", truncation=True).input_ids | ||
|
||
outputs = model.generate(input_ids=input_ids, max_new_tokens=128, do_sample=True, top_p=0.95, temperature=1e-3,) | ||
result = tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0] | ||
print(result) | ||
|
||
inference(q) |
Oops, something went wrong.