From 6457c712831b9c23cd541d5e09d18876777ef295 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Sat, 28 Dec 2024 01:00:30 +0400 Subject: [PATCH] Remove AugmentationSequential wrapper --- tests/transforms/test_transforms.py | 123 +++++++++++----------------- torchgeo/transforms/__init__.py | 2 - torchgeo/transforms/transforms.py | 97 ---------------------- 3 files changed, 46 insertions(+), 176 deletions(-) diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py index 1f2071ae812..1c8e7e274eb 100644 --- a/tests/transforms/test_transforms.py +++ b/tests/transforms/test_transforms.py @@ -4,10 +4,10 @@ import kornia.augmentation as K import pytest import torch +from kornia.contrib import ExtractTensorPatches from torch import Tensor -from torchgeo.transforms import indices, transforms -from torchgeo.transforms.transforms import _ExtractPatches +from torchgeo.transforms import indices # Kornia is very particular about its boxes: # @@ -23,7 +23,7 @@ def batch_gray() -> dict[str, Tensor]: return { 'image': torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float), 'mask': torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), - 'boxes': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), + 'bbox_xyxy': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), 'labels': torch.tensor([[0, 1]]), } @@ -42,7 +42,7 @@ def batch_rgb() -> dict[str, Tensor]: dtype=torch.float, ), 'mask': torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), - 'boxes': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), + 'bbox_xyxy': torch.tensor([[0.0, 1.0, 1.0, 2.0]], dtype=torch.float), 'labels': torch.tensor([[0, 1]]), } @@ -63,7 +63,7 @@ def batch_multispectral() -> dict[str, Tensor]: dtype=torch.float, ), 'mask': torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), - 'boxes': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), + 'bbox_xyxy': torch.tensor([[0.0, 1.0, 1.0, 2.0]], dtype=torch.float), 'labels': torch.tensor([[0, 1]]), } @@ -79,12 +79,10 @@ def test_augmentation_sequential_gray(batch_gray: dict[str, Tensor]) -> None: expected = { 'image': torch.tensor([[[[3, 2, 1], [6, 5, 4], [9, 8, 7]]]], dtype=torch.float), 'mask': torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), - 'boxes': torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), + 'bbox_xyxy': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), 'labels': torch.tensor([[0, 1]]), } - augs = transforms.AugmentationSequential( - K.RandomHorizontalFlip(p=1.0), data_keys=['image', 'mask', 'boxes'] - ) + augs = K.AugmentationSequential(K.RandomHorizontalFlip(p=1.0), data_keys=None) output = augs(batch_gray) assert_matching(output, expected) @@ -102,12 +100,10 @@ def test_augmentation_sequential_rgb(batch_rgb: dict[str, Tensor]) -> None: dtype=torch.float, ), 'mask': torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), - 'boxes': torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), + 'bbox_xyxy': torch.tensor([[1.0, 1.0, 2.0, 2.0]], dtype=torch.float), 'labels': torch.tensor([[0, 1]]), } - augs = transforms.AugmentationSequential( - K.RandomHorizontalFlip(p=1.0), data_keys=['image', 'mask', 'boxes'] - ) + augs = K.AugmentationSequential(K.RandomHorizontalFlip(p=1.0), data_keys=None) output = augs(batch_rgb) assert_matching(output, expected) @@ -119,22 +115,20 @@ def test_augmentation_sequential_multispectral( 'image': torch.tensor( [ [ - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], + [[7, 8, 9], [4, 5, 6], [1, 2, 3]], + [[7, 8, 9], [4, 5, 6], [1, 2, 3]], + [[7, 8, 9], [4, 5, 6], [1, 2, 3]], + [[7, 8, 9], [4, 5, 6], [1, 2, 3]], + [[7, 8, 9], [4, 5, 6], [1, 2, 3]], ] ], dtype=torch.float, ), - 'mask': torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), - 'boxes': torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), + 'mask': torch.tensor([[[1, 1, 1], [0, 1, 1], [0, 0, 1]]], dtype=torch.long), + 'bbox_xyxy': torch.tensor([[0.0, 0.0, 1.0, 1.0]], dtype=torch.float), 'labels': torch.tensor([[0, 1]]), } - augs = transforms.AugmentationSequential( - K.RandomHorizontalFlip(p=1.0), data_keys=['image', 'mask', 'boxes'] - ) + augs = K.AugmentationSequential(K.RandomVerticalFlip(p=1.0), data_keys=None) output = augs(batch_multispectral) assert_matching(output, expected) @@ -142,28 +136,22 @@ def test_augmentation_sequential_multispectral( def test_augmentation_sequential_image_only( batch_multispectral: dict[str, Tensor], ) -> None: - expected = { - 'image': torch.tensor( + expected_image = torch.tensor( + [ [ - [ - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - ] - ], - dtype=torch.float, - ), - 'mask': torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), - 'boxes': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), - 'labels': torch.tensor([[0, 1]]), - } - augs = transforms.AugmentationSequential( - K.RandomHorizontalFlip(p=1.0), data_keys=['image'] + [[3, 2, 1], [6, 5, 4], [9, 8, 7]], + [[3, 2, 1], [6, 5, 4], [9, 8, 7]], + [[3, 2, 1], [6, 5, 4], [9, 8, 7]], + [[3, 2, 1], [6, 5, 4], [9, 8, 7]], + [[3, 2, 1], [6, 5, 4], [9, 8, 7]], + ] + ], + dtype=torch.float, ) - output = augs(batch_multispectral) - assert_matching(output, expected) + + augs = K.AugmentationSequential(K.RandomHorizontalFlip(p=1.0), data_keys=['image']) + aug_image = augs(batch_multispectral['image']) + assert torch.allclose(aug_image, expected_image) def test_sequential_transforms_augmentations( @@ -188,17 +176,17 @@ def test_sequential_transforms_augmentations( dtype=torch.float, ), 'mask': torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), - 'boxes': torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), + 'bbox_xyxy': torch.tensor([[1.0, 1.0, 2.0, 2.0]], dtype=torch.float), 'labels': torch.tensor([[0, 1]]), } - train_transforms = transforms.AugmentationSequential( + train_transforms = K.AugmentationSequential( indices.AppendNBR(index_nir=0, index_swir=0), indices.AppendNDBI(index_swir=0, index_nir=0), indices.AppendNDSI(index_green=0, index_swir=0), indices.AppendNDVI(index_red=0, index_nir=0), indices.AppendNDWI(index_green=0, index_nir=0), K.RandomHorizontalFlip(p=1.0), - data_keys=['image', 'mask', 'boxes'], + data_keys=None, ) output = train_transforms(batch_multispectral) assert_matching(output, expected) @@ -215,12 +203,12 @@ def test_extract_patches() -> None: 'image': torch.randn(size=(b, c, h, w)), 'mask': torch.randint(low=0, high=2, size=(b, h, w)), } - train_transforms = transforms.AugmentationSequential( - _ExtractPatches(window_size=p), same_on_batch=True, data_keys=['image', 'mask'] - ) - output = train_transforms(batch) - assert batch['image'].shape == (b * num_patches, c, p, p) - assert batch['mask'].shape == (b * num_patches, p, p) + train_transforms = ExtractTensorPatches(p, s) + output = {} + output['image'] = train_transforms(batch['image']) + output['mask'] = train_transforms(batch['mask'].unsqueeze(1)).squeeze(2) + assert output['image'].shape == (b, num_patches, c, p, p) + assert output['mask'].shape == (b, num_patches, p, p) # Test different stride s = 16 @@ -229,29 +217,10 @@ def test_extract_patches() -> None: 'image': torch.randn(size=(b, c, h, w)), 'mask': torch.randint(low=0, high=2, size=(b, h, w)), } - train_transforms = transforms.AugmentationSequential( - _ExtractPatches(window_size=p, stride=s), - same_on_batch=True, - data_keys=['image', 'mask'], - ) - output = train_transforms(batch) - assert batch['image'].shape == (b * num_patches, c, p, p) - assert batch['mask'].shape == (b * num_patches, p, p) - # Test keepdim=False - s = p - num_patches = ((h - p + s) // s) * ((w - p + s) // s) - batch = { - 'image': torch.randn(size=(b, c, h, w)), - 'mask': torch.randint(low=0, high=2, size=(b, h, w)), - } - train_transforms = transforms.AugmentationSequential( - _ExtractPatches(window_size=p, stride=s, keepdim=False), - same_on_batch=True, - data_keys=['image', 'mask'], - ) - output = train_transforms(batch) - for k, v in output.items(): - print(k, v.shape, v.dtype) - assert batch['image'].shape == (b, num_patches, c, p, p) - assert batch['mask'].shape == (b, num_patches, 1, p, p) + train_transforms = ExtractTensorPatches(p, stride=16) + output = {} + output['image'] = train_transforms(batch['image']) + output['mask'] = train_transforms(batch['mask'].unsqueeze(1)).squeeze(2) + assert output['image'].shape == (b, num_patches, c, p, p) + assert output['mask'].shape == (b, num_patches, p, p) diff --git a/torchgeo/transforms/__init__.py b/torchgeo/transforms/__init__.py index 5a0f9ee3392..34291d71345 100644 --- a/torchgeo/transforms/__init__.py +++ b/torchgeo/transforms/__init__.py @@ -20,7 +20,6 @@ AppendSWI, AppendTriBandNormalizedDifferenceIndex, ) -from .transforms import AugmentationSequential __all__ = ( 'AppendBNDVI', @@ -37,6 +36,5 @@ 'AppendRBNDVI', 'AppendSWI', 'AppendTriBandNormalizedDifferenceIndex', - 'AugmentationSequential', 'RandomGrayscale', ) diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index d8f80bdcaac..97a1fc1757f 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -10,104 +10,7 @@ from einops import rearrange from kornia.contrib import extract_tensor_patches from kornia.geometry import crop_by_indices -from kornia.geometry.boxes import Boxes from torch import Tensor -from torch.nn.modules import Module - - -# TODO: contribute these to Kornia and delete this file -class AugmentationSequential(Module): - """Wrapper around kornia AugmentationSequential to handle input dicts. - - .. deprecated:: 0.4 - Use :class:`kornia.augmentation.container.AugmentationSequential` instead. - """ - - def __init__( - self, - *args: K.base._AugmentationBase | K.ImageSequential, - data_keys: list[str], - **kwargs: Any, - ) -> None: - """Initialize a new augmentation sequential instance. - - Args: - *args: Sequence of kornia augmentations - data_keys: List of inputs to augment (e.g., ["image", "mask", "boxes"]) - **kwargs: Keyword arguments passed to ``K.AugmentationSequential`` - - .. versionadded:: 0.5 - The ``**kwargs`` parameter. - """ - super().__init__() - self.data_keys = data_keys - - keys: list[str] = [] - for key in data_keys: - if key.startswith('image'): - keys.append('input') - elif key == 'boxes': - keys.append('bbox') - elif key == 'masks': - keys.append('mask') - else: - keys.append(key) - - self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs) - - def forward(self, batch: dict[str, Any]) -> dict[str, Any]: - """Perform augmentations and update data dict. - - Args: - batch: the input - - Returns: - the augmented input - """ - # Kornia augmentations require all inputs to be float - dtype = {} - for key in self.data_keys: - dtype[key] = batch[key].dtype - batch[key] = batch[key].float() - - # Convert shape of boxes from [N, 4] to [N, 4, 2] - if 'boxes' in batch and ( - isinstance(batch['boxes'], list) or batch['boxes'].ndim == 2 - ): - batch['boxes'] = Boxes.from_tensor(batch['boxes']).data - - # Kornia requires masks to have a channel dimension - if 'mask' in batch and batch['mask'].ndim == 3: - batch['mask'] = rearrange(batch['mask'], 'b h w -> b () h w') - - if 'masks' in batch and batch['masks'].ndim == 3: - batch['masks'] = rearrange(batch['masks'], 'c h w -> () c h w') - - inputs = [batch[k] for k in self.data_keys] - outputs_list: Tensor | list[Tensor] = self.augs(*inputs) - outputs_list = ( - outputs_list if isinstance(outputs_list, list) else [outputs_list] - ) - outputs: dict[str, Tensor] = { - k: v for k, v in zip(self.data_keys, outputs_list) - } - batch.update(outputs) - - # Convert all inputs back to their previous dtype - for key in self.data_keys: - batch[key] = batch[key].to(dtype[key]) - - # Convert boxes to default [N, 4] - if 'boxes' in batch: - batch['boxes'] = Boxes(batch['boxes']).to_tensor(mode='xyxy') - - # Torchmetrics does not support masks with a channel dimension - if 'mask' in batch and batch['mask'].shape[1] == 1: - batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w') - if 'masks' in batch and batch['masks'].ndim == 4: - batch['masks'] = rearrange(batch['masks'], '() c h w -> c h w') - - return batch class _RandomNCrop(K.GeometricAugmentationBase2D):