-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: SRCNN impl and tests #26
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import cv2 | ||
import pytest | ||
import torch | ||
from torchvision import transforms | ||
|
||
from ccrestoration.util.color import rgb_to_yuv, yuv_to_rgb | ||
from ccrestoration.util.device import DEFAULT_DEVICE | ||
|
||
from .util import calculate_image_similarity, load_image | ||
|
||
|
||
def test_device() -> None: | ||
print(DEFAULT_DEVICE) | ||
|
||
|
||
def test_color() -> None: | ||
with pytest.raises(TypeError): | ||
rgb_to_yuv(1) | ||
with pytest.raises(TypeError): | ||
yuv_to_rgb(1) | ||
|
||
with pytest.raises(ValueError): | ||
rgb_to_yuv(torch.zeros(1, 1)) | ||
with pytest.raises(ValueError): | ||
yuv_to_rgb(torch.zeros(1, 1)) | ||
Comment on lines
+16
to
+25
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (testing): Test function Consider splitting this test into multiple test cases: one for error conditions, one for RGB->YUV conversion, and one for YUV->RGB conversion. Also, consider adding assertions about the specific values or ranges in the YUV color space to ensure the conversion is correct. @pytest.mark.parametrize("invalid_input", [1, "string", [1, 2, 3]])
def test_color_type_errors(invalid_input) -> None:
with pytest.raises(TypeError):
rgb_to_yuv(invalid_input)
with pytest.raises(TypeError):
yuv_to_rgb(invalid_input)
@pytest.mark.parametrize("invalid_shape", [torch.zeros(1, 1), torch.zeros(2, 3, 4)])
def test_color_value_errors(invalid_shape) -> None:
with pytest.raises(ValueError):
rgb_to_yuv(invalid_shape)
with pytest.raises(ValueError):
yuv_to_rgb(invalid_shape)
def test_color_conversion() -> None:
img = load_image()
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_tensor = transforms.ToTensor()(img_rgb).unsqueeze(0).to("cpu")
yuv = rgb_to_yuv(img_tensor)
assert yuv.shape == img_tensor.shape
assert (yuv[:, 0] >= -1).all() and (yuv[:, 0] <= 1).all() # Y channel bounds
rgb = yuv_to_rgb(yuv)
rgb_np = (rgb.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).clip(0, 255).astype("uint8")
rgb_np = cv2.cvtColor(rgb_np, cv2.COLOR_RGB2BGR)
assert calculate_image_similarity(rgb_np, img) |
||
|
||
img = load_image() | ||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | ||
|
||
img = transforms.ToTensor()(img).unsqueeze(0).to("cpu") | ||
|
||
img = rgb_to_yuv(img) | ||
img = yuv_to_rgb(img) | ||
|
||
img = img.squeeze(0).permute(1, 2, 0).cpu().numpy() | ||
img = (img * 255).clip(0, 255).astype("uint8") | ||
|
||
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | ||
|
||
assert calculate_image_similarity(img, load_image()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (testing): Test function
test_device
only prints the device without any assertionsThis test doesn't verify any behavior. Consider adding assertions to check if DEFAULT_DEVICE is set correctly or has expected properties.