-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
2,120 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Lora-enhanced distillation on Stable Diffusion model | ||
|
||
This repository contains the implementation of Lora-enhanced distillation applied to the Stable Diffusion (SD) model. By combining the LoRA technique with distillation, we've achieved remarkable results, including a significant reduction in inference time and a 50% decrease in memory consumption. Importantly, this integration of LoRA-enhanced distillation maintains image quality and alignment with the provided prompt. For additional details on this work, please consult our technical report [TODO: add link]. | ||
|
||
In this implementation, we have adapted the dreambooth finetuning [code](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#dreambooth-training-example) as our baseline. Below, you'll find information regarding input data, training, and inference. | ||
|
||
## Installation | ||
|
||
You need to have huggingface [diffusers](https://github.com/huggingface/diffusers) installed on your machine. Then install the requirements: | ||
|
||
<pre> | ||
pip install -r requirements.txt | ||
</pre> | ||
|
||
## Training | ||
|
||
### Training Data | ||
Our training data includes a significant dataset of pre-generated images by [SD](https://github.com/poloclub/diffusiondb). You are not required to download the input data. Instead, you can specify or modify it within the training code (`train_sd_distill_lora.py`) as needed.To train the model, follow these steps: | ||
|
||
### Training Script | ||
|
||
1. Run the `mytrainbash.sh` file. | ||
2. The finetuned model will be saved inside the output directory. | ||
|
||
Here's an example command to run the training script: | ||
|
||
<pre> | ||
bash mytrainbash.sh | ||
</pre> | ||
|
||
Make sure to customize the training parameters in the script to suit your specific requirements. | ||
|
||
## Inference | ||
|
||
For inference, you can use the `inf-loop.py` Python code. Follow these steps: | ||
|
||
1. Provide your desired prompts as input in the script. | ||
2. Run the `inf_txt2img_loop.py` script. | ||
|
||
Here's an example command to run the inference script: | ||
|
||
<pre> | ||
deepspeed inf_txt2img_loop.py | ||
</pre> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import deepspeed | ||
import torch | ||
import os | ||
from local_pipeline_stable_diffusion import StableDiffusionPipeline | ||
from diffusers import StableDiffusionPipeline as StableDiffusionPipelineBaseline | ||
import argparse | ||
|
||
seed = 123450011 | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--ft_model", default="new_sd-distill-v21-10k-1e", type=str, help="Path to the fine-tuned model") | ||
parser.add_argument("--b_model", default="stabilityai/stable-diffusion-2-1-base", type=str, help="Path to the baseline model") | ||
parser.add_argument("--out_dir", default="image_out/", type=str, help="Path to the generated images") | ||
parser.add_argument('--guidance_scale', type=float, default=7.5, help='Guidance Scale') | ||
parser.add_argument("--use_local_pipe", action='store_true', help="Use local SD pipeline") | ||
parser.add_argument("--local_rank", type=int, default=int(os.getenv("LOCAL_RANK", "0")), help="local rank") | ||
args = parser.parse_args() | ||
|
||
|
||
local_rank = int(os.getenv("LOCAL_RANK", "0")) | ||
device = torch.device(f"cuda:{local_rank}") | ||
world_size = int(os.getenv('WORLD_SIZE', '1')) | ||
|
||
|
||
if not os.path.exists(args.out_dir): | ||
os.makedirs(args.out_dir) | ||
print(f"Directory '{args.out_dir}' has been created to store the generated images.") | ||
else: | ||
print(f"Directory '{args.out_dir}' already exists and stores the generated images.") | ||
|
||
|
||
prompts = ["A boy is watching TV", | ||
"A photo of a person dancing in the rain", | ||
"A photo of a boy jumping over a fence", | ||
"A photo of a boy is kicking a ball", | ||
"A beach with a lot of waves on it", | ||
"A road that is going down a hill", | ||
"3d rendering of 5 tennis balls on top of a cake", | ||
"A person holding a drink of soda", | ||
"A person is squeezing a lemon", | ||
"A person holding a cat"] | ||
|
||
|
||
for prompt in prompts: | ||
#--- new image | ||
pipe_new = StableDiffusionPipeline.from_pretrained(args.ft_model, torch_dtype=torch.float16).to("cuda") | ||
generator = torch.Generator("cuda").manual_seed(seed) | ||
pipe_new = deepspeed.init_inference(pipe_new, mp_size=world_size, dtype=torch.half) | ||
image_new = pipe_new(prompt, num_inference_steps=50, guidance_scale=args.guidance_scale, generator=generator).images[0] | ||
image_new.save(args.out_dir+"/NEW__seed_"+str(seed)+"_"+prompt[0:100]+".png") | ||
|
||
#--- baseline image | ||
pipe_baseline = StableDiffusionPipelineBaseline.from_pretrained(args.b_model, torch_dtype=torch.float16).to("cuda") | ||
generator = torch.Generator("cuda").manual_seed(seed) | ||
pipe_baseline = deepspeed.init_inference(pipe_baseline, mp_size=world_size, dtype=torch.half) | ||
image_baseline = pipe_baseline(prompt, num_inference_steps=50, guidance_scale=args.guidance_scale, generator=generator).images[0] | ||
image_baseline.save(args.out_dir+"/BASELINE_seed_"+str(seed)+"_"+prompt[0:100]+".png") |
Oops, something went wrong.