Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support HAT #29

Merged
merged 9 commits into from
Dec 12, 2024
Merged

feat: support HAT #29

merged 9 commits into from
Dec 12, 2024

Conversation

routineLife1
Copy link
Member

@routineLife1 routineLife1 commented Dec 12, 2024

add support for HAT

Summary by Sourcery

Add support for the Hybrid Attention Transformer (HAT) model, including new configurations, model types, and associated tests.

New Features:

  • Introduce support for the Hybrid Attention Transformer (HAT) model in the ccrestoration project, including various configurations and model types.

Tests:

  • Add tests for the HAT model to ensure its functionality and performance, including tests for different configurations and image scales.

# Conflicts:
#	ccrestoration/arch/dat_arch.py
# Conflicts:
#	ccrestoration/arch/dat_arch.py
# Conflicts:
#	ccrestoration/arch/__init__.py
#	ccrestoration/config/__init__.py
#	ccrestoration/config/dat_config.py
#	ccrestoration/model/__init__.py
#	ccrestoration/model/dat_model.py
#	ccrestoration/type/arch.py
#	ccrestoration/type/config.py
#	ccrestoration/type/model.py
#	tests/test_dat.py
Copy link

sourcery-ai bot commented Dec 12, 2024

Reviewer's Guide by Sourcery

This PR adds support for the HAT (Hybrid Attention Transformer) model for image super-resolution. The implementation includes the core HAT architecture, configuration handling, and model loading functionality along with various pre-trained model configurations.

ER diagram for HAT configurations

erDiagram
    HATConfig {
        string name
        string url
        string hash
        int scale
        int in_chans
        int window_size
        int compress_ratio
        int squeeze_factor
        float conv_scale
        float overlap_ratio
        float img_range
        list depth
        int embed_dim
        list num_heads
        float mlp_ratio
        string upsampler
        string resi_connection
    }

    CONFIG_REGISTRY ||--o{ HATConfig : registers
    HATConfig ||--o{ HATConfigs : contains
Loading

Class diagram for the HAT model

classDiagram
    class HAT {
        +int img_size
        +int patch_size
        +int in_chans
        +int embed_dim
        +tuple depth
        +tuple num_heads
        +int window_size
        +float mlp_ratio
        +bool qkv_bias
        +float qk_scale
        +float drop_rate
        +float attn_drop_rate
        +float drop_path_rate
        +nn.Module norm_layer
        +bool ape
        +bool patch_norm
        +int upscale
        +float img_range
        +str upsampler
        +str resi_connection
        +forward(x)
    }

    class HATConfig {
        +ArchType arch
        +ModelType model
        +int scale
        +int patch_size
        +int in_chans
        +Union img_size
        +float img_range
        +Union depth
        +int embed_dim
        +Union num_heads
        +int window_size
        +int compress_ratio
        +int squeeze_factor
        +float conv_scale
        +float overlap_ratio
        +float mlp_ratio
        +str resi_connection
        +bool qkv_bias
        +Optional qk_scale
        +float drop_rate
        +float attn_drop_rate
        +float drop_path_rate
        +bool ape
        +bool patch_norm
        +Any act_layer
        +Any norm_layer
        +str upsampler
    }

    HATModel --|> SRBaseModel
    HATModel : +load_model()
Loading

File-Level Changes

Change Details Files
Added HAT model architecture implementation
  • Implemented HAT (Hybrid Attention Transformer) model architecture with support for different scales (2x, 3x, 4x)
  • Added window-based self-attention mechanism with relative position bias
  • Implemented overlapping cross-attention block (OCAB)
  • Added channel attention mechanism
  • Implemented residual hybrid attention groups (RHAG)
ccrestoration/arch/hat_arch.py
Added HAT model configuration handling
  • Created HATConfig class with model parameters and validation
  • Added configurations for different HAT variants (HAT-S, HAT-M, HAT-L)
  • Added pre-trained model configurations with URLs and hash values
  • Implemented configuration validation for scale factors and upsampler options
ccrestoration/config/hat_config.py
Added HAT model type registration and loading
  • Added HAT model type to enums
  • Implemented HAT model loading functionality
  • Added state dictionary handling for different model formats
ccrestoration/type/config.py
ccrestoration/type/arch.py
ccrestoration/type/model.py
ccrestoration/model/hat_model.py
Added HAT model tests
  • Added tests for official HAT model variants
  • Implemented tests for custom HAT configurations
  • Added image similarity and size comparison tests
tests/test_hat.py

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time. You can also use
    this command to specify where the summary should be inserted.

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

Copy link

@sourcery-ai sourcery-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @routineLife1 - I've reviewed your changes - here's some feedback:

Overall Comments:

  • Please expand the PR description to explain what HAT is, its key benefits, and typical use cases compared to other models in the codebase.
Here's what I looked at during the review
  • 🟢 General issues: all looks good
  • 🟢 Security: all looks good
  • 🟡 Testing: 1 issue found
  • 🟡 Complexity: 1 issue found
  • 🟢 Documentation: all looks good

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

Comment on lines +13 to +22
def test_official_light(self) -> None:
img1 = load_image()

for k in [
ConfigType.HAT_S_2x,
ConfigType.HAT_S_3x,
ConfigType.HAT_S_4x,
]:
print(f"Testing {k}")
cfg: BaseConfig = AutoConfig.from_pretrained(k)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): The test only checks a subset of HAT models and basic functionality

Consider adding more specific assertions to verify the quality of the upscaled images, such as checking PSNR/SSIM metrics against expected values for known test images. Also consider testing edge cases like images with different aspect ratios or extreme dimensions.

Suggested implementation:

from .util import ASSETS_PATH, calculate_image_similarity, compare_image_size, get_device, load_image, calculate_psnr, calculate_ssim
            assert calculate_image_similarity(img1, img2), "Image similarity check failed"
            assert calculate_psnr(img1, img2) > 30, "PSNR is below acceptable threshold"
            assert calculate_ssim(img1, img2) > 0.9, "SSIM is below acceptable threshold"
    def test_edge_cases(self) -> None:
        # Test with an image with a different aspect ratio
        img1 = load_image(aspect_ratio="16:9")
        for k in [
            ConfigType.HAT_S_2x,
            ConfigType.HAT_S_3x,
            ConfigType.HAT_S_4x,
        ]:
            cfg: BaseConfig = AutoConfig.from_pretrained(k)
            model: SRBaseModel = AutoModel.from_config(config=cfg, fp16=False, device=get_device())
            img2 = model.inference_image(img1)
            assert calculate_image_similarity(img1, img2), "Image similarity check failed for aspect ratio"
            assert calculate_psnr(img1, img2) > 30, "PSNR is below acceptable threshold for aspect ratio"
            assert calculate_ssim(img1, img2) > 0.9, "SSIM is below acceptable threshold for aspect ratio"

        # Test with an image with extreme dimensions
        img1 = load_image(dimensions=(4096, 4096))
        for k in [
            ConfigType.HAT_S_2x,
            ConfigType.HAT_S_3x,
            ConfigType.HAT_S_4x,
        ]:
            cfg: BaseConfig = AutoConfig.from_pretrained(k)
            model: SRBaseModel = AutoModel.from_config(config=cfg, fp16=False, device=get_device())
            img2 = model.inference_image(img1)
            assert calculate_image_similarity(img1, img2), "Image similarity check failed for extreme dimensions"
            assert calculate_psnr(img1, img2) > 30, "PSNR is below acceptable threshold for extreme dimensions"
            assert calculate_ssim(img1, img2) > 0.9, "SSIM is below acceptable threshold for extreme dimensions"

    @pytest.mark.skipif(os.environ.get("GITHUB_ACTIONS") == "true", reason="Skip on CI test")
  1. Implement calculate_psnr and calculate_ssim functions in the util module if they do not already exist.
  2. Ensure load_image can handle loading images with specified aspect ratios and dimensions.
  3. Adjust the PSNR and SSIM threshold values as needed based on the quality requirements.

return x


class HAB(nn.Module):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider splitting the HAB class into separate AttentionBlock and ConvolutionBlock components to clarify responsibilities and simplify maintenance.

The HAB (Hybrid Attention Block) class combines too many responsibilities, making it harder to maintain and understand. Consider splitting it into focused components:

class AttentionBlock(nn.Module):
    def __init__(self, dim, num_heads, window_size, shift_size=0):
        super().__init__()
        self.attn = WindowAttention(dim, window_size, num_heads)
        self.norm = nn.LayerNorm(dim)
        self.shift_size = shift_size

    def forward(self, x, x_size, rpi_sa, attn_mask):
        # Handle window attention with optional shift
        if self.shift_size > 0:
            x = self._apply_shift(x, x_size)
        return self.attn(self.norm(x), rpi_sa, attn_mask)

class ConvolutionBlock(nn.Module):
    def __init__(self, dim, compress_ratio, squeeze_factor):
        super().__init__()
        self.conv = CAB(dim, compress_ratio, squeeze_factor)
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        return self.conv(self.norm(x))

class HAB(nn.Module):
    def __init__(self, dim, num_heads, window_size, compress_ratio, squeeze_factor):
        super().__init__()
        self.attn_block = AttentionBlock(dim, num_heads, window_size)
        self.conv_block = ConvolutionBlock(dim, compress_ratio, squeeze_factor)
        self.mlp = Mlp(dim)

    def forward(self, x, x_size, rpi_sa, attn_mask):
        x = x + self.attn_block(x, x_size, rpi_sa, attn_mask)
        x = x + self.conv_block(x)
        return x + self.mlp(x)

This refactoring:

  1. Separates attention and convolution into distinct blocks
  2. Makes each component's responsibility clear
  3. Simplifies testing and maintenance
  4. Maintains all existing functionality

@Tohrusky Tohrusky merged commit 731bade into TensoRaws:main Dec 12, 2024
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants