Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support SRCNN #25

Merged
merged 2 commits into from
Nov 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ccrestoration/arch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
from ccrestoration.arch.msrswvsr_arch import MSRSWVSR # noqa
from ccrestoration.arch.scunet_arch import SCUNet # noqa
from ccrestoration.arch.dat_arch import DAT # noqa
from ccrestoration.arch.srcnn_arch import SRCNN # noqa
41 changes: 41 additions & 0 deletions ccrestoration/arch/srcnn_arch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
from torch import nn
from torch.nn import functional as F

from ccrestoration.arch import ARCH_REGISTRY
from ccrestoration.type import ArchType
from ccrestoration.util.color import rgb_to_yuv, yuv_to_rgb


@ARCH_REGISTRY.register(name=ArchType.SRCNN)
class SRCNN(nn.Module):
def __init__(self, num_channels: int = 1, scale: int = 2) -> None:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Add input validation for the scale parameter

Consider adding validation to ensure the scale factor is a positive integer within an expected range to prevent potential runtime errors.

    def __init__(self, num_channels: int = 1, scale: int = 2) -> None:
        if not isinstance(scale, int) or scale < 1:
            raise ValueError("Scale factor must be a positive integer")
        if scale > 8:
            raise ValueError("Scale factor must be 8 or less")
        super(SRCNN, self).__init__()

super(SRCNN, self).__init__()
self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2)
self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2)
self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2)
self.relu = nn.ReLU(inplace=True)
self.num_channels = num_channels
self.scale = scale

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.interpolate(x, scale_factor=self.scale, mode="bilinear")

if self.num_channels == 1:
# RGB -> YUV
x = rgb_to_yuv(x)
y, u, v = x[:, 0:1, ...], x[:, 1:2, ...], x[:, 2:3, ...]

y = self.relu(self.conv1(y))
y = self.relu(self.conv2(y))
y = self.conv3(y)

x = torch.cat([y, u, v], dim=1)
# YUV -> RGB
x = yuv_to_rgb(x)
else:
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.conv3(x)

return x
1 change: 1 addition & 0 deletions ccrestoration/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
from ccrestoration.config.animesr_config import AnimeSRConfig # noqa
from ccrestoration.config.scunet_config import SCUNetConfig # noqa
from ccrestoration.config.dat_config import DATConfig # noqa
from ccrestoration.config.srcnn_config import SRCNNConfig # noqa
36 changes: 36 additions & 0 deletions ccrestoration/config/srcnn_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Union

from ccrestoration.config import CONFIG_REGISTRY
from ccrestoration.type import ArchType, BaseConfig, ConfigType, ModelType


class SRCNNConfig(BaseConfig):
arch: ArchType = ArchType.SRCNN
model: Union[ModelType, str] = ModelType.SRCNN
scale: int = 2
num_channels: int = 1


SRCNNConfigs = [
SRCNNConfig(
name=ConfigType.SRCNN_2x,
url="https://github.com/TensoRaws/ccrestoration/releases/download/model_zoo/SRCNN_2x.pth",
hash="e803ec6e0230ae12b1fa7fd1c67bd57d2e744b4f4fbbc861bf9790070fc4d19e",
scale=2,
),
SRCNNConfig(
name=ConfigType.SRCNN_3x,
url="https://github.com/TensoRaws/ccrestoration/releases/download/model_zoo/SRCNN_3x.pth",
hash="364ec936313d0fd1052c641b20cefd8153a2c1d89712f357f804f0119ab7ab90",
scale=3,
),
SRCNNConfig(
name=ConfigType.SRCNN_4x,
url="https://github.com/TensoRaws/ccrestoration/releases/download/model_zoo/SRCNN_4x.pth",
hash="f07978e521ede367d55ef7ca83f4f4979e2339c594bead101cfbb9611023f70e",
scale=4,
),
]

for cfg in SRCNNConfigs:
CONFIG_REGISTRY.register(cfg)
1 change: 1 addition & 0 deletions ccrestoration/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
from ccrestoration.model.animesr_model import AnimeSRModel # noqa
from ccrestoration.model.scunet_model import SCUNetModel # noqa
from ccrestoration.model.dat_model import DATModel # noqa
from ccrestoration.model.srcnn_model import SRCNNModel # noqa
8 changes: 4 additions & 4 deletions ccrestoration/model/sr_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,19 @@ def get_state_dict(self) -> Any:
cfg: BaseConfig = self.config

if cfg.path is not None:
model_path = str(cfg.path)
state_dict_path = str(cfg.path)
else:
try:
model_path = load_file_from_url(
state_dict_path = load_file_from_url(
config=cfg, force_download=False, model_dir=self.model_dir, gh_proxy=self.gh_proxy
)
except Exception as e:
print(f"Error: {e}, try force download the model...")
model_path = load_file_from_url(
state_dict_path = load_file_from_url(
config=cfg, force_download=True, model_dir=self.model_dir, gh_proxy=self.gh_proxy
)

return torch.load(model_path, map_location=self.device, weights_only=True)
return torch.load(state_dict_path, map_location=self.device, weights_only=True)

@torch.inference_mode() # type: ignore
def inference(self, img: torch.Tensor) -> torch.Tensor:
Expand Down
28 changes: 28 additions & 0 deletions ccrestoration/model/srcnn_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Any

from ccrestoration.arch import SRCNN
from ccrestoration.config import SRCNNConfig
from ccrestoration.model import MODEL_REGISTRY
from ccrestoration.model.sr_base_model import SRBaseModel
from ccrestoration.type import ModelType


@MODEL_REGISTRY.register(name=ModelType.SRCNN)
class SRCNNModel(SRBaseModel):
def load_model(self) -> Any:
cfg: SRCNNConfig = self.config
state_dict = self.get_state_dict()

if "params_ema" in state_dict:
state_dict = state_dict["params_ema"]
elif "params" in state_dict:
state_dict = state_dict["params"]

model = SRCNN(
num_channels=cfg.num_channels,
scale=cfg.scale,
)

model.load_state_dict(state_dict)
model.eval().to(self.device)
return model
1 change: 1 addition & 0 deletions ccrestoration/type/arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class ArchType(str, Enum):
SWINIR = "SWINIR"
SCUNET = "SCUNET"
DAT = "DAT"
SRCNN = "SRCNN"

# ------------------------------------- Auxiliary Network ----------------------------------------------------------

Expand Down
5 changes: 5 additions & 0 deletions ccrestoration/type/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ class ConfigType(str, Enum):

DAT_APISR_GAN_generator_4x = "DAT_APISR_GAN_generator_4x.pth"

# SRCNN
SRCNN_2x = "SRCNN_2x.pth"
SRCNN_3x = "SRCNN_3x.pth"
SRCNN_4x = "SRCNN_4x.pth"

# ------------------------------------- Auxiliary Network ----------------------------------------------------------

# SpyNet
Expand Down
1 change: 1 addition & 0 deletions ccrestoration/type/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class ModelType(str, Enum):
SwinIR = "SwinIR"
SCUNet = "SCUNet"
DAT = "DAT"
SRCNN = "SRCNN"

# ------------------------------------- Auxiliary Network ----------------------------------------------------------

Expand Down
83 changes: 83 additions & 0 deletions ccrestoration/util/color.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import torch
from torch import Tensor


def rgb_to_yuv(image: Tensor) -> Tensor:
r"""Convert an RGB image to YUV.

.. image:: _static/img/rgb_to_yuv.png

The image data is assumed to be in the range of :math:`(0, 1)`. The range of the output is of
:math:`(0, 1)` to luma and the ranges of U and V are :math:`(-0.436, 0.436)` and :math:`(-0.615, 0.615)`,
respectively.

The YUV model adopted here follows M/PAL values (see
`BT.470-5 <https://www.itu.int/dms_pubrec/itu-r/rec/bt/R-REC-BT.470-5-199802-S!!PDF-E.pdf>`_, Table 2,
items 2.5 and 2.6).

Args:
image: RGB Image to be converted to YUV with shape :math:`(*, 3, H, W)`.

Returns:
YUV version of the image with shape :math:`(*, 3, H, W)`.

Example:
>>> input = torch.rand(2, 3, 4, 5)
>>> output = rgb_to_yuv(input) # 2x3x4x5
"""
if not isinstance(image, Tensor):
raise TypeError(f"Input type is not a Tensor. Got {type(image)}")

if len(image.shape) < 3 or image.shape[-3] != 3:
raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")

r: Tensor = image[..., 0, :, :]
g: Tensor = image[..., 1, :, :]
b: Tensor = image[..., 2, :, :]

y: Tensor = 0.299 * r + 0.587 * g + 0.114 * b
u: Tensor = -0.147 * r - 0.289 * g + 0.436 * b
v: Tensor = 0.615 * r - 0.515 * g - 0.100 * b

out: Tensor = torch.stack([y, u, v], -3)

return out


def yuv_to_rgb(image: Tensor) -> Tensor:
r"""Convert an YUV image to RGB.

The image data is assumed to be in the range of :math:`(0, 1)` for luma (Y). The ranges of U and V are
:math:`(-0.436, 0.436)` and :math:`(-0.615, 0.615)`, respectively.

YUV formula follows M/PAL values (see
`BT.470-5 <https://www.itu.int/dms_pubrec/itu-r/rec/bt/R-REC-BT.470-5-199802-S!!PDF-E.pdf>`_, Table 2,
items 2.5 and 2.6).

Args:
image: YUV Image to be converted to RGB with shape :math:`(*, 3, H, W)`.

Returns:
RGB version of the image with shape :math:`(*, 3, H, W)`.

Example:
>>> input = torch.rand(2, 3, 4, 5)
>>> output = yuv_to_rgb(input) # 2x3x4x5
"""
if not isinstance(image, Tensor):
raise TypeError(f"Input type is not a Tensor. Got {type(image)}")

if image.dim() < 3 or image.shape[-3] != 3:
raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")

y: Tensor = image[..., 0, :, :]
u: Tensor = image[..., 1, :, :]
v: Tensor = image[..., 2, :, :]

r: Tensor = y + 1.14 * v # coefficient for g is 0
g: Tensor = y + -0.396 * u - 0.581 * v
b: Tensor = y + 2.029 * u # coefficient for b is 0

out: Tensor = torch.stack([r, g, b], -3)

return out
27 changes: 27 additions & 0 deletions tests/test_srcnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import cv2

from ccrestoration import AutoConfig, AutoModel, BaseConfig, ConfigType
from ccrestoration.model import SRBaseModel

from .util import ASSETS_PATH, calculate_image_similarity, compare_image_size, get_device, load_image


class Test_SRCNN:
def test_official(self) -> None:
img1 = load_image()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Test should include multiple test images with different characteristics

Consider testing with multiple images that have different characteristics (e.g., high frequency details, smooth areas, different color patterns) to ensure the SRCNN model handles various scenarios correctly.

        test_images = [
            load_image("natural.png"),
            load_image("text.png"),
            load_image("pattern.png"),
            load_image("gradient.png")
        ]


for k in [
ConfigType.SRCNN_2x,
ConfigType.SRCNN_3x,
ConfigType.SRCNN_4x,
]:
print(f"Testing {k}")
cfg: BaseConfig = AutoConfig.from_pretrained(k)
Comment on lines +10 to +19
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Test case needs more assertions to verify SRCNN functionality

The current test only checks image similarity and size. Consider adding assertions to verify:

  1. The YUV color space conversion is working correctly
  2. The model's behavior with edge cases (e.g., very small or large images)
  3. The expected output values for a known input image
  4. Memory usage is within expected bounds for large images
def test_official(self) -> None:
    img1 = load_image()
    small_img = cv2.resize(img1, (32, 32))
    large_img = cv2.resize(img1, (1024, 1024))

    for k in [ConfigType.SRCNN_2x, ConfigType.SRCNN_3x, ConfigType.SRCNN_4x]:
        cfg: BaseConfig = AutoConfig.from_pretrained(k)
        model: SRBaseModel = AutoModel.from_config(config=cfg, fp16=False, device=get_device())

        for test_img in [img1, small_img, large_img]:
            img2 = model.inference_image(test_img)
            yuv_img = cv2.cvtColor(img2, cv2.COLOR_BGR2YUV)

            assert calculate_image_similarity(test_img, img2)
            assert compare_image_size(test_img, img2, cfg.scale)
            assert yuv_img.shape[2] == 3
            assert torch.cuda.max_memory_allocated() < 1024 * 1024 * 1024  # 1GB limit

model: SRBaseModel = AutoModel.from_config(config=cfg, fp16=False, device=get_device())
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Missing test for fp16 mode

Add test cases to verify the model works correctly with fp16=True, as this is an important configuration option that could affect model behavior and performance.

            model_fp16: SRBaseModel = AutoModel.from_config(config=cfg, fp16=True, device=get_device())
            model_fp32: SRBaseModel = AutoModel.from_config(config=cfg, fp16=False, device=get_device())
            print(model_fp16.device)
            print(model_fp32.device)

print(model.device)

img2 = model.inference_image(img1)
cv2.imwrite(str(ASSETS_PATH / f"test_{k}_out.jpg"), img2)

assert calculate_image_similarity(img1, img2)
assert compare_image_size(img1, img2, cfg.scale)