Skip to content

Commit

Permalink
feat: support custom cache folder && use gh-proxy to download (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tohrusky authored Oct 14, 2024
1 parent 86be9ec commit 2d2bdeb
Show file tree
Hide file tree
Showing 11 changed files with 81 additions and 12 deletions.
6 changes: 3 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@ cython_debug/
.idea/
/.ruff_cache/

/ccrestoration/cache_models/*.pth
/ccrestoration/cache_models/*.pt
/ccrestoration/cache_models/*.pkl
*.pth
*.pt
*.pkl

*_out.jpg
*_out.png
Expand Down
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pip install ccrestoration

#### cv2

A simple example to use the sisr model (APISR) to process an image
a simple example to use the SISR (Single Image Super-Resolution) model to process an image (APISR)

```python
import cv2
Expand All @@ -39,7 +39,7 @@ cv2.imwrite("test_out.jpg", img)

#### VapourSynth

A simple example to use the vsr model (AnimeSR) to process a video
a simple example to use the VSR (Video Super-Resolution) model to process a video (AnimeSR)

```python
import vapoursynth as vs
Expand Down Expand Up @@ -70,6 +70,12 @@ It still in development, the following models are supported:

- [Weight(Config)](./ccrestoration/type/config.py)

### Notice

- All the architectures have been edited to normalize input and output, and automatic padding is applied. The input and output tensor shapes may differ from the original architectures. For SR models, the input and output are both 4D tensors in the shape of `(b, c, h, w)`. For VSR models, the input and output are both 5D tensors in the shape of `(b, l, c, h, w)`.

- For VSR models with equal l in input and output `(f1, f2, f3, f4 -> f1', f2', f3', f4')`, you can directly extend from `class VSRBaseModel`. For VSR models that output only one frame `(f-2, f-1, f0, f1, f2 -> f0')`, you also need to set `self.one_frame_out = True`.

### Reference

- [PyTorch](https://github.com/pytorch/pytorch)
Expand Down
12 changes: 12 additions & 0 deletions ccrestoration/auto/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def from_pretrained(
tile: Optional[Tuple[int, int]] = (128, 128),
tile_pad: int = 8,
pad_img: Optional[Tuple[int, int]] = None,
model_dir: Optional[str] = None,
gh_proxy: Optional[str] = None,
) -> Any:
"""
Get a model instance from a pretrained model name.
Expand All @@ -30,6 +32,8 @@ def from_pretrained(
:param tile: tile size for tile inference, tile[0] is width, tile[1] is height, None for disable
:param tile_pad: The padding size for each tile
:param pad_img: The size for the padded image, pad[0] is width, pad[1] is height, None for auto calculate
:param model_dir: The path to cache the downloaded model. Should be a full path. If None, use default cache path.
:param gh_proxy: The proxy for downloading from github release. Example: https://github.abskoop.workers.dev/
:return:
"""

Expand All @@ -43,6 +47,8 @@ def from_pretrained(
tile=tile,
tile_pad=tile_pad,
pad_img=pad_img,
model_dir=model_dir,
gh_proxy=gh_proxy,
)

@staticmethod
Expand All @@ -55,6 +61,8 @@ def from_config(
tile: Optional[Tuple[int, int]] = (128, 128),
tile_pad: int = 8,
pad_img: Optional[Tuple[int, int]] = None,
model_dir: Optional[str] = None,
gh_proxy: Optional[str] = None,
) -> Any:
"""
Get a model instance from a config.
Expand All @@ -67,6 +75,8 @@ def from_config(
:param tile: tile size for tile inference, tile[0] is width, tile[1] is height, None for disable
:param tile_pad: The padding size for each tile
:param pad_img: The size for the padded image, pad[0] is width, pad[1] is height, None for auto calculate
:param model_dir: The path to cache the downloaded model. Should be a full path. If None, use default cache path.
:param gh_proxy: The proxy for downloading from github release. Example: https://github.abskoop.workers.dev/
:return:
"""

Expand All @@ -80,6 +90,8 @@ def from_config(
tile=tile,
tile_pad=tile_pad,
pad_img=pad_img,
model_dir=model_dir,
gh_proxy=gh_proxy,
)

return model
Expand Down
20 changes: 17 additions & 3 deletions ccrestoration/cache_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ def get_file_sha256(file_path: str, blocksize: int = 1 << 20) -> str:


def load_file_from_url(
config: BaseConfig, force_download: bool = False, progress: bool = True, model_dir: Optional[str] = None
config: BaseConfig,
force_download: bool = False,
progress: bool = True,
model_dir: Optional[str] = None,
gh_proxy: Optional[str] = None,
) -> str:
"""
Load file form http url, will download models if necessary.
Expand All @@ -34,6 +38,7 @@ def load_file_from_url(
:param force_download: Whether to force download the file.
:param progress: Whether to show the download progress.
:param model_dir: The path to save the downloaded model. Should be a full path. If None, use default cache path.
:param gh_proxy: The proxy for downloading from github release. Example: https://github.abskoop.workers.dev/
:return:
"""

Expand All @@ -42,13 +47,22 @@ def load_file_from_url(

cached_file_path = os.path.abspath(os.path.join(model_dir, config.name))

_url: str = str(config.url)
_gh_proxy = gh_proxy
if _gh_proxy is not None and _url.startswith("https://github.com"):
if not _gh_proxy.endswith("/"):
_gh_proxy += "/"
_url = _gh_proxy + _url

if not os.path.exists(cached_file_path) or force_download:
print(f"Downloading: {config.url} to {cached_file_path}\n")
if _gh_proxy is not None:
print(f"Using github proxy: {_gh_proxy}")
print(f"Downloading: {_url} to {cached_file_path}\n")

@retry(wait=wait_random(min=3, max=5), stop=stop_after_delay(10) | stop_after_attempt(30))
def _download() -> None:
try:
download_url_to_file(str(config.url), cached_file_path, hash_prefix=None, progress=progress)
download_url_to_file(url=_url, dst=cached_file_path, hash_prefix=None, progress=progress)
except Exception as e:
print(f"Download failed: {e}, retrying...")
raise e
Expand Down
2 changes: 2 additions & 0 deletions ccrestoration/model/basicvsr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def load_model(self) -> Any:
tile=self.tile,
tile_pad=self.tile_pad,
pad_img=self.pad_img,
model_dir=self.model_dir,
gh_proxy=self.gh_proxy,
)

model = BasicVSR(
Expand Down
4 changes: 4 additions & 0 deletions ccrestoration/model/iconvsr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def load_model(self) -> Any:
tile=self.tile,
tile_pad=self.tile_pad,
pad_img=self.pad_img,
model_dir=self.model_dir,
gh_proxy=self.gh_proxy,
)

edvr_feature_extractor = EDVRFeatureExtractorModel(
Expand All @@ -40,6 +42,8 @@ def load_model(self) -> Any:
tile=self.tile,
tile_pad=self.tile_pad,
pad_img=self.pad_img,
model_dir=self.model_dir,
gh_proxy=self.gh_proxy,
)

model = IconVSR(
Expand Down
8 changes: 6 additions & 2 deletions ccrestoration/model/sr_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,14 @@ def get_state_dict(self) -> Any:
model_path = str(cfg.path)
else:
try:
model_path = load_file_from_url(cfg)
model_path = load_file_from_url(
config=cfg, force_download=False, model_dir=self.model_dir, gh_proxy=self.gh_proxy
)
except Exception as e:
print(f"Error: {e}, try force download the model...")
model_path = load_file_from_url(cfg, force_download=True)
model_path = load_file_from_url(
config=cfg, force_download=True, model_dir=self.model_dir, gh_proxy=self.gh_proxy
)

return torch.load(model_path, map_location=self.device, weights_only=True)

Expand Down
6 changes: 6 additions & 0 deletions ccrestoration/type/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class BaseModelInterface(ABC):
:param tile: tile size for tile inference, tile[0] is width, tile[1] is height, None for disable
:param tile_pad: The padding size for each tile
:param pad_img: The size for the padded image, pad[0] is width, pad[1] is height, None for auto calculate
:param model_dir: The path to cache the downloaded model. Should be a full path. If None, use default cache path.
:param gh_proxy: The proxy for downloading from github release. Example: https://github.abskoop.workers.dev/
"""

def __init__(
Expand All @@ -31,6 +33,8 @@ def __init__(
tile: Optional[Tuple[int, int]] = (128, 128),
tile_pad: int = 8,
pad_img: Optional[Tuple[int, int]] = None,
model_dir: Optional[str] = None,
gh_proxy: Optional[str] = None,
) -> None:
# extra config
self.one_frame_out: bool = False # for vsr model type
Expand All @@ -44,6 +48,8 @@ def __init__(
self.tile: Optional[Tuple[int, int]] = tile
self.tile_pad: int = tile_pad
self.pad_img: Optional[Tuple[int, int]] = pad_img
self.model_dir: Optional[str] = model_dir
self.gh_proxy: Optional[str] = gh_proxy

if device is None:
self.device = DEFAULT_DEVICE
Expand Down
10 changes: 9 additions & 1 deletion example/sisr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ccrestoration import ArchType, AutoConfig, AutoModel, BaseConfig, ConfigType, SRBaseModel
from ccrestoration.config import RealESRGANConfig

example = 3
example = 4

if example == 0:
# fast load a pre-trained model
Expand All @@ -28,6 +28,14 @@
num_block=6,
)
model: SRBaseModel = AutoModel.from_config(config=config)
elif example == 4:
# use custom model dir and gh proxy
model: SRBaseModel = AutoModel.from_pretrained(
pretrained_model_name=ConfigType.RealESRGAN_APISR_RRDB_GAN_generator_2x,
model_dir="./",
gh_proxy="https://github.abskoop.workers.dev/",
)


else:
raise ValueError("example not found")
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ license = "MIT"
name = "ccrestoration"
readme = "README.md"
repository = "https://github.com/TensoRaws/ccrestoration"
version = "0.0.11"
version = "0.0.12"

# Requirements
[tool.poetry.dependencies]
Expand Down
13 changes: 13 additions & 0 deletions tests/test_cache_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,16 @@

def test_cache_models() -> None:
load_file_from_url(CONFIG_REGISTRY.get(ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x))


def test_cache_models_with_gh_proxy() -> None:
load_file_from_url(
config=CONFIG_REGISTRY.get(ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x),
force_download=True,
gh_proxy="https://github.abskoop.workers.dev/",
)
load_file_from_url(
config=CONFIG_REGISTRY.get(ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x),
force_download=True,
gh_proxy="https://github.abskoop.workers.dev",
)

0 comments on commit 2d2bdeb

Please sign in to comment.