Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mask as optional input to UltimateSD Sampler #96

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ def __init__(
uniform_tile_mode,
tiled_decode,
custom_sampler=None,
custom_sigmas=None
custom_sigmas=None,
mask=None,
):
# Variables used by the USDU script
self.init_images = [init_img]
self.image_mask = None
self.mask = mask
self.mask_blur = 0
self.inpaint_full_res_padding = 0
self.width = init_img.width
Expand Down Expand Up @@ -150,6 +152,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if p.mask_blur > 0:
image_mask = image_mask.filter(ImageFilter.GaussianBlur(p.mask_blur))


# Crop the images to get the tiles that will be used for generation
tiles = [img.crop(crop_region) for img in shared.batch]

Expand All @@ -168,7 +171,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
# Encode the image
batched_tiles = torch.cat([pil_to_tensor(tile) for tile in tiles], dim=0)
(latent,) = p.vae_encoder.encode(p.vae, batched_tiles)

if p.mask is not None:
latent["noise_mask"] = p.mask.reshape((-1, 1, p.mask.shape[-2], p.mask.shape[-1]))
# Generate samples
samples = sample(p.model, p.seed, p.steps, p.cfg, p.sampler_name, p.scheduler, positive_cropped,
negative_cropped, latent, p.denoise, p.custom_sampler, p.custom_sigmas)
Expand Down
47 changes: 25 additions & 22 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def USDU_base_inputs():
("tiled_decode", ("BOOLEAN", {"default": False})),
]

optional = []
optional = [("mask", ("MASK", {"default": None})), ]

return required, optional

Expand Down Expand Up @@ -104,8 +104,8 @@ def upscale(self, image, model, positive, negative, vae, upscale_by, seed,
steps, cfg, sampler_name, scheduler, denoise, upscale_model,
mode_type, tile_width, tile_height, mask_blur, tile_padding,
seam_fix_mode, seam_fix_denoise, seam_fix_mask_blur,
seam_fix_width, seam_fix_padding, force_uniform_tiles, tiled_decode,
custom_sampler=None, custom_sigmas=None):
seam_fix_width, seam_fix_padding, force_uniform_tiles, tiled_decode,
custom_sampler=None, custom_sigmas=None, mask=None):
# Store params
self.tile_width = tile_width
self.tile_height = tile_height
Expand Down Expand Up @@ -136,7 +136,7 @@ def upscale(self, image, model, positive, negative, vae, upscale_by, seed,
self.sdprocessing = StableDiffusionProcessing(
tensor_to_pil(image), model, positive, negative, vae,
seed, steps, cfg, sampler_name, scheduler, denoise, upscale_by, force_uniform_tiles, tiled_decode,
custom_sampler, custom_sigmas
custom_sampler, custom_sigmas, mask
)

# Disable logging
Expand All @@ -148,13 +148,15 @@ def upscale(self, image, model, positive, negative, vae, upscale_by, seed,
# Running the script
#
script = usdu.Script()
processed = script.run(p=self.sdprocessing, _=None, tile_width=self.tile_width, tile_height=self.tile_height,
mask_blur=self.mask_blur, padding=self.tile_padding, seams_fix_width=self.seam_fix_width,
seams_fix_denoise=self.seam_fix_denoise, seams_fix_padding=self.seam_fix_padding,
upscaler_index=0, save_upscaled_image=False, redraw_mode=MODES[self.mode_type],
save_seams_fix_image=False, seams_fix_mask_blur=self.seam_fix_mask_blur,
seams_fix_type=SEAM_FIX_MODES[self.seam_fix_mode], target_size_type=2,
custom_width=None, custom_height=None, custom_scale=self.upscale_by)
processed = script.run(p=self.sdprocessing, _=None, tile_width=self.tile_width,
tile_height=self.tile_height,
mask_blur=self.mask_blur, padding=self.tile_padding,
seams_fix_width=self.seam_fix_width,
seams_fix_denoise=self.seam_fix_denoise, seams_fix_padding=self.seam_fix_padding,
upscaler_index=0, save_upscaled_image=False, redraw_mode=MODES[self.mode_type],
save_seams_fix_image=False, seams_fix_mask_blur=self.seam_fix_mask_blur,
seams_fix_type=SEAM_FIX_MODES[self.seam_fix_mode], target_size_type=2,
custom_width=None, custom_height=None, custom_scale=self.upscale_by)

# Return the resulting images
images = [pil_to_tensor(img) for img in shared.batch]
Expand All @@ -164,6 +166,7 @@ def upscale(self, image, model, positive, negative, vae, upscale_by, seed,
# Restore the original logging level
logger.setLevel(old_level)


class UltimateSDUpscaleNoUpscale(UltimateSDUpscale):
@classmethod
def INPUT_TYPES(s):
Expand All @@ -181,14 +184,15 @@ def upscale(self, upscaled_image, model, positive, negative, vae, seed,
steps, cfg, sampler_name, scheduler, denoise,
mode_type, tile_width, tile_height, mask_blur, tile_padding,
seam_fix_mode, seam_fix_denoise, seam_fix_mask_blur,
seam_fix_width, seam_fix_padding, force_uniform_tiles, tiled_decode):
seam_fix_width, seam_fix_padding, force_uniform_tiles, tiled_decode, mask=None):
upscale_by = 1.0
return super().upscale(upscaled_image, model, positive, negative, vae, upscale_by, seed,
steps, cfg, sampler_name, scheduler, denoise, None,
mode_type, tile_width, tile_height, mask_blur, tile_padding,
seam_fix_mode, seam_fix_denoise, seam_fix_mask_blur,
seam_fix_width, seam_fix_padding, force_uniform_tiles, tiled_decode)

seam_fix_width, seam_fix_padding, force_uniform_tiles, tiled_decode, mask=mask)


class UltimateSDUpscaleCustomSample(UltimateSDUpscale):
@classmethod
def INPUT_TYPES(s):
Expand All @@ -198,7 +202,7 @@ def INPUT_TYPES(s):
optional.append(("custom_sampler", ("SAMPLER",)))
optional.append(("custom_sigmas", ("SIGMAS",)))
return prepare_inputs(required, optional)

RETURN_TYPES = ("IMAGE",)
FUNCTION = "upscale"
CATEGORY = "image/upscaling"
Expand All @@ -208,14 +212,13 @@ def upscale(self, image, model, positive, negative, vae, upscale_by, seed,
mode_type, tile_width, tile_height, mask_blur, tile_padding,
seam_fix_mode, seam_fix_denoise, seam_fix_mask_blur,
seam_fix_width, seam_fix_padding, force_uniform_tiles, tiled_decode,
upscale_model=None,
custom_sampler=None, custom_sigmas=None):
upscale_model=None, custom_sampler=None, custom_sigmas=None, mask=None):
return super().upscale(image, model, positive, negative, vae, upscale_by, seed,
steps, cfg, sampler_name, scheduler, denoise, upscale_model,
mode_type, tile_width, tile_height, mask_blur, tile_padding,
seam_fix_mode, seam_fix_denoise, seam_fix_mask_blur,
seam_fix_width, seam_fix_padding, force_uniform_tiles, tiled_decode,
custom_sampler, custom_sigmas)
steps, cfg, sampler_name, scheduler, denoise, upscale_model,
mode_type, tile_width, tile_height, mask_blur, tile_padding,
seam_fix_mode, seam_fix_denoise, seam_fix_mask_blur,
seam_fix_width, seam_fix_padding, force_uniform_tiles, tiled_decode,
custom_sampler, custom_sigmas, mask)


# A dictionary that contains all nodes you want to export with their names
Expand Down