diff --git a/.gitignore b/.gitignore index 8d18d04..4046166 100644 --- a/.gitignore +++ b/.gitignore @@ -161,9 +161,9 @@ cython_debug/ .idea/ /.ruff_cache/ -/ccrestoration/cache_models/*.pth -/ccrestoration/cache_models/*.pt -/ccrestoration/cache_models/*.pkl +*.pth +*.pt +*.pkl *_out.jpg *_out.png diff --git a/README.md b/README.md index ee5c48b..6794814 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ pip install ccrestoration #### cv2 -A simple example to use the sisr model (APISR) to process an image +a simple example to use the SISR (Single Image Super-Resolution) model to process an image (APISR) ```python import cv2 @@ -39,7 +39,7 @@ cv2.imwrite("test_out.jpg", img) #### VapourSynth -A simple example to use the vsr model (AnimeSR) to process a video +a simple example to use the VSR (Video Super-Resolution) model to process a video (AnimeSR) ```python import vapoursynth as vs @@ -70,6 +70,12 @@ It still in development, the following models are supported: - [Weight(Config)](./ccrestoration/type/config.py) +### Notice + +- All the architectures have been edited to normalize input and output, and automatic padding is applied. The input and output tensor shapes may differ from the original architectures. For SR models, the input and output are both 4D tensors in the shape of `(b, c, h, w)`. For VSR models, the input and output are both 5D tensors in the shape of `(b, l, c, h, w)`. + +- For VSR models with equal l in input and output `(f1, f2, f3, f4 -> f1', f2', f3', f4')`, you can directly extend from `class VSRBaseModel`. For VSR models that output only one frame `(f-2, f-1, f0, f1, f2 -> f0')`, you also need to set `self.one_frame_out = True`. + ### Reference - [PyTorch](https://github.com/pytorch/pytorch) diff --git a/ccrestoration/auto/model.py b/ccrestoration/auto/model.py index 27f3506..7bb443e 100644 --- a/ccrestoration/auto/model.py +++ b/ccrestoration/auto/model.py @@ -18,6 +18,8 @@ def from_pretrained( tile: Optional[Tuple[int, int]] = (128, 128), tile_pad: int = 8, pad_img: Optional[Tuple[int, int]] = None, + model_dir: Optional[str] = None, + gh_proxy: Optional[str] = None, ) -> Any: """ Get a model instance from a pretrained model name. @@ -30,6 +32,8 @@ def from_pretrained( :param tile: tile size for tile inference, tile[0] is width, tile[1] is height, None for disable :param tile_pad: The padding size for each tile :param pad_img: The size for the padded image, pad[0] is width, pad[1] is height, None for auto calculate + :param model_dir: The path to cache the downloaded model. Should be a full path. If None, use default cache path. + :param gh_proxy: The proxy for downloading from github release. Example: https://github.abskoop.workers.dev/ :return: """ @@ -43,6 +47,8 @@ def from_pretrained( tile=tile, tile_pad=tile_pad, pad_img=pad_img, + model_dir=model_dir, + gh_proxy=gh_proxy, ) @staticmethod @@ -55,6 +61,8 @@ def from_config( tile: Optional[Tuple[int, int]] = (128, 128), tile_pad: int = 8, pad_img: Optional[Tuple[int, int]] = None, + model_dir: Optional[str] = None, + gh_proxy: Optional[str] = None, ) -> Any: """ Get a model instance from a config. @@ -67,6 +75,8 @@ def from_config( :param tile: tile size for tile inference, tile[0] is width, tile[1] is height, None for disable :param tile_pad: The padding size for each tile :param pad_img: The size for the padded image, pad[0] is width, pad[1] is height, None for auto calculate + :param model_dir: The path to cache the downloaded model. Should be a full path. If None, use default cache path. + :param gh_proxy: The proxy for downloading from github release. Example: https://github.abskoop.workers.dev/ :return: """ @@ -80,6 +90,8 @@ def from_config( tile=tile, tile_pad=tile_pad, pad_img=pad_img, + model_dir=model_dir, + gh_proxy=gh_proxy, ) return model diff --git a/ccrestoration/cache_models/__init__.py b/ccrestoration/cache_models/__init__.py index 114d2b2..61a9613 100644 --- a/ccrestoration/cache_models/__init__.py +++ b/ccrestoration/cache_models/__init__.py @@ -23,7 +23,11 @@ def get_file_sha256(file_path: str, blocksize: int = 1 << 20) -> str: def load_file_from_url( - config: BaseConfig, force_download: bool = False, progress: bool = True, model_dir: Optional[str] = None + config: BaseConfig, + force_download: bool = False, + progress: bool = True, + model_dir: Optional[str] = None, + gh_proxy: Optional[str] = None, ) -> str: """ Load file form http url, will download models if necessary. @@ -34,6 +38,7 @@ def load_file_from_url( :param force_download: Whether to force download the file. :param progress: Whether to show the download progress. :param model_dir: The path to save the downloaded model. Should be a full path. If None, use default cache path. + :param gh_proxy: The proxy for downloading from github release. Example: https://github.abskoop.workers.dev/ :return: """ @@ -42,13 +47,22 @@ def load_file_from_url( cached_file_path = os.path.abspath(os.path.join(model_dir, config.name)) + _url: str = str(config.url) + _gh_proxy = gh_proxy + if _gh_proxy is not None and _url.startswith("https://github.com"): + if not _gh_proxy.endswith("/"): + _gh_proxy += "/" + _url = _gh_proxy + _url + if not os.path.exists(cached_file_path) or force_download: - print(f"Downloading: {config.url} to {cached_file_path}\n") + if _gh_proxy is not None: + print(f"Using github proxy: {_gh_proxy}") + print(f"Downloading: {_url} to {cached_file_path}\n") @retry(wait=wait_random(min=3, max=5), stop=stop_after_delay(10) | stop_after_attempt(30)) def _download() -> None: try: - download_url_to_file(str(config.url), cached_file_path, hash_prefix=None, progress=progress) + download_url_to_file(url=_url, dst=cached_file_path, hash_prefix=None, progress=progress) except Exception as e: print(f"Download failed: {e}, retrying...") raise e diff --git a/ccrestoration/model/basicvsr_model.py b/ccrestoration/model/basicvsr_model.py index c404230..a8c8059 100644 --- a/ccrestoration/model/basicvsr_model.py +++ b/ccrestoration/model/basicvsr_model.py @@ -28,6 +28,8 @@ def load_model(self) -> Any: tile=self.tile, tile_pad=self.tile_pad, pad_img=self.pad_img, + model_dir=self.model_dir, + gh_proxy=self.gh_proxy, ) model = BasicVSR( diff --git a/ccrestoration/model/iconvsr_model.py b/ccrestoration/model/iconvsr_model.py index 9eb6e94..f0fa9eb 100644 --- a/ccrestoration/model/iconvsr_model.py +++ b/ccrestoration/model/iconvsr_model.py @@ -29,6 +29,8 @@ def load_model(self) -> Any: tile=self.tile, tile_pad=self.tile_pad, pad_img=self.pad_img, + model_dir=self.model_dir, + gh_proxy=self.gh_proxy, ) edvr_feature_extractor = EDVRFeatureExtractorModel( @@ -40,6 +42,8 @@ def load_model(self) -> Any: tile=self.tile, tile_pad=self.tile_pad, pad_img=self.pad_img, + model_dir=self.model_dir, + gh_proxy=self.gh_proxy, ) model = IconVSR( diff --git a/ccrestoration/model/sr_base_model.py b/ccrestoration/model/sr_base_model.py index c1b157a..97b18b0 100644 --- a/ccrestoration/model/sr_base_model.py +++ b/ccrestoration/model/sr_base_model.py @@ -23,10 +23,14 @@ def get_state_dict(self) -> Any: model_path = str(cfg.path) else: try: - model_path = load_file_from_url(cfg) + model_path = load_file_from_url( + config=cfg, force_download=False, model_dir=self.model_dir, gh_proxy=self.gh_proxy + ) except Exception as e: print(f"Error: {e}, try force download the model...") - model_path = load_file_from_url(cfg, force_download=True) + model_path = load_file_from_url( + config=cfg, force_download=True, model_dir=self.model_dir, gh_proxy=self.gh_proxy + ) return torch.load(model_path, map_location=self.device, weights_only=True) diff --git a/ccrestoration/type/base_model.py b/ccrestoration/type/base_model.py index bfcc40f..7966dc3 100644 --- a/ccrestoration/type/base_model.py +++ b/ccrestoration/type/base_model.py @@ -19,6 +19,8 @@ class BaseModelInterface(ABC): :param tile: tile size for tile inference, tile[0] is width, tile[1] is height, None for disable :param tile_pad: The padding size for each tile :param pad_img: The size for the padded image, pad[0] is width, pad[1] is height, None for auto calculate + :param model_dir: The path to cache the downloaded model. Should be a full path. If None, use default cache path. + :param gh_proxy: The proxy for downloading from github release. Example: https://github.abskoop.workers.dev/ """ def __init__( @@ -31,6 +33,8 @@ def __init__( tile: Optional[Tuple[int, int]] = (128, 128), tile_pad: int = 8, pad_img: Optional[Tuple[int, int]] = None, + model_dir: Optional[str] = None, + gh_proxy: Optional[str] = None, ) -> None: # extra config self.one_frame_out: bool = False # for vsr model type @@ -44,6 +48,8 @@ def __init__( self.tile: Optional[Tuple[int, int]] = tile self.tile_pad: int = tile_pad self.pad_img: Optional[Tuple[int, int]] = pad_img + self.model_dir: Optional[str] = model_dir + self.gh_proxy: Optional[str] = gh_proxy if device is None: self.device = DEFAULT_DEVICE diff --git a/example/sisr.py b/example/sisr.py index 677b472..58636aa 100644 --- a/example/sisr.py +++ b/example/sisr.py @@ -4,7 +4,7 @@ from ccrestoration import ArchType, AutoConfig, AutoModel, BaseConfig, ConfigType, SRBaseModel from ccrestoration.config import RealESRGANConfig -example = 3 +example = 4 if example == 0: # fast load a pre-trained model @@ -28,6 +28,14 @@ num_block=6, ) model: SRBaseModel = AutoModel.from_config(config=config) +elif example == 4: + # use custom model dir and gh proxy + model: SRBaseModel = AutoModel.from_pretrained( + pretrained_model_name=ConfigType.RealESRGAN_APISR_RRDB_GAN_generator_2x, + model_dir="./", + gh_proxy="https://github.abskoop.workers.dev/", + ) + else: raise ValueError("example not found") diff --git a/pyproject.toml b/pyproject.toml index 5900575..e5fd64f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ license = "MIT" name = "ccrestoration" readme = "README.md" repository = "https://github.com/TensoRaws/ccrestoration" -version = "0.0.11" +version = "0.0.12" # Requirements [tool.poetry.dependencies] diff --git a/tests/test_cache_models.py b/tests/test_cache_models.py index 324d7de..fdef670 100644 --- a/tests/test_cache_models.py +++ b/tests/test_cache_models.py @@ -4,3 +4,16 @@ def test_cache_models() -> None: load_file_from_url(CONFIG_REGISTRY.get(ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x)) + + +def test_cache_models_with_gh_proxy() -> None: + load_file_from_url( + config=CONFIG_REGISTRY.get(ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x), + force_download=True, + gh_proxy="https://github.abskoop.workers.dev/", + ) + load_file_from_url( + config=CONFIG_REGISTRY.get(ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x), + force_download=True, + gh_proxy="https://github.abskoop.workers.dev", + )