-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support HAT #29
Conversation
# Conflicts: # ccrestoration/arch/dat_arch.py
# Conflicts: # ccrestoration/arch/dat_arch.py
# Conflicts: # ccrestoration/arch/__init__.py # ccrestoration/config/__init__.py # ccrestoration/config/dat_config.py # ccrestoration/model/__init__.py # ccrestoration/model/dat_model.py # ccrestoration/type/arch.py # ccrestoration/type/config.py # ccrestoration/type/model.py # tests/test_dat.py
Reviewer's Guide by SourceryThis PR adds support for the HAT (Hybrid Attention Transformer) model for image super-resolution. The implementation includes the core HAT architecture, configuration handling, and model loading functionality along with various pre-trained model configurations. ER diagram for HAT configurationserDiagram
HATConfig {
string name
string url
string hash
int scale
int in_chans
int window_size
int compress_ratio
int squeeze_factor
float conv_scale
float overlap_ratio
float img_range
list depth
int embed_dim
list num_heads
float mlp_ratio
string upsampler
string resi_connection
}
CONFIG_REGISTRY ||--o{ HATConfig : registers
HATConfig ||--o{ HATConfigs : contains
Class diagram for the HAT modelclassDiagram
class HAT {
+int img_size
+int patch_size
+int in_chans
+int embed_dim
+tuple depth
+tuple num_heads
+int window_size
+float mlp_ratio
+bool qkv_bias
+float qk_scale
+float drop_rate
+float attn_drop_rate
+float drop_path_rate
+nn.Module norm_layer
+bool ape
+bool patch_norm
+int upscale
+float img_range
+str upsampler
+str resi_connection
+forward(x)
}
class HATConfig {
+ArchType arch
+ModelType model
+int scale
+int patch_size
+int in_chans
+Union img_size
+float img_range
+Union depth
+int embed_dim
+Union num_heads
+int window_size
+int compress_ratio
+int squeeze_factor
+float conv_scale
+float overlap_ratio
+float mlp_ratio
+str resi_connection
+bool qkv_bias
+Optional qk_scale
+float drop_rate
+float attn_drop_rate
+float drop_path_rate
+bool ape
+bool patch_norm
+Any act_layer
+Any norm_layer
+str upsampler
}
HATModel --|> SRBaseModel
HATModel : +load_model()
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @routineLife1 - I've reviewed your changes - here's some feedback:
Overall Comments:
- Please expand the PR description to explain what HAT is, its key benefits, and typical use cases compared to other models in the codebase.
Here's what I looked at during the review
- 🟢 General issues: all looks good
- 🟢 Security: all looks good
- 🟡 Testing: 1 issue found
- 🟡 Complexity: 1 issue found
- 🟢 Documentation: all looks good
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
def test_official_light(self) -> None: | ||
img1 = load_image() | ||
|
||
for k in [ | ||
ConfigType.HAT_S_2x, | ||
ConfigType.HAT_S_3x, | ||
ConfigType.HAT_S_4x, | ||
]: | ||
print(f"Testing {k}") | ||
cfg: BaseConfig = AutoConfig.from_pretrained(k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (testing): The test only checks a subset of HAT models and basic functionality
Consider adding more specific assertions to verify the quality of the upscaled images, such as checking PSNR/SSIM metrics against expected values for known test images. Also consider testing edge cases like images with different aspect ratios or extreme dimensions.
Suggested implementation:
from .util import ASSETS_PATH, calculate_image_similarity, compare_image_size, get_device, load_image, calculate_psnr, calculate_ssim
assert calculate_image_similarity(img1, img2), "Image similarity check failed"
assert calculate_psnr(img1, img2) > 30, "PSNR is below acceptable threshold"
assert calculate_ssim(img1, img2) > 0.9, "SSIM is below acceptable threshold"
def test_edge_cases(self) -> None:
# Test with an image with a different aspect ratio
img1 = load_image(aspect_ratio="16:9")
for k in [
ConfigType.HAT_S_2x,
ConfigType.HAT_S_3x,
ConfigType.HAT_S_4x,
]:
cfg: BaseConfig = AutoConfig.from_pretrained(k)
model: SRBaseModel = AutoModel.from_config(config=cfg, fp16=False, device=get_device())
img2 = model.inference_image(img1)
assert calculate_image_similarity(img1, img2), "Image similarity check failed for aspect ratio"
assert calculate_psnr(img1, img2) > 30, "PSNR is below acceptable threshold for aspect ratio"
assert calculate_ssim(img1, img2) > 0.9, "SSIM is below acceptable threshold for aspect ratio"
# Test with an image with extreme dimensions
img1 = load_image(dimensions=(4096, 4096))
for k in [
ConfigType.HAT_S_2x,
ConfigType.HAT_S_3x,
ConfigType.HAT_S_4x,
]:
cfg: BaseConfig = AutoConfig.from_pretrained(k)
model: SRBaseModel = AutoModel.from_config(config=cfg, fp16=False, device=get_device())
img2 = model.inference_image(img1)
assert calculate_image_similarity(img1, img2), "Image similarity check failed for extreme dimensions"
assert calculate_psnr(img1, img2) > 30, "PSNR is below acceptable threshold for extreme dimensions"
assert calculate_ssim(img1, img2) > 0.9, "SSIM is below acceptable threshold for extreme dimensions"
@pytest.mark.skipif(os.environ.get("GITHUB_ACTIONS") == "true", reason="Skip on CI test")
- Implement
calculate_psnr
andcalculate_ssim
functions in theutil
module if they do not already exist. - Ensure
load_image
can handle loading images with specified aspect ratios and dimensions. - Adjust the PSNR and SSIM threshold values as needed based on the quality requirements.
return x | ||
|
||
|
||
class HAB(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (complexity): Consider splitting the HAB class into separate AttentionBlock and ConvolutionBlock components to clarify responsibilities and simplify maintenance.
The HAB (Hybrid Attention Block) class combines too many responsibilities, making it harder to maintain and understand. Consider splitting it into focused components:
class AttentionBlock(nn.Module):
def __init__(self, dim, num_heads, window_size, shift_size=0):
super().__init__()
self.attn = WindowAttention(dim, window_size, num_heads)
self.norm = nn.LayerNorm(dim)
self.shift_size = shift_size
def forward(self, x, x_size, rpi_sa, attn_mask):
# Handle window attention with optional shift
if self.shift_size > 0:
x = self._apply_shift(x, x_size)
return self.attn(self.norm(x), rpi_sa, attn_mask)
class ConvolutionBlock(nn.Module):
def __init__(self, dim, compress_ratio, squeeze_factor):
super().__init__()
self.conv = CAB(dim, compress_ratio, squeeze_factor)
self.norm = nn.LayerNorm(dim)
def forward(self, x):
return self.conv(self.norm(x))
class HAB(nn.Module):
def __init__(self, dim, num_heads, window_size, compress_ratio, squeeze_factor):
super().__init__()
self.attn_block = AttentionBlock(dim, num_heads, window_size)
self.conv_block = ConvolutionBlock(dim, compress_ratio, squeeze_factor)
self.mlp = Mlp(dim)
def forward(self, x, x_size, rpi_sa, attn_mask):
x = x + self.attn_block(x, x_size, rpi_sa, attn_mask)
x = x + self.conv_block(x)
return x + self.mlp(x)
This refactoring:
- Separates attention and convolution into distinct blocks
- Makes each component's responsibility clear
- Simplifies testing and maintenance
- Maintains all existing functionality
add support for HAT
Summary by Sourcery
Add support for the Hybrid Attention Transformer (HAT) model, including new configurations, model types, and associated tests.
New Features:
Tests: