Skip to content

Commit

Permalink
Adding LoRA-Distillation SD training example (#788)
Browse files Browse the repository at this point in the history
Co-authored-by: Xiaoxia (Shirley) Wu <[email protected]>
  • Loading branch information
PareesaMS and xiaoxiawu-microsoft authored Dec 4, 2023
1 parent b116838 commit 0e10c4b
Show file tree
Hide file tree
Showing 6 changed files with 2,120 additions and 0 deletions.
44 changes: 44 additions & 0 deletions training/stable_diffusion/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Lora-enhanced distillation on Stable Diffusion model

This repository contains the implementation of Lora-enhanced distillation applied to the Stable Diffusion (SD) model. By combining the LoRA technique with distillation, we've achieved remarkable results, including a significant reduction in inference time and a 50% decrease in memory consumption. Importantly, this integration of LoRA-enhanced distillation maintains image quality and alignment with the provided prompt. For additional details on this work, please consult our technical report [TODO: add link].

In this implementation, we have adapted the dreambooth finetuning [code](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#dreambooth-training-example) as our baseline. Below, you'll find information regarding input data, training, and inference.

## Installation

You need to have huggingface [diffusers](https://github.com/huggingface/diffusers) installed on your machine. Then install the requirements:

<pre>
pip install -r requirements.txt
</pre>

## Training

### Training Data
Our training data includes a significant dataset of pre-generated images by [SD](https://github.com/poloclub/diffusiondb). You are not required to download the input data. Instead, you can specify or modify it within the training code (`train_sd_distill_lora.py`) as needed.To train the model, follow these steps:

### Training Script

1. Run the `mytrainbash.sh` file.
2. The finetuned model will be saved inside the output directory.

Here's an example command to run the training script:

<pre>
bash mytrainbash.sh
</pre>

Make sure to customize the training parameters in the script to suit your specific requirements.

## Inference

For inference, you can use the `inf-loop.py` Python code. Follow these steps:

1. Provide your desired prompts as input in the script.
2. Run the `inf_txt2img_loop.py` script.

Here's an example command to run the inference script:

<pre>
deepspeed inf_txt2img_loop.py
</pre>
56 changes: 56 additions & 0 deletions training/stable_diffusion/inf_txt2img_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import deepspeed
import torch
import os
from local_pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers import StableDiffusionPipeline as StableDiffusionPipelineBaseline
import argparse

seed = 123450011
parser = argparse.ArgumentParser()
parser.add_argument("--ft_model", default="new_sd-distill-v21-10k-1e", type=str, help="Path to the fine-tuned model")
parser.add_argument("--b_model", default="stabilityai/stable-diffusion-2-1-base", type=str, help="Path to the baseline model")
parser.add_argument("--out_dir", default="image_out/", type=str, help="Path to the generated images")
parser.add_argument('--guidance_scale', type=float, default=7.5, help='Guidance Scale')
parser.add_argument("--use_local_pipe", action='store_true', help="Use local SD pipeline")
parser.add_argument("--local_rank", type=int, default=int(os.getenv("LOCAL_RANK", "0")), help="local rank")
args = parser.parse_args()


local_rank = int(os.getenv("LOCAL_RANK", "0"))
device = torch.device(f"cuda:{local_rank}")
world_size = int(os.getenv('WORLD_SIZE', '1'))


if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir)
print(f"Directory '{args.out_dir}' has been created to store the generated images.")
else:
print(f"Directory '{args.out_dir}' already exists and stores the generated images.")


prompts = ["A boy is watching TV",
"A photo of a person dancing in the rain",
"A photo of a boy jumping over a fence",
"A photo of a boy is kicking a ball",
"A beach with a lot of waves on it",
"A road that is going down a hill",
"3d rendering of 5 tennis balls on top of a cake",
"A person holding a drink of soda",
"A person is squeezing a lemon",
"A person holding a cat"]


for prompt in prompts:
#--- new image
pipe_new = StableDiffusionPipeline.from_pretrained(args.ft_model, torch_dtype=torch.float16).to("cuda")
generator = torch.Generator("cuda").manual_seed(seed)
pipe_new = deepspeed.init_inference(pipe_new, mp_size=world_size, dtype=torch.half)
image_new = pipe_new(prompt, num_inference_steps=50, guidance_scale=args.guidance_scale, generator=generator).images[0]
image_new.save(args.out_dir+"/NEW__seed_"+str(seed)+"_"+prompt[0:100]+".png")

#--- baseline image
pipe_baseline = StableDiffusionPipelineBaseline.from_pretrained(args.b_model, torch_dtype=torch.float16).to("cuda")
generator = torch.Generator("cuda").manual_seed(seed)
pipe_baseline = deepspeed.init_inference(pipe_baseline, mp_size=world_size, dtype=torch.half)
image_baseline = pipe_baseline(prompt, num_inference_steps=50, guidance_scale=args.guidance_scale, generator=generator).images[0]
image_baseline.save(args.out_dir+"/BASELINE_seed_"+str(seed)+"_"+prompt[0:100]+".png")
Loading

0 comments on commit 0e10c4b

Please sign in to comment.