-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support SRCNN #25
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import torch | ||
from torch import nn | ||
from torch.nn import functional as F | ||
|
||
from ccrestoration.arch import ARCH_REGISTRY | ||
from ccrestoration.type import ArchType | ||
from ccrestoration.util.color import rgb_to_yuv, yuv_to_rgb | ||
|
||
|
||
@ARCH_REGISTRY.register(name=ArchType.SRCNN) | ||
class SRCNN(nn.Module): | ||
def __init__(self, num_channels: int = 1, scale: int = 2) -> None: | ||
super(SRCNN, self).__init__() | ||
self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2) | ||
self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2) | ||
self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2) | ||
self.relu = nn.ReLU(inplace=True) | ||
self.num_channels = num_channels | ||
self.scale = scale | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x = F.interpolate(x, scale_factor=self.scale, mode="bilinear") | ||
|
||
if self.num_channels == 1: | ||
# RGB -> YUV | ||
x = rgb_to_yuv(x) | ||
y, u, v = x[:, 0:1, ...], x[:, 1:2, ...], x[:, 2:3, ...] | ||
|
||
y = self.relu(self.conv1(y)) | ||
y = self.relu(self.conv2(y)) | ||
y = self.conv3(y) | ||
|
||
x = torch.cat([y, u, v], dim=1) | ||
# YUV -> RGB | ||
x = yuv_to_rgb(x) | ||
else: | ||
x = self.relu(self.conv1(x)) | ||
x = self.relu(self.conv2(x)) | ||
x = self.conv3(x) | ||
|
||
return x |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from typing import Union | ||
|
||
from ccrestoration.config import CONFIG_REGISTRY | ||
from ccrestoration.type import ArchType, BaseConfig, ConfigType, ModelType | ||
|
||
|
||
class SRCNNConfig(BaseConfig): | ||
arch: ArchType = ArchType.SRCNN | ||
model: Union[ModelType, str] = ModelType.SRCNN | ||
scale: int = 2 | ||
num_channels: int = 1 | ||
|
||
|
||
SRCNNConfigs = [ | ||
SRCNNConfig( | ||
name=ConfigType.SRCNN_2x, | ||
url="https://github.com/TensoRaws/ccrestoration/releases/download/model_zoo/SRCNN_2x.pth", | ||
hash="e803ec6e0230ae12b1fa7fd1c67bd57d2e744b4f4fbbc861bf9790070fc4d19e", | ||
scale=2, | ||
), | ||
SRCNNConfig( | ||
name=ConfigType.SRCNN_3x, | ||
url="https://github.com/TensoRaws/ccrestoration/releases/download/model_zoo/SRCNN_3x.pth", | ||
hash="364ec936313d0fd1052c641b20cefd8153a2c1d89712f357f804f0119ab7ab90", | ||
scale=3, | ||
), | ||
SRCNNConfig( | ||
name=ConfigType.SRCNN_4x, | ||
url="https://github.com/TensoRaws/ccrestoration/releases/download/model_zoo/SRCNN_4x.pth", | ||
hash="f07978e521ede367d55ef7ca83f4f4979e2339c594bead101cfbb9611023f70e", | ||
scale=4, | ||
), | ||
] | ||
|
||
for cfg in SRCNNConfigs: | ||
CONFIG_REGISTRY.register(cfg) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from typing import Any | ||
|
||
from ccrestoration.arch import SRCNN | ||
from ccrestoration.config import SRCNNConfig | ||
from ccrestoration.model import MODEL_REGISTRY | ||
from ccrestoration.model.sr_base_model import SRBaseModel | ||
from ccrestoration.type import ModelType | ||
|
||
|
||
@MODEL_REGISTRY.register(name=ModelType.SRCNN) | ||
class SRCNNModel(SRBaseModel): | ||
def load_model(self) -> Any: | ||
cfg: SRCNNConfig = self.config | ||
state_dict = self.get_state_dict() | ||
|
||
if "params_ema" in state_dict: | ||
state_dict = state_dict["params_ema"] | ||
elif "params" in state_dict: | ||
state_dict = state_dict["params"] | ||
|
||
model = SRCNN( | ||
num_channels=cfg.num_channels, | ||
scale=cfg.scale, | ||
) | ||
|
||
model.load_state_dict(state_dict) | ||
model.eval().to(self.device) | ||
return model |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import torch | ||
from torch import Tensor | ||
|
||
|
||
def rgb_to_yuv(image: Tensor) -> Tensor: | ||
r"""Convert an RGB image to YUV. | ||
|
||
.. image:: _static/img/rgb_to_yuv.png | ||
|
||
The image data is assumed to be in the range of :math:`(0, 1)`. The range of the output is of | ||
:math:`(0, 1)` to luma and the ranges of U and V are :math:`(-0.436, 0.436)` and :math:`(-0.615, 0.615)`, | ||
respectively. | ||
|
||
The YUV model adopted here follows M/PAL values (see | ||
`BT.470-5 <https://www.itu.int/dms_pubrec/itu-r/rec/bt/R-REC-BT.470-5-199802-S!!PDF-E.pdf>`_, Table 2, | ||
items 2.5 and 2.6). | ||
|
||
Args: | ||
image: RGB Image to be converted to YUV with shape :math:`(*, 3, H, W)`. | ||
|
||
Returns: | ||
YUV version of the image with shape :math:`(*, 3, H, W)`. | ||
|
||
Example: | ||
>>> input = torch.rand(2, 3, 4, 5) | ||
>>> output = rgb_to_yuv(input) # 2x3x4x5 | ||
""" | ||
if not isinstance(image, Tensor): | ||
raise TypeError(f"Input type is not a Tensor. Got {type(image)}") | ||
|
||
if len(image.shape) < 3 or image.shape[-3] != 3: | ||
raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") | ||
|
||
r: Tensor = image[..., 0, :, :] | ||
g: Tensor = image[..., 1, :, :] | ||
b: Tensor = image[..., 2, :, :] | ||
|
||
y: Tensor = 0.299 * r + 0.587 * g + 0.114 * b | ||
u: Tensor = -0.147 * r - 0.289 * g + 0.436 * b | ||
v: Tensor = 0.615 * r - 0.515 * g - 0.100 * b | ||
|
||
out: Tensor = torch.stack([y, u, v], -3) | ||
|
||
return out | ||
|
||
|
||
def yuv_to_rgb(image: Tensor) -> Tensor: | ||
r"""Convert an YUV image to RGB. | ||
|
||
The image data is assumed to be in the range of :math:`(0, 1)` for luma (Y). The ranges of U and V are | ||
:math:`(-0.436, 0.436)` and :math:`(-0.615, 0.615)`, respectively. | ||
|
||
YUV formula follows M/PAL values (see | ||
`BT.470-5 <https://www.itu.int/dms_pubrec/itu-r/rec/bt/R-REC-BT.470-5-199802-S!!PDF-E.pdf>`_, Table 2, | ||
items 2.5 and 2.6). | ||
|
||
Args: | ||
image: YUV Image to be converted to RGB with shape :math:`(*, 3, H, W)`. | ||
|
||
Returns: | ||
RGB version of the image with shape :math:`(*, 3, H, W)`. | ||
|
||
Example: | ||
>>> input = torch.rand(2, 3, 4, 5) | ||
>>> output = yuv_to_rgb(input) # 2x3x4x5 | ||
""" | ||
if not isinstance(image, Tensor): | ||
raise TypeError(f"Input type is not a Tensor. Got {type(image)}") | ||
|
||
if image.dim() < 3 or image.shape[-3] != 3: | ||
raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") | ||
|
||
y: Tensor = image[..., 0, :, :] | ||
u: Tensor = image[..., 1, :, :] | ||
v: Tensor = image[..., 2, :, :] | ||
|
||
r: Tensor = y + 1.14 * v # coefficient for g is 0 | ||
g: Tensor = y + -0.396 * u - 0.581 * v | ||
b: Tensor = y + 2.029 * u # coefficient for b is 0 | ||
|
||
out: Tensor = torch.stack([r, g, b], -3) | ||
|
||
return out |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import cv2 | ||
|
||
from ccrestoration import AutoConfig, AutoModel, BaseConfig, ConfigType | ||
from ccrestoration.model import SRBaseModel | ||
|
||
from .util import ASSETS_PATH, calculate_image_similarity, compare_image_size, get_device, load_image | ||
|
||
|
||
class Test_SRCNN: | ||
def test_official(self) -> None: | ||
img1 = load_image() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (testing): Test should include multiple test images with different characteristics Consider testing with multiple images that have different characteristics (e.g., high frequency details, smooth areas, different color patterns) to ensure the SRCNN model handles various scenarios correctly. test_images = [
load_image("natural.png"),
load_image("text.png"),
load_image("pattern.png"),
load_image("gradient.png")
] |
||
|
||
for k in [ | ||
ConfigType.SRCNN_2x, | ||
ConfigType.SRCNN_3x, | ||
ConfigType.SRCNN_4x, | ||
]: | ||
print(f"Testing {k}") | ||
cfg: BaseConfig = AutoConfig.from_pretrained(k) | ||
Comment on lines
+10
to
+19
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (testing): Test case needs more assertions to verify SRCNN functionality The current test only checks image similarity and size. Consider adding assertions to verify:
def test_official(self) -> None:
img1 = load_image()
small_img = cv2.resize(img1, (32, 32))
large_img = cv2.resize(img1, (1024, 1024))
for k in [ConfigType.SRCNN_2x, ConfigType.SRCNN_3x, ConfigType.SRCNN_4x]:
cfg: BaseConfig = AutoConfig.from_pretrained(k)
model: SRBaseModel = AutoModel.from_config(config=cfg, fp16=False, device=get_device())
for test_img in [img1, small_img, large_img]:
img2 = model.inference_image(test_img)
yuv_img = cv2.cvtColor(img2, cv2.COLOR_BGR2YUV)
assert calculate_image_similarity(test_img, img2)
assert compare_image_size(test_img, img2, cfg.scale)
assert yuv_img.shape[2] == 3
assert torch.cuda.max_memory_allocated() < 1024 * 1024 * 1024 # 1GB limit |
||
model: SRBaseModel = AutoModel.from_config(config=cfg, fp16=False, device=get_device()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (testing): Missing test for fp16 mode Add test cases to verify the model works correctly with fp16=True, as this is an important configuration option that could affect model behavior and performance. model_fp16: SRBaseModel = AutoModel.from_config(config=cfg, fp16=True, device=get_device())
model_fp32: SRBaseModel = AutoModel.from_config(config=cfg, fp16=False, device=get_device())
print(model_fp16.device)
print(model_fp32.device) |
||
print(model.device) | ||
|
||
img2 = model.inference_image(img1) | ||
cv2.imwrite(str(ASSETS_PATH / f"test_{k}_out.jpg"), img2) | ||
|
||
assert calculate_image_similarity(img1, img2) | ||
assert compare_image_size(img1, img2, cfg.scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: Add input validation for the scale parameter
Consider adding validation to ensure the scale factor is a positive integer within an expected range to prevent potential runtime errors.