diff --git a/.github/workflows/CI-test.yml b/.github/workflows/CI-test.yml new file mode 100644 index 0000000..e1d7ab0 --- /dev/null +++ b/.github/workflows/CI-test.yml @@ -0,0 +1,58 @@ +name: CI-test + +env: + GITHUB_ACTIONS: true + +on: + push: + branches: ["main"] + paths-ignore: + - "**.md" + - "LICENSE" + + pull_request: + branches: ["main"] + paths-ignore: + - "**.md" + - "LICENSE" + + workflow_dispatch: + +jobs: + CI: + strategy: + matrix: + os-version: ["ubuntu-20.04", "macos-14", "windows-latest"] + python-version: ["3.9"] + poetry-version: ["1.8.3"] + + runs-on: ${{ matrix.os-version }} + steps: + - uses: actions/checkout@v3 + with: + submodules: recursive + + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - uses: abatilo/actions-poetry@v2 + with: + poetry-version: ${{ matrix.poetry-version }} + + - name: Test + run: | + pip install numpy==1.26.4 + pip install pydantic tenacity opencv-python + pip install pre-commit torch torchvision scikit-image + pip install pytest pytest-cov coverage + pip install mypy ruff types-requests + + make lint + make test + + - name: Codecov + if: matrix.os-version == 'ubuntu-20.04' + uses: codecov/codecov-action@v3 + with: + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/Release-pypi.yml b/.github/workflows/Release-pypi.yml new file mode 100644 index 0000000..1f75436 --- /dev/null +++ b/.github/workflows/Release-pypi.yml @@ -0,0 +1,34 @@ +name: Release-pypi + +on: + workflow_dispatch: + +jobs: + Pypi: + strategy: + matrix: + python-version: ["3.9"] + poetry-version: ["1.8.3"] + + runs-on: ubuntu-20.04 + steps: + - uses: actions/checkout@v3 + with: + submodules: recursive + + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - uses: abatilo/actions-poetry@v2 + with: + poetry-version: ${{ matrix.poetry-version }} + + - name: Build package + run: | + make build + + - name: Publish a Python distribution to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + password: ${{ secrets.PYPI_API }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..24133cb --- /dev/null +++ b/.gitignore @@ -0,0 +1,171 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +*.DS_Store + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ +/.ruff_cache/ + +/ccrestoration/cache_models/*.pth +/ccrestoration/cache_models/*.pt +/ccrestoration/cache_models/*.pkl + +/assets/*_out.jpg + +*.mp4 +*.mkv diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..895e7a9 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,43 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-json + - id: check-yaml + - id: check-xml + - id: check-toml + + # autofix json, yaml, markdown... + - repo: https://github.com/pre-commit/mirrors-prettier + rev: v3.1.0 + hooks: + - id: prettier + + # autofix toml + - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks + rev: v2.12.0 + hooks: + - id: pretty-format-toml + args: [--autofix] + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.6 + hooks: + - id: ruff-format + - id: ruff + args: [--fix, --exit-non-zero-on-fix] + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.7.1 + hooks: + - id: mypy + args: [ccrestoration, tests] + pass_filenames: false + additional_dependencies: + - types-requests + - types-certifi + - pytest + - pydantic + - tenacity diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..1b307b1 --- /dev/null +++ b/Makefile @@ -0,0 +1,19 @@ +.DEFAULT_GOAL := default + +.PHONY: test +test: + poetry run pytest --cov=ccrestoration --cov-report=xml --cov-report=html + +.PHONY: lint +lint: + poetry run pre-commit install + poetry run pre-commit run --all-files + +.PHONY: build +build: + poetry build --format wheel + +.PHONY: vs +vs: + rm -f encoded.mkv + vspipe -c y4m vs.py - | ffmpeg -i - -vcodec libx265 -crf 16 encoded.mkv diff --git a/README.md b/README.md new file mode 100644 index 0000000..3a93c8b --- /dev/null +++ b/README.md @@ -0,0 +1 @@ +# dev diff --git a/assets/test.jpg b/assets/test.jpg new file mode 100644 index 0000000..d325ab8 Binary files /dev/null and b/assets/test.jpg differ diff --git a/ccrestoration/__init__.py b/ccrestoration/__init__.py new file mode 100644 index 0000000..9206e56 --- /dev/null +++ b/ccrestoration/__init__.py @@ -0,0 +1,5 @@ +from ccrestoration.auto import AutoModel, AutoConfig # noqa +from ccrestoration.core.type import BaseConfig, ConfigType, BaseModelInterface, ArchType, ModelType # noqa +from ccrestoration.core.model import MODEL_REGISTRY # noqa +from ccrestoration.core.arch import ARCH_REGISTRY # noqa +from ccrestoration.core.config import CONFIG_REGISTRY # noqa diff --git a/ccrestoration/auto/__init__.py b/ccrestoration/auto/__init__.py new file mode 100644 index 0000000..e299b03 --- /dev/null +++ b/ccrestoration/auto/__init__.py @@ -0,0 +1,2 @@ +from ccrestoration.auto.config import AutoConfig # noqa +from ccrestoration.auto.model import AutoModel # noqa diff --git a/ccrestoration/auto/config.py b/ccrestoration/auto/config.py new file mode 100644 index 0000000..391be78 --- /dev/null +++ b/ccrestoration/auto/config.py @@ -0,0 +1,29 @@ +from typing import Any, Optional, Union + +from ccrestoration.core.config import CONFIG_REGISTRY +from ccrestoration.core.type import BaseConfig, ConfigType + + +class AutoConfig: + @staticmethod + def from_pretrained(pretrained_model_name: Union[ConfigType, str]) -> Any: + """ + Get a config instance of a pretrained model configuration. + + :param pretrained_model_name: The name of the pretrained model configuration + :return: + """ + return CONFIG_REGISTRY.get(pretrained_model_name) + + @staticmethod + def register(config: Union[BaseConfig, Any], name: Optional[str] = None) -> None: + """ + Register the given config class instance under the name BaseConfig.name or the given name. + Can be used as a function call. See docstring of this class for usage. + + :param config: The config class instance to register + :param name: The name to register the config class instance under. If None, use BaseConfig.name + :return: + """ + # used as a function call + CONFIG_REGISTRY.register(obj=config, name=name) diff --git a/ccrestoration/auto/model.py b/ccrestoration/auto/model.py new file mode 100644 index 0000000..e9ebe00 --- /dev/null +++ b/ccrestoration/auto/model.py @@ -0,0 +1,93 @@ +from typing import Any, Optional, Union + +import torch + +from ccrestoration.core.config import CONFIG_REGISTRY +from ccrestoration.core.model import MODEL_REGISTRY +from ccrestoration.core.type import BaseConfig, ConfigType + + +class AutoModel: + @staticmethod + def from_pretrained( + pretrained_model_name: Union[ConfigType, str], + device: Optional[torch.device] = None, + fp16: bool = True, + compile: bool = False, + compile_backend: Optional[str] = None, + ) -> Any: + """ + Get a model instance from a pretrained model name. + + :param pretrained_model_name: The name of the pretrained model. It should be registered in CONFIG_REGISTRY. + :param device: torch.device + :param fp16: Whether to use fp16 precision. + :param compile: Whether to compile the model. + :param compile_backend: The backend to use for compiling the model. + :return: + """ + + config = CONFIG_REGISTRY.get(pretrained_model_name) + return AutoModel.from_config( + config=config, + device=device, + fp16=fp16, + compile=compile, + compile_backend=compile_backend, + ) + + @staticmethod + def from_config( + config: Union[BaseConfig, Any], + device: Optional[torch.device] = None, + fp16: bool = True, + compile: bool = False, + compile_backend: Optional[str] = None, + ) -> Any: + """ + Get a model instance from a config. + + :param config: The config object. It should be registered in CONFIG_REGISTRY. + :param device: torch.device + :param fp16: Whether to use fp16 precision. + :param compile: Whether to compile the model. + :param compile_backend: The backend to use for compiling the model. + :return: + """ + + model = MODEL_REGISTRY.get(config.model) + model = model( + config=config, + device=device, + fp16=fp16, + compile=compile, + compile_backend=compile_backend, + ) + + return model + + @staticmethod + def register(obj: Optional[Any] = None, name: Optional[str] = None) -> Any: + """ + Register the given object under the name `obj.__name__` or the given name. + Can be used as either a decorator or not. See docstring of this class for usage. + + :param obj: The object to register. If None, this is being used as a decorator. + :param name: The name to register the object under. If None, use `obj.__name__`. + :return: + """ + if obj is None: + # used as a decorator + def deco(func_or_class: Any) -> Any: + _name = name + if _name is None: + _name = func_or_class.__name__ + MODEL_REGISTRY.register(obj=func_or_class, name=_name) + return func_or_class + + return deco + + # used as a function call + if name is None: + name = obj.__name__ + MODEL_REGISTRY.register(obj=obj, name=name) diff --git a/ccrestoration/cache_models/README.md b/ccrestoration/cache_models/README.md new file mode 100644 index 0000000..b855590 --- /dev/null +++ b/ccrestoration/cache_models/README.md @@ -0,0 +1 @@ +## download models to cache diff --git a/ccrestoration/cache_models/__init__.py b/ccrestoration/cache_models/__init__.py new file mode 100644 index 0000000..6cde809 --- /dev/null +++ b/ccrestoration/cache_models/__init__.py @@ -0,0 +1,76 @@ +import hashlib +import os +from pathlib import Path +from typing import Optional + +from tenacity import retry, stop_after_attempt, stop_after_delay, wait_random +from torch.hub import download_url_to_file + +from ccrestoration.core.type import BaseConfig + +CACHE_PATH = Path(__file__).resolve().parent.absolute() + + +def get_file_sha256(file_path: str, blocksize: int = 1 << 20) -> str: + sha256 = hashlib.sha256() + with open(file_path, "rb") as f: + while True: + data = f.read(blocksize) + if not data: + break + sha256.update(data) + return sha256.hexdigest() + + +def load_file_from_url( + config: BaseConfig, force_download: bool = False, progress: bool = True, model_dir: Optional[str] = None +) -> str: + """ + Load file form http url, will download models if necessary. + + Reference: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + + :param config: The config object. + :param force_download: Whether to force download the file. + :param progress: Whether to show the download progress. + :param model_dir: The path to save the downloaded model. Should be a full path. If None, use default cache path. + :return: + """ + + if model_dir is None: + model_dir = str(CACHE_PATH) + + cached_file_path = os.path.abspath(os.path.join(model_dir, config.name)) + + if not os.path.exists(cached_file_path) or force_download: + print(f"Downloading: {config.url} to {cached_file_path}\n") + + @retry(wait=wait_random(min=3, max=5), stop=stop_after_delay(10) | stop_after_attempt(30)) + def _download() -> None: + try: + download_url_to_file(str(config.url), cached_file_path, hash_prefix=None, progress=progress) + except Exception as e: + print(f"Download failed: {e}, retrying...") + raise e + + _download() + + if config.hash is not None: + get_hash = get_file_sha256(cached_file_path) + if get_hash != config.hash: + raise ValueError( + f"File {cached_file_path} hash mismatched with config hash {config.hash}, compare with {get_hash}" + ) + + return cached_file_path + + +if __name__ == "__main__": + # get all model files sha256 + for root, _, files in os.walk(CACHE_PATH): + for file in files: + if not file.endswith(".pth") and not file.endswith(".pt"): + continue + file_path = os.path.join(root, file) + name = os.path.basename(file_path) + print(f"{name}: {get_file_sha256(file_path)}") diff --git a/ccrestoration/core/__init__.py b/ccrestoration/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ccrestoration/core/arch/__init__.py b/ccrestoration/core/arch/__init__.py new file mode 100644 index 0000000..c536e70 --- /dev/null +++ b/ccrestoration/core/arch/__init__.py @@ -0,0 +1,6 @@ +from ccrestoration.utils.registry import Registry + +ARCH_REGISTRY: Registry = Registry("ARCH") + +from ccrestoration.core.arch.rrdb_arch import RRDBNet # noqa +from ccrestoration.core.arch.srvgg_arch import SRVGGNetCompact # noqa diff --git a/ccrestoration/core/arch/arch_util.py b/ccrestoration/core/arch/arch_util.py new file mode 100644 index 0000000..d96d93a --- /dev/null +++ b/ccrestoration/core/arch/arch_util.py @@ -0,0 +1,281 @@ +# type: ignore +import collections.abc +import math +import warnings +from itertools import repeat + +import torch +from torch import nn as nn +from torch.nn import functional as F, init as init +from torch.nn.modules.batchnorm import _BatchNorm + + +@torch.no_grad() +def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): + """Initialize network weights. + + Args: + module_list (list[nn.Module] | nn.Module): Modules to be initialized. + scale (float): Scale initialized weights, especially for residual + blocks. Default: 1. + bias_fill (float): The value to fill bias. Default: 0 + kwargs (dict): Other arguments for initialization function. + """ + if not isinstance(module_list, list): + module_list = [module_list] + for module in module_list: + for m in module.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, _BatchNorm): + init.constant_(m.weight, 1) + if m.bias is not None: + m.bias.data.fill_(bias_fill) + + +def make_layer(basic_block, num_basic_block, **kwarg): + """Make layers by stacking the same blocks. + + Args: + basic_block (nn.module): nn.module class for basic block. + num_basic_block (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_basic_block): + layers.append(basic_block(**kwarg)) + return nn.Sequential(*layers) + + +class ResidualBlockNoBN(nn.Module): + """Residual block without BN. + + Args: + num_feat (int): Channel number of intermediate features. + Default: 64. + res_scale (float): Residual scale. Default: 1. + pytorch_init (bool): If set to True, use pytorch default init, + otherwise, use default_init_weights. Default: False. + """ + + def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): + super(ResidualBlockNoBN, self).__init__() + self.res_scale = res_scale + self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.relu = nn.ReLU(inplace=True) + + if not pytorch_init: + default_init_weights([self.conv1, self.conv2], 0.1) + + def forward(self, x): + identity = x + out = self.conv2(self.relu(self.conv1(x))) + return identity + out * self.res_scale + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f"scale {scale} is not supported. Supported scales: 2^n and 3.") + super(Upsample, self).__init__(*m) + + +def flow_warp(x, flow, interp_mode="bilinear", padding_mode="zeros", align_corners=True): + """Warp an image or feature map with optical flow. + + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2), normal value. + interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. + padding_mode (str): 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Before pytorch 1.3, the default value is + align_corners=True. After pytorch 1.3, the default value is + align_corners=False. Here, we use the True as default. + + Returns: + Tensor: Warped image or feature map. + """ + assert x.size()[-2:] == flow.size()[1:3] + _, _, h, w = x.size() + # create mesh grid + grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + + vgrid = grid + flow + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners) + + # TODO, what if align_corners=False + return output + + +def resize_flow(flow, size_type, sizes, interp_mode="bilinear", align_corners=False): + """Resize a flow according to ratio or shape. + + Args: + flow (Tensor): Precomputed flow. shape [N, 2, H, W]. + size_type (str): 'ratio' or 'shape'. + sizes (list[int | float]): the ratio for resizing or the final output + shape. + 1) The order of ratio should be [ratio_h, ratio_w]. For + downsampling, the ratio should be smaller than 1.0 (i.e., ratio + < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e., + ratio > 1.0). + 2) The order of output_size should be [out_h, out_w]. + interp_mode (str): The mode of interpolation for resizing. + Default: 'bilinear'. + align_corners (bool): Whether align corners. Default: False. + + Returns: + Tensor: Resized flow. + """ + _, _, flow_h, flow_w = flow.size() + if size_type == "ratio": + output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1]) + elif size_type == "shape": + output_h, output_w = sizes[0], sizes[1] + else: + raise ValueError(f"Size type should be ratio or shape, but got type {size_type}.") + + input_flow = flow.clone() + ratio_h = output_h / flow_h + ratio_w = output_w / flow_w + input_flow[:, 0, :, :] *= ratio_w + input_flow[:, 1, :, :] *= ratio_h + resized_flow = F.interpolate( + input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners + ) + return resized_flow + + +# TODO: may write a cpp file +def pixel_unshuffle(x, scale): + """Pixel unshuffle. + + Args: + x (Tensor): Input feature with shape (b, c, hh, hw). + scale (int): Downsample ratio. + + Returns: + Tensor: the pixel unshuffled feature. + """ + b, c, hh, hw = x.size() + out_channel = c * (scale**2) + assert hh % scale == 0 and hw % scale == 0 + h = hh // scale + w = hw // scale + x_view = x.view(b, c, h, scale, w, scale) + return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + low = norm_cdf((a - mean) / std) + up = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [low, up], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * low - 1, 2 * up - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. + + From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py + + The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +# From PyTorch +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple diff --git a/ccrestoration/core/arch/rrdb_arch.py b/ccrestoration/core/arch/rrdb_arch.py new file mode 100644 index 0000000..a92f256 --- /dev/null +++ b/ccrestoration/core/arch/rrdb_arch.py @@ -0,0 +1,122 @@ +# type: ignore +import torch +from torch import nn as nn +from torch.nn import functional as F + +from ccrestoration.core.arch import ARCH_REGISTRY +from ccrestoration.core.type import ArchType + +from .arch_util import default_init_weights, make_layer, pixel_unshuffle + + +@ARCH_REGISTRY.register(name=ArchType.RRDB) +class RRDBNet(nn.Module): + """Networks consisting of Residual in Residual Dense Block, which is used + in ESRGAN. + + ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. + + We extend ESRGAN for scale x2 and scale x1. + Note: This is one option for scale 1, scale 2 in RRDBNet. + We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size + and enlarge the channel size before feeding inputs into the main ESRGAN architecture. + + Args: + num_in_ch (int): Channel number of inputs. + num_out_ch (int): Channel number of outputs. + num_feat (int): Channel number of intermediate features. + Default: 64 + num_block (int): Block number in the trunk network. Defaults: 23 + num_grow_ch (int): Channels for each growth. Default: 32. + """ + + def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32): + super(RRDBNet, self).__init__() + self.scale = scale + if scale == 2: + num_in_ch = num_in_ch * 4 + elif scale == 1: + num_in_ch = num_in_ch * 16 + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) + self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + # upsample + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + if self.scale == 2: + feat = pixel_unshuffle(x, scale=2) + elif self.scale == 1: + feat = pixel_unshuffle(x, scale=4) + else: + feat = x + feat = self.conv_first(feat) + body_feat = self.conv_body(self.body(feat)) + feat = feat + body_feat + # upsample + feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode="nearest"))) + feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode="nearest"))) + out = self.conv_last(self.lrelu(self.conv_hr(feat))) + return out + + +class RRDB(nn.Module): + """Residual in Residual Dense Block. + + Used in RRDB-Net in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat, num_grow_ch=32): + super(RRDB, self).__init__() + self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) + + def forward(self, x): + out = self.rdb1(x) + out = self.rdb2(out) + out = self.rdb3(out) + # Empirically, we use 0.2 to scale the residual for better performance + return out * 0.2 + x + + +class ResidualDenseBlock(nn.Module): + """Residual Dense Block. + + Used in RRDB block in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat=64, num_grow_ch=32): + super(ResidualDenseBlock, self).__init__() + self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) + self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + # Empirically, we use 0.2 to scale the residual for better performance + return x5 * 0.2 + x diff --git a/ccrestoration/core/arch/srvgg_arch.py b/ccrestoration/core/arch/srvgg_arch.py new file mode 100644 index 0000000..2c9be5f --- /dev/null +++ b/ccrestoration/core/arch/srvgg_arch.py @@ -0,0 +1,72 @@ +# type: ignore +from torch import nn as nn +from torch.nn import functional as F + +from ccrestoration.core.arch import ARCH_REGISTRY +from ccrestoration.core.type import ArchType + + +@ARCH_REGISTRY.register(name=ArchType.SRVGG) +class SRVGGNetCompact(nn.Module): + """A compact VGG-style network structure for super-resolution. + + It is a compact network structure, which performs upsampling in the last layer and no convolution is + conducted on the HR feature space. + + Args: + num_in_ch (int): Channel number of inputs. Default: 3. + num_out_ch (int): Channel number of outputs. Default: 3. + num_feat (int): Channel number of intermediate features. Default: 64. + num_conv (int): Number of convolution layers in the body network. Default: 16. + upscale (int): Upsampling factor. Default: 4. + act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu. + """ + + def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type="prelu"): + super(SRVGGNetCompact, self).__init__() + self.num_in_ch = num_in_ch + self.num_out_ch = num_out_ch + self.num_feat = num_feat + self.num_conv = num_conv + self.upscale = upscale + self.act_type = act_type + + self.body = nn.ModuleList() + # the first conv + self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)) + # the first activation + if act_type == "relu": + activation = nn.ReLU(inplace=True) + elif act_type == "prelu": + activation = nn.PReLU(num_parameters=num_feat) + elif act_type == "leakyrelu": + activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.body.append(activation) + + # the body structure + for _ in range(num_conv): + self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1)) + # activation + if act_type == "relu": + activation = nn.ReLU(inplace=True) + elif act_type == "prelu": + activation = nn.PReLU(num_parameters=num_feat) + elif act_type == "leakyrelu": + activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.body.append(activation) + + # the last conv + self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1)) + # upsample + self.upsampler = nn.PixelShuffle(upscale) + + def forward(self, x): + out = x + for i in range(0, len(self.body)): + out = self.body[i](out) + + out = self.upsampler(out) + # add the nearest upsampled image, so that the network learns the residual + base = F.interpolate(x, scale_factor=self.upscale, mode="nearest") + out += base + return out diff --git a/ccrestoration/core/config/__init__.py b/ccrestoration/core/config/__init__.py new file mode 100644 index 0000000..17a2d68 --- /dev/null +++ b/ccrestoration/core/config/__init__.py @@ -0,0 +1,5 @@ +from ccrestoration.utils.registry import RegistryConfigInstance + +CONFIG_REGISTRY: RegistryConfigInstance = RegistryConfigInstance("CONFIG") + +from ccrestoration.core.config.realesrgan_config import RealESRGANConfig # noqa diff --git a/ccrestoration/core/config/realesrgan_config.py b/ccrestoration/core/config/realesrgan_config.py new file mode 100644 index 0000000..f46931f --- /dev/null +++ b/ccrestoration/core/config/realesrgan_config.py @@ -0,0 +1,92 @@ +from pydantic import field_validator + +from ccrestoration.core.config import CONFIG_REGISTRY +from ccrestoration.core.type import ArchType, BaseConfig, ConfigType, ModelType + + +class RealESRGANConfig(BaseConfig): + scale: int = 2 + num_in_ch: int = 3 + num_out_ch: int = 3 + num_feat: int = 64 + num_block: int = 23 + num_grow_ch: int = 32 + num_conv: int = 16 + act_type: str = "prelu" + + @field_validator("arch") + def arch_match(cls, v: str) -> str: + if v not in [ArchType.RRDB, ArchType.SRVGG]: + raise ValueError("real esrgan arch must be one of 'RRDB', 'SRVGG'") + return v + + @field_validator("act_type") + def act_type_match(cls, v: str) -> str: + if v not in ["relu", "prelu", "leakyrelu"]: + raise ValueError("act_type must be one of 'relu', 'prelu', 'leakyrelu'") + return v + + +RealESRGANConfigs = [ + RealESRGANConfig( + name=ConfigType.RealESRGAN_RealESRGAN_x4plus_4x, + url="https://github.com/TensoRaws/ccrestoration/releases/download/model_zoo/RealESRGAN_RealESRGAN_x4plus_4x.pth", + hash="4fa0d38905f75ac06eb49a7951b426670021be3018265fd191d2125df9d682f1", + arch=ArchType.RRDB, + model=ModelType.RealESRGAN, + scale=4, + ), + RealESRGANConfig( + name=ConfigType.RealESRGAN_RealESRGAN_x4plus_anime_6B_4x, + url="https://github.com/TensoRaws/ccrestoration/releases/download/model_zoo/RealESRGAN_RealESRGAN_x4plus_anime_6B_4x.pth", + hash="f872d837d3c90ed2e05227bed711af5671a6fd1c9f7d7e91c911a61f155e99da", + arch=ArchType.RRDB, + model=ModelType.RealESRGAN, + scale=4, + num_block=6, + ), + RealESRGANConfig( + name=ConfigType.RealESRGAN_RealESRGAN_x2plus_2x, + url="https://github.com/TensoRaws/ccrestoration/releases/download/model_zoo/RealESRGAN_RealESRGAN_x2plus_2x.pth", + hash="49fafd45f8fd7aa8d31ab2a22d14d91b536c34494a5cfe31eb5d89c2fa266abb", + arch=ArchType.RRDB, + model=ModelType.RealESRGAN, + scale=2, + num_block=23, + ), + RealESRGANConfig( + name=ConfigType.RealESRGAN_realesr_animevideov3_4x, + url="https://github.com/TensoRaws/ccrestoration/releases/download/model_zoo/RealESRGAN_realesr_animevideov3_4x.pth", + hash="b8a8376811077954d82ca3fcf476f1ac3da3e8a68a4f4d71363008000a18b75d", + arch=ArchType.SRVGG, + model=ModelType.RealESRGAN, + scale=4, + ), + RealESRGANConfig( + name=ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x, + url="https://github.com/TensoRaws/ccrestoration/releases/download/model_zoo/RealESRGAN_AnimeJaNai_HD_V3_Compact_2x.pth", + hash="af7307eee19e5982a8014dd0e4650d3bde2e25aa78d2105a4bdfd947636e4c8f", + arch=ArchType.SRVGG, + model=ModelType.RealESRGAN, + scale=2, + ), + RealESRGANConfig( + name=ConfigType.RealESRGAN_AniScale_2_Compact_2x, + url="https://github.com/TensoRaws/ccrestoration/releases/download/model_zoo/RealESRGAN_AniScale_2_Compact_2x.pth", + hash="916ddf99eac77008834a8aeb3dc74b64b17eee02932c18bca93cfa093106e85d", + arch=ArchType.SRVGG, + model=ModelType.RealESRGAN, + scale=2, + ), + RealESRGANConfig( + name=ConfigType.RealESRGAN_Ani4Kv2_Compact_2x, + url="https://github.com/TensoRaws/ccrestoration/releases/download/model_zoo/RealESRGAN_Ani4Kv2_Compact_2x.pth", + hash="fe99290e9e4f95424219566dbe159184a123587622cc00bc632b1eecbd07d7a4", + arch=ArchType.SRVGG, + model=ModelType.RealESRGAN, + scale=2, + ), +] + +for cfg in RealESRGANConfigs: + CONFIG_REGISTRY.register(cfg) diff --git a/ccrestoration/core/model/__init__.py b/ccrestoration/core/model/__init__.py new file mode 100644 index 0000000..38a2f7a --- /dev/null +++ b/ccrestoration/core/model/__init__.py @@ -0,0 +1,6 @@ +from ccrestoration.utils.registry import Registry + +MODEL_REGISTRY: Registry = Registry("MODEL") + +from ccrestoration.core.model.sr_model import SRBaseModel # noqa +from ccrestoration.core.model.realesrgan_model import RealESRGANModel # noqa diff --git a/ccrestoration/core/model/realesrgan_model.py b/ccrestoration/core/model/realesrgan_model.py new file mode 100644 index 0000000..8b1da08 --- /dev/null +++ b/ccrestoration/core/model/realesrgan_model.py @@ -0,0 +1,44 @@ +from typing import Any + +from ccrestoration.core.arch import RRDBNet, SRVGGNetCompact +from ccrestoration.core.config import RealESRGANConfig +from ccrestoration.core.model import MODEL_REGISTRY +from ccrestoration.core.model.sr_model import SRBaseModel +from ccrestoration.core.type import ArchType, ModelType + + +@MODEL_REGISTRY.register(name=ModelType.RealESRGAN) +class RealESRGANModel(SRBaseModel): + def load_model(self) -> Any: + cfg: RealESRGANConfig = self.config + state_dict = self.get_state_dict() + + if "params_ema" in state_dict: + state_dict = state_dict["params_ema"] + elif "params" in state_dict: + state_dict = state_dict["params"] + + if cfg.arch == ArchType.RRDB: + model = RRDBNet( + num_in_ch=cfg.num_in_ch, + num_out_ch=cfg.num_out_ch, + scale=cfg.scale, + num_feat=cfg.num_feat, + num_block=cfg.num_block, + num_grow_ch=cfg.num_grow_ch, + ) + elif self.config.arch == ArchType.SRVGG: + model = SRVGGNetCompact( + num_in_ch=cfg.num_in_ch, + num_out_ch=cfg.num_out_ch, + upscale=cfg.scale, + num_feat=cfg.num_feat, + num_conv=cfg.num_conv, + act_type=cfg.act_type, + ) + else: + raise NotImplementedError(f"Arch {cfg.arch} is not implemented.") + + model.load_state_dict(state_dict, assign=True) + model.eval().to(self.device) + return model diff --git a/ccrestoration/core/model/sr_model.py b/ccrestoration/core/model/sr_model.py new file mode 100644 index 0000000..c6206d6 --- /dev/null +++ b/ccrestoration/core/model/sr_model.py @@ -0,0 +1,68 @@ +from typing import Any + +import cv2 +import numpy as np +import torch +from torchvision import transforms + +from ccrestoration.cache_models import load_file_from_url +from ccrestoration.core.type import BaseConfig, BaseModelInterface + + +class SRBaseModel(BaseModelInterface): + def get_state_dict(self) -> Any: + """ + Load the state dict of the model from config + + :return: The state dict of the model + """ + cfg: BaseConfig = self.config + + if cfg.path is not None: + model_path = str(cfg.path) + else: + try: + model_path = load_file_from_url(cfg) + except Exception as e: + print(f"Error: {e}, try force download the model...") + model_path = load_file_from_url(cfg, force_download=True) + + return torch.load(model_path, map_location=self.device, weights_only=True) + + @torch.inference_mode() # type: ignore + def inference(self, img: torch.Tensor) -> torch.Tensor: + return self.model(img) + + @torch.inference_mode() # type: ignore + def inference_image(self, img: np.ndarray) -> np.ndarray: + """ + Inference the image(BGR) with the model + + :param img: The input image(BGR), can use cv2 to read the image + :return: + """ + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + img = transforms.ToTensor()(img).unsqueeze(0).to(self.device) + if self.fp16: + img = img.half() + + img = self.inference(img) + img = img.squeeze(0).permute(1, 2, 0).cpu().numpy() + img = (img * 255).clip(0, 255).astype("uint8") + + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + return img + + @torch.inference_mode() # type: ignore + def inference_video(self, clip: Any) -> Any: + """ + Inference the video with the model, the clip should be a vapoursynth clip + + :param clip: vs.VideoNode + :return: + """ + + from ccrestoration.vs import inference_sr + + return inference_sr(self.inference, clip, self.config.scale, self.device) diff --git a/ccrestoration/core/type/__init__.py b/ccrestoration/core/type/__init__.py new file mode 100644 index 0000000..787fa70 --- /dev/null +++ b/ccrestoration/core/type/__init__.py @@ -0,0 +1,5 @@ +from ccrestoration.core.type.arch import ArchType # noqa +from ccrestoration.core.type.base_config import BaseConfig # noqa +from ccrestoration.core.type.base_model import BaseModelInterface # noqa +from ccrestoration.core.type.config import ConfigType # noqa +from ccrestoration.core.type.model import ModelType # noqa diff --git a/ccrestoration/core/type/arch.py b/ccrestoration/core/type/arch.py new file mode 100644 index 0000000..eb93aaf --- /dev/null +++ b/ccrestoration/core/type/arch.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class ArchType(str, Enum): + RRDB = "RRDB" + SRVGG = "SRVGG" diff --git a/ccrestoration/core/type/base_config.py b/ccrestoration/core/type/base_config.py new file mode 100644 index 0000000..f426f58 --- /dev/null +++ b/ccrestoration/core/type/base_config.py @@ -0,0 +1,15 @@ +from typing import Optional, Union + +from pydantic import BaseModel, FileUrl, HttpUrl + +from ccrestoration.core.type.arch import ArchType +from ccrestoration.core.type.model import ModelType + + +class BaseConfig(BaseModel): + name: str + url: Optional[HttpUrl] = None + path: Optional[FileUrl] = None + hash: Optional[str] = None + arch: Union[ArchType, str] + model: Union[ModelType, str] diff --git a/ccrestoration/core/type/base_model.py b/ccrestoration/core/type/base_model.py new file mode 100644 index 0000000..af53bd4 --- /dev/null +++ b/ccrestoration/core/type/base_model.py @@ -0,0 +1,83 @@ +import sys +from abc import ABC, abstractmethod +from typing import Any, Optional + +import torch + +from ccrestoration.utils.device import DEFAULT_DEVICE + + +class BaseModelInterface(ABC): + """ + Base model interface + + :param config: config of the model + :param device: inference device + :param fp16: use fp16 or not + :param compile: use torch.compile or not + :param compile_backend: backend of torch.compile + """ + + def __init__( + self, + config: Any, + device: Optional[torch.device] = None, + fp16: bool = True, + compile: bool = False, + compile_backend: Optional[str] = None, + ) -> None: + self.config = config + self.device: Optional[torch.device] = device + self.fp16: bool = fp16 + self.compile: bool = compile + self.compile_backend: Optional[str] = compile_backend + + if device is None: + self.device = DEFAULT_DEVICE + + self.model: torch.nn.Module = self.load_model() + + # fp16 + if self.fp16: + try: + self.model = self.model.half() + except Exception as e: + print(f"Error: {e}, fp16 is not supported on this model.") + self.fp16 = False + self.model = self.load_model() + + # compile + if self.compile: + try: + if self.compile_backend is None: + if sys.platform == "darwin": + self.compile_backend = "aot_eager" + else: + self.compile_backend = "inductor" + self.model = torch.compile(self.model, backend=self.compile_backend) + except Exception as e: + print(f"Error: {e}, compile is not supported on this model.") + + def get_state_dict(self) -> Any: + raise NotImplementedError + + def load_model(self) -> Any: + raise NotImplementedError + + @abstractmethod + @torch.inference_mode() # type: ignore + def inference(self, img: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + @torch.inference_mode() # type: ignore + def inference_video(self, clip: Any) -> Any: + """ + Inference the video with the model, the clip should be a vapoursynth clip + + :param clip: vs.VideoNode + :return: + """ + raise NotImplementedError + + def __call__(self, img: torch.Tensor) -> torch.Tensor: + return self.inference(img) diff --git a/ccrestoration/core/type/config.py b/ccrestoration/core/type/config.py new file mode 100644 index 0000000..83ec42c --- /dev/null +++ b/ccrestoration/core/type/config.py @@ -0,0 +1,13 @@ +from enum import Enum + + +class ConfigType(str, Enum): + # RealESRGAN + RealESRGAN_RealESRGAN_x4plus_4x = "RealESRGAN_RealESRGAN_x4plus_4x.pth" + RealESRGAN_RealESRGAN_x4plus_anime_6B_4x = "RealESRGAN_RealESRGAN_x4plus_anime_6B_4x.pth" + RealESRGAN_RealESRGAN_x2plus_2x = "RealESRGAN_RealESRGAN_x2plus_2x.pth" + RealESRGAN_realesr_animevideov3_4x = "RealESRGAN_realesr_animevideov3_4x.pth" + + RealESRGAN_AnimeJaNai_HD_V3_Compact_2x = "RealESRGAN_AnimeJaNai_HD_V3_Compact_2x.pth" + RealESRGAN_AniScale_2_Compact_2x = "RealESRGAN_AniScale_2_Compact_2x.pth" + RealESRGAN_Ani4Kv2_Compact_2x = "RealESRGAN_Ani4Kv2_Compact_2x.pth" diff --git a/ccrestoration/core/type/model.py b/ccrestoration/core/type/model.py new file mode 100644 index 0000000..ef0b111 --- /dev/null +++ b/ccrestoration/core/type/model.py @@ -0,0 +1,5 @@ +from enum import Enum + + +class ModelType(str, Enum): + RealESRGAN = "RealESRGAN" diff --git a/ccrestoration/utils/__init__.py b/ccrestoration/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ccrestoration/utils/device.py b/ccrestoration/utils/device.py new file mode 100644 index 0000000..aa98e38 --- /dev/null +++ b/ccrestoration/utils/device.py @@ -0,0 +1,17 @@ +import sys + +import torch + + +def default_device() -> torch.device: + if sys.platform != "darwin": + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + try: + return torch.device("mps" if torch.backends.mps.is_available() else "cpu") + except Exception as e: + print(f"Err: {e}, MPS is not available, use CPU instead.") + return torch.device("cpu") + + +DEFAULT_DEVICE = default_device() diff --git a/ccrestoration/utils/misc.py b/ccrestoration/utils/misc.py new file mode 100644 index 0000000..cd9c7c4 --- /dev/null +++ b/ccrestoration/utils/misc.py @@ -0,0 +1,57 @@ +import os +import random +from os import path as osp +from typing import Generator, Optional, Tuple, Union + +import numpy as np +import torch + + +def set_random_seed(seed: int) -> None: + """Set random seeds.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def scandir( + dir_path: str, suffix: Optional[Union[str, Tuple]] = None, recursive: bool = False, full_path: bool = False +) -> Generator[str, None, None]: + """Scan a directory to find the interested files. + + Args: + dir_path: Path of the directory. + suffix: File suffix that we are interested in. Default: None. + recursive: If set to True, recursively scan the directory. Default: False. + full_path: If set to True, include the dir_path. Default: False. + + Returns: + A generator for all the interested files with relative paths. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(_dir_path: str, _suffix: Optional[Union[str, Tuple]], _recursive: bool) -> Generator[str, None, None]: + for entry in os.scandir(_dir_path): + if not entry.name.startswith(".") and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if _suffix is None: + yield return_path + elif return_path.endswith(_suffix): + yield return_path + else: + if _recursive: + yield from _scandir(entry.path, _suffix=_suffix, _recursive=_recursive) + else: + continue + + return _scandir(dir_path, _suffix=suffix, _recursive=recursive) diff --git a/ccrestoration/utils/registry.py b/ccrestoration/utils/registry.py new file mode 100644 index 0000000..5a246bc --- /dev/null +++ b/ccrestoration/utils/registry.py @@ -0,0 +1,97 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +# pyre-strict +# pyre-ignore-all-errors[2,3] +from typing import Any, Dict, Iterable, Iterator, Optional, Tuple + + +class Registry(Iterable[Tuple[str, Any]]): + """ + The registry that provides name -> object mapping, to support third-party + users' custom modules. + + To create a registry (e.g. a backbone registry): + + .. code-block:: python + + BACKBONE_REGISTRY = Registry('BACKBONE') + + To register an object: + + .. code-block:: python + + @BACKBONE_REGISTRY.register() + class MyBackbone(): + ... + + Or: + + .. code-block:: python + + BACKBONE_REGISTRY.register(MyBackbone) + """ + + def __init__(self, name: str) -> None: + """ + Args: + name (str): the name of this registry + """ + self._name: str = name + self._obj_map: Dict[str, Any] = {} + + def _do_register(self, name: str, obj: Any) -> None: + if name in self._obj_map: + print("An object named '{}' was already registered in '{}' registry!".format(name, self._name)) + else: + self._obj_map[name] = obj + + def register(self, obj: Any = None, name: Optional[str] = None) -> Any: + """ + Register the given object under the name `obj.__name__` or the given name. + Can be used as either a decorator or not. See docstring of this class for usage. + """ + if obj is None: + # used as a decorator + def deco(func_or_class: Any) -> Any: + _name = name + if _name is None: + _name = func_or_class.__name__ + self._do_register(_name, func_or_class) + return func_or_class + + return deco + + # used as a function call + if name is None: + name = obj.__name__ + self._do_register(name, obj) + + def get(self, name: str) -> Any: + ret = self._obj_map.get(name) + if ret is None: + raise KeyError("No object named '{}' found in '{}' registry!".format(name, self._name)) + return ret + + def __contains__(self, name: str) -> bool: + return name in self._obj_map + + def __repr__(self) -> str: + return "Registry of {}\n".format(self._name) + + def __iter__(self) -> Iterator[Tuple[str, Any]]: + return iter(self._obj_map.items()) + + # pyre-fixme[4]: Attribute must be annotated. + __str__ = __repr__ + + +class RegistryConfigInstance(Registry): + def register(self, obj: Any = None, name: Optional[str] = None) -> Any: + """ + Register the given config class instance under the name BaseConfig.name or the given name. + Can be used as a function call. See docstring of this class for usage. + """ + # used as a function call + if name is None: + name = obj.name + self._do_register(name, obj) diff --git a/ccrestoration/vs/__init__.py b/ccrestoration/vs/__init__.py new file mode 100644 index 0000000..5dff30f --- /dev/null +++ b/ccrestoration/vs/__init__.py @@ -0,0 +1 @@ +from ccrestoration.vs.vs import frame_to_tensor, tensor_to_frame, inference_sr # noqa diff --git a/ccrestoration/vs/vs.py b/ccrestoration/vs/vs.py new file mode 100644 index 0000000..9e6fa81 --- /dev/null +++ b/ccrestoration/vs/vs.py @@ -0,0 +1,55 @@ +from typing import Any, Callable, Union + +import numpy as np +import torch +import vapoursynth as vs + + +def frame_to_tensor(frame: vs.VideoFrame, device: torch.device) -> torch.Tensor: + return torch.stack( + [torch.from_numpy(np.asarray(frame[plane])).to(device) for plane in range(frame.format.num_planes)] + ).clamp(0.0, 1.0) + + +def tensor_to_frame(tensor: torch.Tensor, frame: vs.VideoFrame) -> vs.VideoFrame: + array = tensor.squeeze(0).detach().cpu().numpy() + for plane in range(frame.format.num_planes): + np.copyto(np.asarray(frame[plane]), array[plane]) + return frame + + +@torch.inference_mode() # type: ignore +def inference_sr( + inference: Callable[[torch.Tensor], torch.Tensor], + clip: vs.VideoNode, + scale: Union[float, int, Any], + device: torch.device, +) -> vs.VideoNode: + """ + Inference the video with the model, the clip should be a vapoursynth clip + + :param inference: The inference function + :param clip: vs.VideoNode + :param scale: The scale factor + :param device: The device + :return: + """ + + if not isinstance(clip, vs.VideoNode): + raise vs.Error("Only vapoursynth clip is supported") + + if clip.format.id not in [vs.RGBH, vs.RGBS]: + raise vs.Error("Only vs.RGBH and vs.RGBS formats are supported") + + @torch.inference_mode() # type: ignore + def _inference(n: int, f: list[vs.VideoFrame]) -> vs.VideoFrame: + img = frame_to_tensor(f[0], device).unsqueeze(0) + + output = inference(img) + + return tensor_to_frame(output, f[1].copy()) + + new_clip = clip.std.BlankClip(width=clip.width * scale, height=clip.height * scale, keep=True) + return new_clip.std.FrameEval( + lambda n: new_clip.std.ModifyFrame([clip, new_clip], _inference), clip_src=[clip, new_clip] + ) diff --git a/poetry.toml b/poetry.toml new file mode 100644 index 0000000..029a4a5 --- /dev/null +++ b/poetry.toml @@ -0,0 +1,2 @@ +virtualenvs.create = false +virtualenvs.options.system-site-packages = true diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..50e7976 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,89 @@ +[build-system] +build-backend = "poetry-core.masonry.api" +requires = ["poetry-core"] + +[tool.coverage.report] +exclude_also = [ + "raise AssertionError", + "raise NotImplementedError", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", + "except Exception as e" +] + +[tool.coverage.run] + +[tool.mypy] +disable_error_code = "attr-defined" +disallow_any_generics = false +disallow_subclassing_any = false +ignore_missing_imports = true +plugins = ["pydantic.mypy"] +strict = true +warn_return_any = false + +[tool.poetry] +authors = ["Tohrusky"] +classifiers = [ + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12" +] +description = "A library for image restoration with VapourSynth support." +homepage = "https://github.com/TensoRaws/ccrestoration" +license = "MIT" +name = "ccrestoration" +readme = "README.md" +repository = "https://github.com/TensoRaws/ccrestoration" +version = "0.0.2" + +# Requirements +[tool.poetry.dependencies] +opencv-python = "*" +pydantic = "*" +python = "^3.9" +tenacity = "*" + +[tool.poetry.group.dev.dependencies] +numpy = "*" +pre-commit = "^3.7.0" +scikit-image = "*" +torch = "*" +torchvision = "*" + +[tool.poetry.group.test.dependencies] +coverage = "^7.2.0" +pytest = "^8.0" +pytest-cov = "^4.0" + +[tool.poetry.group.typing.dependencies] +mypy = "^1.8.0" +ruff = "^0.3.7" +types-requests = "^2.28.8" + +[tool.ruff] +extend-ignore = ["B018", "B019", "RUF001", "PGH003", "PGH004", "RUF003", "E402"] +extend-select = [ + "I", # isort + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "PGH", # pygrep-hooks + "RUF", # ruff + "W", # pycodestyle + "YTT" # flake8-2020 +] +fixable = ["ALL"] +line-length = 120 + +[tool.ruff.format] +indent-style = "space" +line-ending = "auto" +quote-style = "double" +skip-magic-trailing-comma = false + +[tool.ruff.isort] +combine-as-imports = true + +[tool.ruff.mccabe] +max-complexity = 10 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..fe90734 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,14 @@ +# from vapoursynth import core +# +# from ccrestoration import AutoModel +# +# # AutoModel.list_models() +# # AutoModel.register(model_name, model_class_object) +# clip = core.bs.VideoSource(source="s.mkv") +# clip = AutoModel.from_pretrained("REAL-ESRGAN/RealESRGANx4plus/x4").InferenceVideo(clip) # download only once +# # clip = AutoModel.from_pretrained_url(url, key, model, **args).InferenceVideo(clip) +# # clip = AutoModel.from_pretrained_path(path, model, **args).InferenceVideo(clip) +# clip.set_output() +# 设计 autoconfig 模块,用于自动配置不同模型,如animejanai 对应 {arch: rrdb, scale: 4, noise: 0.1} +# config = AutoConfig.from_pretrained("animejanai") +# AutoModel.from_config(config).InferenceVideo(clip) diff --git a/tests/test_auto_class.py b/tests/test_auto_class.py new file mode 100644 index 0000000..78561ec --- /dev/null +++ b/tests/test_auto_class.py @@ -0,0 +1,29 @@ +from typing import Any + +from ccrestoration import AutoConfig, AutoModel +from ccrestoration.core.config import RealESRGANConfig +from ccrestoration.core.model import RealESRGANModel + + +def test_auto_class_register() -> None: + cfg_name = "TESTCONFIG.pth" + model_name = "TESTMODEL" + + cfg = RealESRGANConfig( + name=cfg_name, + url="https://github.com/HolyWu/vs-realesrgan/releases/download/model/RealESRGAN_x4plus_anime_6B.pth", + arch="RRDB", + model=model_name, + scale=4, + num_block=6, + ) + + AutoConfig.register(cfg) + + @AutoModel.register(name=model_name) + class TESTMODEL(RealESRGANModel): + def get_cfg(self) -> Any: + return self.config + + model: TESTMODEL = AutoModel.from_pretrained(cfg_name) + assert model.get_cfg() == cfg diff --git a/tests/test_cache_models.py b/tests/test_cache_models.py new file mode 100644 index 0000000..324d7de --- /dev/null +++ b/tests/test_cache_models.py @@ -0,0 +1,6 @@ +from ccrestoration import CONFIG_REGISTRY, ConfigType +from ccrestoration.cache_models import load_file_from_url + + +def test_cache_models() -> None: + load_file_from_url(CONFIG_REGISTRY.get(ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x)) diff --git a/tests/test_realesrgan.py b/tests/test_realesrgan.py new file mode 100644 index 0000000..9cb56a6 --- /dev/null +++ b/tests/test_realesrgan.py @@ -0,0 +1,47 @@ +import cv2 + +from ccrestoration import AutoConfig, AutoModel, BaseConfig, ConfigType +from ccrestoration.core.model import SRBaseModel + +from .util import ASSETS_PATH, calculate_image_similarity, compare_image_size, get_device, load_image + + +class Test_RealESRGAN: + def test_official(self) -> None: + img1 = load_image() + + for k in [ + ConfigType.RealESRGAN_RealESRGAN_x4plus_4x, + ConfigType.RealESRGAN_RealESRGAN_x4plus_anime_6B_4x, + ConfigType.RealESRGAN_RealESRGAN_x2plus_2x, + ConfigType.RealESRGAN_realesr_animevideov3_4x, + ]: + print(f"Testing {k}") + cfg: BaseConfig = AutoConfig.from_pretrained(k) + model: SRBaseModel = AutoModel.from_config(config=cfg, fp16=False, device=get_device()) + print(model.device) + + img2 = model.inference_image(img1) + + assert calculate_image_similarity(img1, img2) + assert compare_image_size(img1, img2, cfg.scale) + + def test_custom(self) -> None: + img1 = load_image() + + for k in [ + ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x, + ConfigType.RealESRGAN_AniScale_2_Compact_2x, + ConfigType.RealESRGAN_Ani4Kv2_Compact_2x, + ]: + print(f"Testing {k}") + cfg: BaseConfig = AutoConfig.from_pretrained(k) + model: SRBaseModel = AutoModel.from_config(config=cfg, fp16=False, device=get_device()) + print(model.device) + + img2 = model.inference_image(img1) + + cv2.imwrite(str(ASSETS_PATH / f"test_{k}_out.jpg"), img2) + + assert calculate_image_similarity(img1, img2) + assert compare_image_size(img1, img2, cfg.scale) diff --git a/tests/test_sr.py b/tests/test_sr.py new file mode 100644 index 0000000..432522b --- /dev/null +++ b/tests/test_sr.py @@ -0,0 +1,51 @@ +import sys + +import cv2 +import pytest +import torch + +from ccrestoration import AutoConfig, AutoModel, BaseConfig, ConfigType +from ccrestoration.core.model import SRBaseModel + +from .util import ASSETS_PATH, calculate_image_similarity, compare_image_size, get_device, load_image + + +def test_inference() -> None: + tensor1 = torch.rand(1, 3, 256, 256).to(get_device()) + + k = ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x + + model: SRBaseModel = AutoModel.from_pretrained(pretrained_model_name=k, fp16=False, device=get_device()) + + t2 = model(tensor1) + t3 = model.inference(tensor1) + assert t2.equal(t3) + + +def test_sr_fp16() -> None: + img1 = load_image() + k = ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x + + cfg: BaseConfig = AutoConfig.from_pretrained(k) + model: SRBaseModel = AutoModel.from_config(config=cfg, fp16=True, device=get_device()) + + img2 = model.inference_image(img1) + + cv2.imwrite(str(ASSETS_PATH / f"test_fp16_{k}_out.jpg"), img2) + + assert calculate_image_similarity(img1, img2) + assert compare_image_size(img1, img2, cfg.scale) + + +@pytest.mark.skipif(sys.platform == "win32", reason="Skip test torch.compile on Windows") +def test_sr_compile() -> None: + img1 = load_image() + k = ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x + + model: SRBaseModel = AutoModel.from_pretrained( + pretrained_model_name=k, fp16=True, compile=True, device=get_device() + ) + + img2 = model.inference_image(img1) + + assert calculate_image_similarity(img1, img2) diff --git a/tests/test_type.py b/tests/test_type.py new file mode 100644 index 0000000..1ffa113 --- /dev/null +++ b/tests/test_type.py @@ -0,0 +1,8 @@ +import pytest + +from ccrestoration import BaseModelInterface + + +def test_base_class() -> None: + with pytest.raises(TypeError): + BaseModelInterface() # type: ignore diff --git a/tests/util.py b/tests/util.py new file mode 100644 index 0000000..ebbf307 --- /dev/null +++ b/tests/util.py @@ -0,0 +1,58 @@ +import math +import os +from pathlib import Path + +import cv2 +import numpy as np +import torch +from skimage.metrics import structural_similarity + +from ccrestoration.utils.device import DEFAULT_DEVICE + +ASSETS_PATH = Path(__file__).resolve().parent.parent.absolute() / "assets" +TEST_IMG_PATH = ASSETS_PATH / "test.jpg" + + +def get_device() -> torch.device: + if os.environ.get("GITHUB_ACTIONS") == "true": + return torch.device("cpu") + return DEFAULT_DEVICE + + +def load_image() -> np.ndarray: + img = cv2.imdecode(np.fromfile(str(TEST_IMG_PATH), dtype=np.uint8), cv2.IMREAD_COLOR) + return img + + +def calculate_image_similarity(image1: np.ndarray, image2: np.ndarray) -> bool: + """ + calculate image similarity, check SR is correct + + :param image1: original image + :param image2: upscale image + :return: + """ + # Resize the two images to the same size + height, width = image1.shape[:2] + image2 = cv2.resize(image2, (width, height)) + # Convert the images to grayscale + grayscale_image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2GRAY) + grayscale_image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2GRAY) + # Calculate the Structural Similarity Index (SSIM) between the two images + (score, diff) = structural_similarity(grayscale_image1, grayscale_image2, full=True) + print("SSIM: {}".format(score)) + return score > 0.9 + + +def compare_image_size(image1: np.ndarray, image2: np.ndarray, scale: int) -> bool: + """ + compare original image size and upscale image size, check targetscale is correct + + :param image1: original image + :param image2: upscale image + :param scale: upscale ratio + :return: + """ + target_size = (math.ceil(image1.shape[0] * scale), math.ceil(image1.shape[1] * scale)) + + return image2.shape[0] == target_size[0] and image2.shape[1] == target_size[1] diff --git a/vs.py b/vs.py new file mode 100644 index 0000000..5795662 --- /dev/null +++ b/vs.py @@ -0,0 +1,16 @@ +import sys + +sys.path.append(".") + +import vapoursynth as vs +from vapoursynth import core + +from ccrestoration import AutoModel, BaseModelInterface, ConfigType + +model: BaseModelInterface = AutoModel.from_pretrained(ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x) + +clip = core.bs.VideoSource(source="s.mp4") +clip = core.resize.Bicubic(clip=clip, format=vs.RGBH) +clip = model.inference_video(clip) +clip = core.resize.Bicubic(clip=clip, matrix_s="709", format=vs.YUV420P16) +clip.set_output()