-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support SRCNN #25
Conversation
Reviewer's Guide by SourceryThis PR adds support for the SRCNN (Super-Resolution Convolutional Neural Network) model by implementing the architecture, configuration, and model classes. The implementation includes YUV color space conversion utilities and supports multiple scaling factors (2x, 3x, 4x). Sequence diagram for SRCNN model loadingsequenceDiagram
participant User
participant AutoConfig
participant AutoModel
participant SRCNNModel
participant SRCNN
User->>AutoConfig: from_pretrained(ConfigType)
AutoConfig-->>User: Return BaseConfig
User->>AutoModel: from_config(config)
AutoModel-->>SRCNNModel: Initialize
SRCNNModel->>SRCNNModel: load_model()
SRCNNModel->>SRBaseModel: get_state_dict()
SRBaseModel-->>SRCNNModel: Return state_dict
SRCNNModel->>SRCNN: Initialize with state_dict
SRCNN-->>SRCNNModel: Return model
SRCNNModel-->>User: Return model
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @Tohrusky - I've reviewed your changes - here's some feedback:
Overall Comments:
- Consider adding documentation about SRCNN architecture and its recommended use cases to help users choose between different models
Here's what I looked at during the review
- 🟡 General issues: 1 issue found
- 🟢 Security: all looks good
- 🟡 Testing: 3 issues found
- 🟢 Complexity: all looks good
- 🟢 Documentation: all looks good
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
|
||
@ARCH_REGISTRY.register(name=ArchType.SRCNN) | ||
class SRCNN(nn.Module): | ||
def __init__(self, num_channels: int = 1, scale: int = 2) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: Add input validation for the scale parameter
Consider adding validation to ensure the scale factor is a positive integer within an expected range to prevent potential runtime errors.
def __init__(self, num_channels: int = 1, scale: int = 2) -> None:
if not isinstance(scale, int) or scale < 1:
raise ValueError("Scale factor must be a positive integer")
if scale > 8:
raise ValueError("Scale factor must be 8 or less")
super(SRCNN, self).__init__()
def test_official(self) -> None: | ||
img1 = load_image() | ||
|
||
for k in [ | ||
ConfigType.SRCNN_2x, | ||
ConfigType.SRCNN_3x, | ||
ConfigType.SRCNN_4x, | ||
]: | ||
print(f"Testing {k}") | ||
cfg: BaseConfig = AutoConfig.from_pretrained(k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (testing): Test case needs more assertions to verify SRCNN functionality
The current test only checks image similarity and size. Consider adding assertions to verify:
- The YUV color space conversion is working correctly
- The model's behavior with edge cases (e.g., very small or large images)
- The expected output values for a known input image
- Memory usage is within expected bounds for large images
def test_official(self) -> None:
img1 = load_image()
small_img = cv2.resize(img1, (32, 32))
large_img = cv2.resize(img1, (1024, 1024))
for k in [ConfigType.SRCNN_2x, ConfigType.SRCNN_3x, ConfigType.SRCNN_4x]:
cfg: BaseConfig = AutoConfig.from_pretrained(k)
model: SRBaseModel = AutoModel.from_config(config=cfg, fp16=False, device=get_device())
for test_img in [img1, small_img, large_img]:
img2 = model.inference_image(test_img)
yuv_img = cv2.cvtColor(img2, cv2.COLOR_BGR2YUV)
assert calculate_image_similarity(test_img, img2)
assert compare_image_size(test_img, img2, cfg.scale)
assert yuv_img.shape[2] == 3
assert torch.cuda.max_memory_allocated() < 1024 * 1024 * 1024 # 1GB limit
|
||
class Test_SRCNN: | ||
def test_official(self) -> None: | ||
img1 = load_image() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (testing): Test should include multiple test images with different characteristics
Consider testing with multiple images that have different characteristics (e.g., high frequency details, smooth areas, different color patterns) to ensure the SRCNN model handles various scenarios correctly.
test_images = [
load_image("natural.png"),
load_image("text.png"),
load_image("pattern.png"),
load_image("gradient.png")
]
]: | ||
print(f"Testing {k}") | ||
cfg: BaseConfig = AutoConfig.from_pretrained(k) | ||
model: SRBaseModel = AutoModel.from_config(config=cfg, fp16=False, device=get_device()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (testing): Missing test for fp16 mode
Add test cases to verify the model works correctly with fp16=True, as this is an important configuration option that could affect model behavior and performance.
model_fp16: SRBaseModel = AutoModel.from_config(config=cfg, fp16=True, device=get_device())
model_fp32: SRBaseModel = AutoModel.from_config(config=cfg, fp16=False, device=get_device())
print(model_fp16.device)
print(model_fp32.device)
Summary by Sourcery
Implement support for the SRCNN model in the ccrestoration library, including new configurations, architecture, and tests. Enhance model loading by renaming variables for clarity and add utility functions for color space conversion.
New Features:
Enhancements:
Documentation:
Tests: