forked from Fanghua-Yu/SUPIR
-
Notifications
You must be signed in to change notification settings - Fork 8
/
predict.py
213 lines (193 loc) · 8.42 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md
import os
import subprocess
import time
from omegaconf import OmegaConf
from PIL import Image
from cog import BasePredictor, Input, Path
from SUPIR.util import (
create_SUPIR_model,
PIL2Tensor,
Tensor2PIL,
convert_dtype,
)
from llava.llava_agent import LLavaAgent
import CKPT_PTH
SUPIR_v0Q_URL = "https://weights.replicate.delivery/default/SUPIR-v0Q.ckpt"
SUPIR_v0F_URL = "https://weights.replicate.delivery/default/SUPIR-v0F.ckpt"
LLAVA_URL = "https://weights.replicate.delivery/default/llava-v1.5-13b.tar"
LLAVA_CLIP_URL = (
"https://weights.replicate.delivery/default/clip-vit-large-patch14-336.tar"
)
SDXL_URL = "https://weights.replicate.delivery/default/stable-diffusion-xl-base-1.0/sd_xl_base_1.0_0.9vae.safetensors"
SDXL_CLIP1_URL = "https://weights.replicate.delivery/default/clip-vit-large-patch14.tar"
SDXL_CLIP2_URL = (
"https://weights.replicate.delivery/default/CLIP-ViT-bigG-14-laion2B-39B-b160k.tar"
)
MODEL_CACHE = "/opt/data/private/AIGC_pretrain/" # Follow the default in CKPT_PTH.py
LLAVA_CLIP_PATH = CKPT_PTH.LLAVA_CLIP_PATH
LLAVA_MODEL_PATH = CKPT_PTH.LLAVA_MODEL_PATH
SDXL_CLIP1_PATH = CKPT_PTH.SDXL_CLIP1_PATH
SDXL_CLIP2_CACHE = f"{MODEL_CACHE}/models--laion--CLIP-ViT-bigG-14-laion2B-39B-b160k"
SDXL_CKPT = f"{MODEL_CACHE}/SDXL_cache/sd_xl_base_1.0_0.9vae.safetensors"
SUPIR_CKPT_F = f"{MODEL_CACHE}/SUPIR_cache/SUPIR-v0F.ckpt"
SUPIR_CKPT_Q = f"{MODEL_CACHE}/SUPIR_cache/SUPIR-v0Q.ckpt"
def download_weights(url, dest, extract=True):
start = time.time()
print("downloading url: ", url)
print("downloading to: ", dest)
args = ["pget"]
if extract:
args.append("-x")
subprocess.check_call(args + [url, dest], close_fds=False)
print("downloading took: ", time.time() - start)
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
for model_dir in [
MODEL_CACHE,
f"{MODEL_CACHE}/SUPIR_cache",
f"{MODEL_CACHE}/SDXL_cache",
]:
if not os.path.exists(model_dir):
os.makedirs(model_dir)
if not os.path.exists(SUPIR_CKPT_Q):
download_weights(SUPIR_v0Q_URL, SUPIR_CKPT_Q, extract=False)
if not os.path.exists(SUPIR_CKPT_F):
download_weights(SUPIR_v0F_URL, SUPIR_CKPT_F, extract=False)
if not os.path.exists(LLAVA_MODEL_PATH):
download_weights(LLAVA_URL, LLAVA_MODEL_PATH)
if not os.path.exists(LLAVA_CLIP_PATH):
download_weights(LLAVA_CLIP_URL, LLAVA_CLIP_PATH)
if not os.path.exists(SDXL_CLIP1_PATH):
download_weights(SDXL_CLIP1_URL, SDXL_CLIP1_PATH)
if not os.path.exists(SDXL_CKPT):
download_weights(SDXL_URL, SDXL_CKPT, extract=False)
if not os.path.exists(SDXL_CKPT):
download_weights(SDXL_CLIP2_URL, SDXL_CKPT)
self.supir_device = "cuda:0"
self.llava_device = "cuda:0"
ae_dtype = "bf16" # Inference data type of AutoEncoder
diff_dtype = "bf16" # Inference data type of Diffusion
self.models = {
k: create_SUPIR_model("options/SUPIR_v0.yaml", SUPIR_sign=k).to(
self.supir_device
)
for k in ["Q", "F"]
}
for k in ["Q", "F"]:
self.models[k].ae_dtype = convert_dtype(ae_dtype)
self.models[k].model.dtype = convert_dtype(diff_dtype)
# load LLaVA
self.llava_agent = LLavaAgent(LLAVA_MODEL_PATH, device=self.llava_device)
def predict(
self,
model_name: str = Input(
description="Choose a model. SUPIR-v0Q is the default training settings with paper. SUPIR-v0F is high generalization and high image quality in most cases. Training with light degradation settings. Stage1 encoder of SUPIR-v0F remains more details when facing light degradations.",
choices=["SUPIR-v0Q", "SUPIR-v0F"],
default="SUPIR-v0Q",
),
image: Path = Input(description="Low quality input image."),
upscale: int = Input(
description="Upsampling ratio of given inputs.", default=1
),
min_size: float = Input(
description="Minimum resolution of output images.", default=1024
),
edm_steps: int = Input(
description="Number of steps for EDM Sampling Schedule.",
ge=1,
le=500,
default=50,
),
use_llava: bool = Input(
description="Use LLaVA model to get captions.", default=True
),
a_prompt: str = Input(
description="Additive positive prompt for the inputs.",
default="Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore detailing, hyper sharpness, perfect without deformations.",
),
n_prompt: str = Input(
description="Negative prompt for the inputs.",
default="painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, CG Style, 3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, signature, jpeg artifacts, deformed, lowres, over-smooth",
),
color_fix_type: str = Input(
description="Color Fixing Type..",
choices=["None", "AdaIn", "Wavelet"],
default="Wavelet",
),
s_stage1: int = Input(
description="Control Strength of Stage1 (negative means invalid).",
default=-1,
),
s_churn: float = Input(
description="Original churn hy-param of EDM.", default=5
),
s_noise: float = Input(
description="Original noise hy-param of EDM.", default=1.003
),
s_cfg: float = Input(
description=" Classifier-free guidance scale for prompts.",
ge=1,
le=20,
default=7.5,
),
s_stage2: float = Input(description="Control Strength of Stage2.", default=1.0),
linear_CFG: bool = Input(
description="Linearly (with sigma) increase CFG from 'spt_linear_CFG' to s_cfg.",
default=False,
),
linear_s_stage2: bool = Input(
description="Linearly (with sigma) increase s_stage2 from 'spt_linear_s_stage2' to s_stage2.",
default=False,
),
spt_linear_CFG: float = Input(
description="Start point of linearly increasing CFG.", default=1.0
),
spt_linear_s_stage2: float = Input(
description="Start point of linearly increasing s_stage2.", default=0.0
),
seed: int = Input(
description="Random seed. Leave blank to randomize the seed", default=None
),
) -> Path:
"""Run a single prediction on the model"""
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
model = self.models["Q"] if model_name == "SUPIR-v0Q" else self.models["F"]
lq_img = Image.open(str(image))
lq_img, h0, w0 = PIL2Tensor(lq_img, upsacle=upscale, min_size=min_size)
lq_img = lq_img.unsqueeze(0).to(self.supir_device)[:, :3, :, :]
# step 1: Pre-denoise for LLaVA)
clean_imgs = model.batchify_denoise(lq_img)
clean_PIL_img = Tensor2PIL(clean_imgs[0], h0, w0)
# step 2: LLaVA
captions = [""]
if use_llava:
captions = self.llava_agent.gen_image_caption([clean_PIL_img])
print(f"Captions from LLaVA: {captions}")
# step 3: Diffusion Process
samples = model.batchify_sample(
lq_img,
captions,
num_steps=edm_steps,
restoration_scale=s_stage1,
s_churn=s_churn,
s_noise=s_noise,
cfg_scale=s_cfg,
control_scale=s_stage2,
seed=seed,
num_samples=1,
p_p=a_prompt,
n_p=n_prompt,
color_fix_type=color_fix_type,
use_linear_CFG=linear_CFG,
use_linear_control_scale=linear_s_stage2,
cfg_scale_start=spt_linear_CFG,
control_scale_start=spt_linear_s_stage2,
)
out_path = "/tmp/out.png"
Tensor2PIL(samples[0], h0, w0).save(out_path)
return Path(out_path)