diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index 97a1fc1757..a98fe99529 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -7,8 +7,6 @@ import kornia.augmentation as K import torch -from einops import rearrange -from kornia.contrib import extract_tensor_patches from kornia.geometry import crop_by_indices from torch import Tensor @@ -103,80 +101,6 @@ def forward( } -class _ExtractPatches(K.GeometricAugmentationBase2D): - """Extract patches from an image or mask.""" - - def __init__( - self, - window_size: int | tuple[int, int], - stride: int | tuple[int, int] | None = None, - padding: int | tuple[int, int] | None = 0, - keepdim: bool = True, - ) -> None: - """Initialize a new _ExtractPatches instance. - - Args: - window_size: desired output size (out_h, out_w) of the crop - stride: stride of window to extract patches. Defaults to non-overlapping - patches (stride=window_size) - padding: zero padding added to the height and width dimensions - keepdim: Combine the patch dimension into the batch dimension - """ - super().__init__(p=1) - self.flags = { - 'window_size': window_size, - 'stride': stride if stride is not None else window_size, - 'padding': padding, - 'keepdim': keepdim, - } - - def compute_transformation( - self, input: Tensor, params: dict[str, Tensor], flags: dict[str, Any] - ) -> Tensor: - """Compute the transformation. - - Args: - input: the input tensor - params: generated parameters - flags: static parameters - - Returns: - the transformation - """ - out: Tensor = self.identity_matrix(input) - return out - - def apply_transform( - self, - input: Tensor, - params: dict[str, Tensor], - flags: dict[str, Any], - transform: Tensor | None = None, - ) -> Tensor: - """Apply the transform. - - Args: - input: the input tensor - params: generated parameters - flags: static parameters - transform: the geometric transformation tensor - - Returns: - the augmented input - """ - out = extract_tensor_patches( - input, - window_size=flags['window_size'], - stride=flags['stride'], - padding=flags['padding'], - ) - - if flags['keepdim']: - out = rearrange(out, 'b t c h w -> (b t) c h w') - - return out - - class _Clamp(K.IntensityAugmentationBase2D): """Clamp images to a specific range."""