Skip to content

Commit

Permalink
Merge pull request #7 from ernestchu/main
Browse files Browse the repository at this point in the history
refector and change meshgrid implementation
  • Loading branch information
ernestchu authored Sep 18, 2023
2 parents 2f1c9ee + 7d83e0e commit d79b98f
Showing 1 changed file with 70 additions and 65 deletions.
135 changes: 70 additions & 65 deletions diffusers-0.20.0/src/diffusers/pipelines/medm/pipeline_medm.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,7 @@ def __call__(
mixer = None,
flows: Optional[torch.FloatTensor] = None,
occlusions: Optional[torch.BoolTensor] = None,
encoded_pixels: Optional[torch.IntTensor] = None,
backward_coding: bool = True,
no_control_from_step: int = 999,
no_mix_from_step: int = 999,
Expand Down Expand Up @@ -997,7 +998,7 @@ def __exit__(self, _, __, ___):
# 7.5. CUSTOM. Prepare pixel mixer based on given flows and occlusions
if harmonization_scale > 0:
with Timer('coding'):
mixer = mixer or FlowCoding(flows, occlusions, backward_coding)
mixer = mixer or FlowCoding(flows, occlusions, backward_coding, encoded_pixels)
mixer.to(device=device, dtype=controlnet.dtype) # MEMSAVE

# 8. Denoising loop
Expand Down Expand Up @@ -1180,81 +1181,85 @@ def __exit__(self, _, __, ___):


class FlowCoding(torch.nn.Module):
def __init__(self, flows, occlusions, backward_coding):
def __init__(self, flows, occlusions, backward_coding, encoded_pixels):
'''
backward flow coding:
required: optical flow map, occlusion mask, target shape
return: number of unique pixels, encoded pixels (index to the unique pixels)
'''
super().__init__()
# unbatched (batched inference has not implemented)
flows = flows[0]
occlusions = occlusions[0]

# flip along temporal dimension for backward coding
if backward_coding:
flows = flows.flip(0)
occlusions = occlusions.flip(0)

# coordinate matrix
shape = occlusions.shape[1:]
meshgrid = torch.from_numpy(
np.array([i for i in np.ndindex(*shape)]).reshape(*shape, len(shape)))

# start encoding
encoded_pixels = []
n_pixels = shape[0] * shape[1]

enc = torch.arange(n_pixels).view(shape)
encoded_pixels.append(enc)

for i in range(len(flows)):
flow = flows[i]
occlusion = occlusions[i]
# prepare ingredients
prev_enc = enc
enc = torch.zeros_like(enc)

unfulfilled_mask = torch.ones_like(occlusion)

dest_float = flow + meshgrid
dest = dest_float.round().to(int)

# discard out-of-range
valid_mask = ((dest >= 0) & (dest < torch.tensor(dest.shape[:2]))).all(-1)
# maskout where long range flow already handled
valid_mask = torch.logical_and(valid_mask, unfulfilled_mask)
v_dest = dest[valid_mask]
v_grid = meshgrid[valid_mask]

# get the common pixels from the previous frame
enc[v_grid[:, 0], v_grid[:, 1]] = prev_enc[v_dest[:, 0], v_dest[:, 1]]
# set the warped pixels to fulfilled
unfulfilled_mask = torch.logical_and(unfulfilled_mask, ~valid_mask)
# unset the occluded pixels
unfulfilled_mask = torch.logical_or(unfulfilled_mask, occlusion)

# novel pixels = occlusions + invalid meshgrid (long and short)
# add novel pixels to global pixels
novel_px = meshgrid[unfulfilled_mask]

offset = n_pixels
enc[novel_px[:, 0], novel_px[:, 1]] = torch.arange(len(novel_px)) + offset
if encoded_pixels is None:
# unbatched (batched inference has not implemented)
flows = flows[0]
occlusions = occlusions[0]

# flip along temporal dimension for backward coding
if backward_coding:
flows = flows.flip(0)
occlusions = occlusions.flip(0)

n_pixels += len(novel_px)
# coordinate matrix
shape = occlusions.shape[1:]
meshgrid = torch.stack(torch.where(torch.ones(shape))).T.view(*shape, -1)

# start encoding
encoded_pixels = []
n_pixels = shape[0] * shape[1]

enc = torch.arange(n_pixels).view(shape)
encoded_pixels.append(enc)

for i in range(len(flows)):
flow = flows[i]
occlusion = occlusions[i]
# prepare ingredients
prev_enc = enc
enc = torch.zeros_like(enc)

unfulfilled_mask = torch.ones_like(occlusion)

dest_float = flow + meshgrid
dest = dest_float.round().to(int)

# discard out-of-range
valid_mask = ((dest >= 0) & (dest < torch.tensor(dest.shape[:2]))).all(-1)
# maskout where long range flow already handled
valid_mask = torch.logical_and(valid_mask, unfulfilled_mask)
v_dest = dest[valid_mask]
v_grid = meshgrid[valid_mask]

# get the common pixels from the previous frame
enc[v_grid[:, 0], v_grid[:, 1]] = prev_enc[v_dest[:, 0], v_dest[:, 1]]
# set the warped pixels to fulfilled
unfulfilled_mask = torch.logical_and(unfulfilled_mask, ~valid_mask)
# unset the occluded pixels
unfulfilled_mask = torch.logical_or(unfulfilled_mask, occlusion)

# novel pixels = occlusions + invalid meshgrid (long and short)
# add novel pixels to global pixels
novel_px = meshgrid[unfulfilled_mask]

offset = n_pixels
enc[novel_px[:, 0], novel_px[:, 1]] = torch.arange(len(novel_px)) + offset

n_pixels += len(novel_px)
encoded_pixels.append(enc)


encoded_pixels = torch.stack(encoded_pixels)
if backward_coding:
# flip back to original temporal order
encoded_pixels = encoded_pixels.flip(0)

# add a dummy batch dimension for future batched implementation
self.encoded_pixels = encoded_pixels[None]


encoded_pixels = torch.stack(encoded_pixels)
if backward_coding:
# flip back to original temporal order
encoded_pixels = encoded_pixels.flip(0)
else:
self.encoded_pixels = encoded_pixels

self.shape = shape
self.shape = self.encoded_pixels.shape[-2:]
# n_pixels is the sum of the number of unique pixels of the entire batch
self.n_pixels = n_pixels
# add a dummy batch dimension for future batched implementation
self.encoded_pixels = encoded_pixels[None]
self.n_pixels = int(self.encoded_pixels.max() + 1)

# init computation buffers
self.register_buffer('values', torch.zeros(self.n_pixels, 3))
Expand Down

0 comments on commit d79b98f

Please sign in to comment.