Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support SRCNN #25

Merged
merged 2 commits into from
Nov 24, 2024
Merged

feat: support SRCNN #25

merged 2 commits into from
Nov 24, 2024

Conversation

Tohrusky
Copy link
Member

@Tohrusky Tohrusky commented Nov 24, 2024

Summary by Sourcery

Implement support for the SRCNN model in the ccrestoration library, including new configurations, architecture, and tests. Enhance model loading by renaming variables for clarity and add utility functions for color space conversion.

New Features:

  • Add support for SRCNN architecture in the ccrestoration library, including configurations for 2x, 3x, and 4x scaling.

Enhancements:

  • Refactor model loading to use 'state_dict_path' instead of 'model_path' for clarity.

Documentation:

  • Add documentation for new utility functions 'rgb_to_yuv' and 'yuv_to_rgb' for color space conversion.

Tests:

  • Introduce tests for the SRCNN model to ensure correct functionality and image similarity after processing.

Copy link

sourcery-ai bot commented Nov 24, 2024

Reviewer's Guide by Sourcery

This PR adds support for the SRCNN (Super-Resolution Convolutional Neural Network) model by implementing the architecture, configuration, and model classes. The implementation includes YUV color space conversion utilities and supports multiple scaling factors (2x, 3x, 4x).

Sequence diagram for SRCNN model loading

sequenceDiagram
    participant User
    participant AutoConfig
    participant AutoModel
    participant SRCNNModel
    participant SRCNN
    User->>AutoConfig: from_pretrained(ConfigType)
    AutoConfig-->>User: Return BaseConfig
    User->>AutoModel: from_config(config)
    AutoModel-->>SRCNNModel: Initialize
    SRCNNModel->>SRCNNModel: load_model()
    SRCNNModel->>SRBaseModel: get_state_dict()
    SRBaseModel-->>SRCNNModel: Return state_dict
    SRCNNModel->>SRCNN: Initialize with state_dict
    SRCNN-->>SRCNNModel: Return model
    SRCNNModel-->>User: Return model
Loading

File-Level Changes

Change Details Files
Added SRCNN architecture implementation with YUV color space support
  • Implemented SRCNN model architecture with three convolutional layers
  • Added RGB to YUV and YUV to RGB conversion functions
  • Integrated YUV color space processing for single-channel operation
ccrestoration/arch/srcnn_arch.py
ccrestoration/util/color.py
Added SRCNN configuration and model classes
  • Created SRCNN configuration class with support for different scaling factors
  • Implemented model loading and initialization logic
  • Added pre-trained model configurations for 2x, 3x, and 4x scaling
ccrestoration/config/srcnn_config.py
ccrestoration/model/srcnn_model.py
Updated type system and registry to support SRCNN
  • Added SRCNN to architecture types
  • Added SRCNN to model types
  • Added SRCNN model configuration types
ccrestoration/type/arch.py
ccrestoration/type/config.py
ccrestoration/type/model.py
Added test suite for SRCNN implementation
  • Created test cases for all supported scaling factors
  • Added image similarity and size comparison tests
tests/test_srcnn.py
Improved model state dictionary loading
  • Renamed variable for better clarity
  • Maintained backward compatibility with existing code
ccrestoration/model/sr_base_model.py

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time. You can also use
    this command to specify where the summary should be inserted.

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

Copy link

@sourcery-ai sourcery-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @Tohrusky - I've reviewed your changes - here's some feedback:

Overall Comments:

  • Consider adding documentation about SRCNN architecture and its recommended use cases to help users choose between different models
Here's what I looked at during the review
  • 🟡 General issues: 1 issue found
  • 🟢 Security: all looks good
  • 🟡 Testing: 3 issues found
  • 🟢 Complexity: all looks good
  • 🟢 Documentation: all looks good

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.


@ARCH_REGISTRY.register(name=ArchType.SRCNN)
class SRCNN(nn.Module):
def __init__(self, num_channels: int = 1, scale: int = 2) -> None:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Add input validation for the scale parameter

Consider adding validation to ensure the scale factor is a positive integer within an expected range to prevent potential runtime errors.

    def __init__(self, num_channels: int = 1, scale: int = 2) -> None:
        if not isinstance(scale, int) or scale < 1:
            raise ValueError("Scale factor must be a positive integer")
        if scale > 8:
            raise ValueError("Scale factor must be 8 or less")
        super(SRCNN, self).__init__()

Comment on lines +10 to +19
def test_official(self) -> None:
img1 = load_image()

for k in [
ConfigType.SRCNN_2x,
ConfigType.SRCNN_3x,
ConfigType.SRCNN_4x,
]:
print(f"Testing {k}")
cfg: BaseConfig = AutoConfig.from_pretrained(k)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Test case needs more assertions to verify SRCNN functionality

The current test only checks image similarity and size. Consider adding assertions to verify:

  1. The YUV color space conversion is working correctly
  2. The model's behavior with edge cases (e.g., very small or large images)
  3. The expected output values for a known input image
  4. Memory usage is within expected bounds for large images
def test_official(self) -> None:
    img1 = load_image()
    small_img = cv2.resize(img1, (32, 32))
    large_img = cv2.resize(img1, (1024, 1024))

    for k in [ConfigType.SRCNN_2x, ConfigType.SRCNN_3x, ConfigType.SRCNN_4x]:
        cfg: BaseConfig = AutoConfig.from_pretrained(k)
        model: SRBaseModel = AutoModel.from_config(config=cfg, fp16=False, device=get_device())

        for test_img in [img1, small_img, large_img]:
            img2 = model.inference_image(test_img)
            yuv_img = cv2.cvtColor(img2, cv2.COLOR_BGR2YUV)

            assert calculate_image_similarity(test_img, img2)
            assert compare_image_size(test_img, img2, cfg.scale)
            assert yuv_img.shape[2] == 3
            assert torch.cuda.max_memory_allocated() < 1024 * 1024 * 1024  # 1GB limit


class Test_SRCNN:
def test_official(self) -> None:
img1 = load_image()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Test should include multiple test images with different characteristics

Consider testing with multiple images that have different characteristics (e.g., high frequency details, smooth areas, different color patterns) to ensure the SRCNN model handles various scenarios correctly.

        test_images = [
            load_image("natural.png"),
            load_image("text.png"),
            load_image("pattern.png"),
            load_image("gradient.png")
        ]

]:
print(f"Testing {k}")
cfg: BaseConfig = AutoConfig.from_pretrained(k)
model: SRBaseModel = AutoModel.from_config(config=cfg, fp16=False, device=get_device())
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Missing test for fp16 mode

Add test cases to verify the model works correctly with fp16=True, as this is an important configuration option that could affect model behavior and performance.

            model_fp16: SRBaseModel = AutoModel.from_config(config=cfg, fp16=True, device=get_device())
            model_fp32: SRBaseModel = AutoModel.from_config(config=cfg, fp16=False, device=get_device())
            print(model_fp16.device)
            print(model_fp32.device)

@Tohrusky Tohrusky merged commit c7e8321 into TensoRaws:main Nov 24, 2024
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant