Skip to content

Commit

Permalink
feat: optimize cuda for vapoursynth
Browse files Browse the repository at this point in the history
  • Loading branch information
Tohrusky committed Dec 10, 2024
1 parent abe3997 commit a1b86d4
Showing 1 changed file with 37 additions and 26 deletions.
63 changes: 37 additions & 26 deletions ccrestoration/vs/vsr.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from threading import Lock
from typing import Any, Callable, Dict, Union

import torch
Expand Down Expand Up @@ -65,27 +66,32 @@ def inference_vsr_multi_frame_out(

cache: Dict[int, torch.Tensor] = {}

lock = Lock()

def _inference(n: int, f: list[vs.VideoFrame]) -> vs.VideoFrame:
if n not in cache:
cache.clear()
with lock:
if n not in cache:
cache.clear()

img = []
for i in range(length):
index = n + i
if index >= clip.num_frames:
img.append(frame_to_tensor(clip.get_frame(clip.num_frames - 1), device=device).unsqueeze(0))
img = []
for i in range(length):
index = n + i
if index >= clip.num_frames:
img.append(frame_to_tensor(clip.get_frame(clip.num_frames - 1), device=device).unsqueeze(0))

else:
img.append(frame_to_tensor(clip.get_frame(n + i), device=device).unsqueeze(0))
else:
img.append(frame_to_tensor(clip.get_frame(n + i), device=device).unsqueeze(0))

img = torch.stack(img, dim=1)
img = torch.stack(img, dim=1)

output = inference(img)
output = inference(img)

for i in range(output.shape[0]):
cache[n + i] = output[0, i, :, :, :]

for i in range(output.shape[0]):
cache[n + i] = output[0, i, :, :, :]
res = tensor_to_frame(cache[n], f[1].copy())

return tensor_to_frame(cache[n], f[1].copy())
return res

new_clip = clip.std.BlankClip(width=clip.width * scale, height=clip.height * scale, keep=True)
return new_clip.std.ModifyFrame([clip, new_clip], _inference)
Expand Down Expand Up @@ -117,24 +123,29 @@ def inference_vsr_one_frame_out(
if length % 2 == 0:
raise ValueError("The length of the input frames should be odd")

lock = Lock()

def _inference(n: int, f: list[vs.VideoFrame]) -> vs.VideoFrame:
img = []
for i in range(length):
index = i - length // 2 + n
if index < 0:
img.append(frame_to_tensor(clip.get_frame(0), device=device).unsqueeze(0))
with lock:
img = []
for i in range(length):
index = i - length // 2 + n
if index < 0:
img.append(frame_to_tensor(clip.get_frame(0), device=device).unsqueeze(0))

elif index >= clip.num_frames:
img.append(frame_to_tensor(clip.get_frame(clip.num_frames - 1), device=device).unsqueeze(0))

elif index >= clip.num_frames:
img.append(frame_to_tensor(clip.get_frame(clip.num_frames - 1), device=device).unsqueeze(0))
else:
img.append(frame_to_tensor(clip.get_frame(index), device=device).unsqueeze(0))

else:
img.append(frame_to_tensor(clip.get_frame(index), device=device).unsqueeze(0))
img = torch.stack(img, dim=1)

img = torch.stack(img, dim=1)
output = inference(img)

output = inference(img)
res = tensor_to_frame(output[0, 0, :, :, :], f[1].copy())

return tensor_to_frame(output[0, 0, :, :, :], f[1].copy())
return res

new_clip = clip.std.BlankClip(width=clip.width * scale, height=clip.height * scale, keep=True)
return new_clip.std.ModifyFrame([clip, new_clip], _inference)

0 comments on commit a1b86d4

Please sign in to comment.