diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index b8f2137c920..31ed8bdc482 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -469,6 +469,11 @@ SSL4EO-L Benchmark .. autoclass:: SSL4EOLBenchmark +SubstationDataset +^^^^^^^^^^^^^^^^^ + +.. autoclass:: SubstationDataset + SustainBench Crop Yield ^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/api/datasets/non_geo_datasets.csv b/docs/api/datasets/non_geo_datasets.csv index 7d7a17a4b94..620bf884db8 100644 --- a/docs/api/datasets/non_geo_datasets.csv +++ b/docs/api/datasets/non_geo_datasets.csv @@ -52,6 +52,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands `SSL4EO`_-S12,T,Sentinel-1/2,"CC-BY-4.0",1M,-,264x264,10,"SAR, MSI" `SSL4EO-L Benchmark`_,S,Lansat & CDL,"CC0-1.0",25K,134,264x264,30,MSI `SSL4EO-L Benchmark`_,S,Lansat & NLCD,"CC0-1.0",25K,17,264x264,30,MSI +`SubstationDataset`_,S,OpenStreetMap & Sentinel-2, "CC-BY-SA 2.0", 27K, 2, 228x228, 10, MSI `SustainBench Crop Yield`_,R,MODIS,"CC-BY-SA-4.0",11k,-,32x32,-,MSI `TreeSatAI`_,"C, R, S","Aerial, Sentinel-1/2",CC-BY-4.0,50K,"12, 15, 20","6, 20, 304","0.2, 10","CIR, MSI, SAR" `Tropical Cyclone`_,R,GOES 8--16,"CC-BY-4.0","108,110",-,256x256,4K--8K,MSI diff --git a/tests/data/substation/data.py b/tests/data/substation/data.py new file mode 100644 index 00000000000..356eec83850 --- /dev/null +++ b/tests/data/substation/data.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import os +import shutil + +import numpy as np + +# Parameters +SIZE = 228 # Image dimensions +NUM_SAMPLES = 5 # Number of samples +np.random.seed(0) + +# Define directory hierarchy +FILENAME_HIERARCHY = dict[str, 'FILENAME_HIERARCHY'] | list[str] + +filenames: FILENAME_HIERARCHY = {'image_stack': ['image'], 'mask': ['mask']} + + +def create_file(path: str, value: str) -> None: + """ + Generates .npz files for images or masks based on the path. + + Args: + - path (str): Base path for saving files (either 'image' or 'mask'). + """ + for i in range(NUM_SAMPLES): + new_path = f'{path}_{i}.npz' + + if value == 'image': + # Generate image data with shape (4, 13, SIZE, SIZE) for timepoints and channels + data = np.random.rand(4, 13, SIZE, SIZE).astype( + np.float32 + ) # 4 timepoints, 13 channels + elif value == 'mask': + # Generate mask data with shape (SIZE, SIZE) with 4 classes + data = np.random.randint(0, 4, size=(SIZE, SIZE)).astype(np.uint8) + + np.savez_compressed(new_path, arr_0=data) + + +def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None: + """ + Recursively creates directory structure based on hierarchy and populates with data files. + + Args: + - directory (str): Base directory for dataset. + - hierarchy (FILENAME_HIERARCHY): Directory and file structure. + """ + if isinstance(hierarchy, dict): + # Recursive case + for key, value in hierarchy.items(): + path = os.path.join(directory, key) + os.makedirs(path, exist_ok=True) + create_directory(path, value) + else: + # Base case + for value in hierarchy: + path = os.path.join(directory, 'image') + create_file(path, value) + + +if __name__ == '__main__': + # Generate directory structure and data + create_directory('.', filenames) + + # Create zip archives of dataset folders + filename_images = 'image_stack.tar.gz' + filename_masks = 'mask.tar.gz' + shutil.make_archive('image_stack', 'gztar', '.', 'image_stack') + shutil.make_archive('mask', 'gztar', '.', 'mask') + + # Compute and print MD5 checksums for data validation + with open(filename_images, 'rb') as f: + md5_images = hashlib.md5(f.read()).hexdigest() + print(f'{filename_images}: {md5_images}') + + with open(filename_masks, 'rb') as f: + md5_masks = hashlib.md5(f.read()).hexdigest() + print(f'{filename_masks}: {md5_masks}') diff --git a/tests/data/substation/image_stack.tar.gz b/tests/data/substation/image_stack.tar.gz new file mode 100644 index 00000000000..23b92374dae Binary files /dev/null and b/tests/data/substation/image_stack.tar.gz differ diff --git a/tests/data/substation/image_stack/image_0.npz b/tests/data/substation/image_stack/image_0.npz new file mode 100644 index 00000000000..d460c779f87 Binary files /dev/null and b/tests/data/substation/image_stack/image_0.npz differ diff --git a/tests/data/substation/image_stack/image_1.npz b/tests/data/substation/image_stack/image_1.npz new file mode 100644 index 00000000000..0f7e31edaaf Binary files /dev/null and b/tests/data/substation/image_stack/image_1.npz differ diff --git a/tests/data/substation/image_stack/image_2.npz b/tests/data/substation/image_stack/image_2.npz new file mode 100644 index 00000000000..4c3504be0e6 Binary files /dev/null and b/tests/data/substation/image_stack/image_2.npz differ diff --git a/tests/data/substation/image_stack/image_3.npz b/tests/data/substation/image_stack/image_3.npz new file mode 100644 index 00000000000..0104c267312 Binary files /dev/null and b/tests/data/substation/image_stack/image_3.npz differ diff --git a/tests/data/substation/image_stack/image_4.npz b/tests/data/substation/image_stack/image_4.npz new file mode 100644 index 00000000000..1adf8f7c3e6 Binary files /dev/null and b/tests/data/substation/image_stack/image_4.npz differ diff --git a/tests/data/substation/mask.tar.gz b/tests/data/substation/mask.tar.gz new file mode 100644 index 00000000000..887debae638 Binary files /dev/null and b/tests/data/substation/mask.tar.gz differ diff --git a/tests/data/substation/mask/image_0.npz b/tests/data/substation/mask/image_0.npz new file mode 100644 index 00000000000..1559f933537 Binary files /dev/null and b/tests/data/substation/mask/image_0.npz differ diff --git a/tests/data/substation/mask/image_1.npz b/tests/data/substation/mask/image_1.npz new file mode 100644 index 00000000000..56a1e5cc97b Binary files /dev/null and b/tests/data/substation/mask/image_1.npz differ diff --git a/tests/data/substation/mask/image_2.npz b/tests/data/substation/mask/image_2.npz new file mode 100644 index 00000000000..9d0094bbff2 Binary files /dev/null and b/tests/data/substation/mask/image_2.npz differ diff --git a/tests/data/substation/mask/image_3.npz b/tests/data/substation/mask/image_3.npz new file mode 100644 index 00000000000..3011ce9dd2b Binary files /dev/null and b/tests/data/substation/mask/image_3.npz differ diff --git a/tests/data/substation/mask/image_4.npz b/tests/data/substation/mask/image_4.npz new file mode 100644 index 00000000000..e161f9b9729 Binary files /dev/null and b/tests/data/substation/mask/image_4.npz differ diff --git a/tests/datasets/test_substation.py b/tests/datasets/test_substation.py new file mode 100644 index 00000000000..6586ed87cf1 --- /dev/null +++ b/tests/datasets/test_substation.py @@ -0,0 +1,157 @@ +import os +import shutil +from collections.abc import Generator +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock + +import matplotlib.pyplot as plt +import pytest +import torch + +from torchgeo.datasets import SubstationDataset + + +class Args: + """Mocked arguments for testing SubstationDataset.""" + + def __init__(self) -> None: + self.data_dir: str = os.path.join(os.getcwd(), 'tests', 'data') + self.in_channels: int = 13 + self.use_timepoints: bool = True + self.mask_2d: bool = True + self.timepoint_aggregation: str = 'median' + + +@pytest.fixture +def dataset( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> Generator[SubstationDataset, None, None]: + """Fixture for the SubstationDataset.""" + args = Args() + image_files = ['image_0.npz', 'image_1.npz'] + + yield SubstationDataset(**vars(args), image_files=image_files) + + +@pytest.mark.parametrize( + 'config', + [ + {'in_channels': 3, 'use_timepoints': False, 'mask_2d': True}, + { + 'in_channels': 9, + 'use_timepoints': True, + 'timepoint_aggregation': 'concat', + 'mask_2d': False, + }, + { + 'in_channels': 12, + 'use_timepoints': True, + 'timepoint_aggregation': 'median', + 'mask_2d': True, + }, + { + 'in_channels': 5, + 'use_timepoints': True, + 'timepoint_aggregation': 'first', + 'mask_2d': False, + }, + { + 'in_channels': 4, + 'use_timepoints': True, + 'timepoint_aggregation': 'random', + 'mask_2d': True, + }, + {'in_channels': 2, 'use_timepoints': False, 'mask_2d': False}, + { + 'in_channels': 5, + 'use_timepoints': False, + 'timepoint_aggregation': 'first', + 'mask_2d': False, + }, + { + 'in_channels': 4, + 'use_timepoints': False, + 'timepoint_aggregation': 'random', + 'mask_2d': True, + }, + ], +) +def test_getitem_semantic(config: dict[str, Any]) -> None: + args = Args() + for key, value in config.items(): + setattr(args, key, value) # Dynamically set arguments for each config + + # Setting mock paths and creating dataset instance + image_files = ['image_0.npz', 'image_1.npz'] + dataset = SubstationDataset(**vars(args), image_files=image_files) + + x = dataset[0] + assert isinstance(x, dict), f'Expected dict, got {type(x)}' + assert isinstance(x['image'], torch.Tensor), 'Expected image to be a torch.Tensor' + assert isinstance(x['mask'], torch.Tensor), 'Expected mask to be a torch.Tensor' + + +def test_len(dataset: SubstationDataset) -> None: + """Test the length of the dataset.""" + assert len(dataset) == 2 + + +def test_output_shape(dataset: SubstationDataset) -> None: + """Test the output shape of the dataset.""" + x = dataset[0] + assert x['image'].shape == torch.Size([13, 228, 228]) + assert x['mask'].shape == torch.Size([2, 228, 228]) + + +def test_plot(dataset: SubstationDataset) -> None: + sample = dataset[0] + dataset.plot(sample, suptitle='Test') + plt.close() + dataset.plot(sample, show_titles=False) + plt.close() + sample['prediction'] = sample['mask'].clone() + dataset.plot(sample) + plt.close() + + +def test_already_downloaded( + dataset: SubstationDataset, tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """Test that the dataset doesn't re-download if already present.""" + # Simulating that files are already present by copying them to the target directory + url_for_images = os.path.join('tests', 'data', 'substation', 'image_stack.tar.gz') + url_for_masks = os.path.join('tests', 'data', 'substation', 'mask.tar.gz') + + # Copy files to the temporary directory to simulate already downloaded files + shutil.copy(url_for_images, tmp_path) + shutil.copy(url_for_masks, tmp_path) + + # No download should be attempted, since the files are already present + # Mock the _download method to simulate the behavior + monkeypatch.setattr(dataset, '_download', MagicMock()) + dataset._download() # This will now call the mocked method + + +def test_download( + dataset: SubstationDataset, tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """Test the _download method of the dataset.""" + # Mock the download_url and extract_archive functions + mock_download_url = MagicMock() + mock_extract_archive = MagicMock() + monkeypatch.setattr('torchgeo.datasets.substation.download_url', mock_download_url) + monkeypatch.setattr( + 'torchgeo.datasets.substation.extract_archive', mock_extract_archive + ) + + # Call the _download method + dataset._download() + + # Check that download_url was called twice + mock_download_url.assert_called() + assert mock_download_url.call_count == 2 + + # Check that extract_archive was called twice + mock_extract_archive.assert_called() + assert mock_extract_archive.call_count == 2 diff --git a/torchgeo/datamodules/substation.py b/torchgeo/datamodules/substation.py new file mode 100644 index 00000000000..8a67d28343d --- /dev/null +++ b/torchgeo/datamodules/substation.py @@ -0,0 +1,169 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Substation Data Module.""" + +from typing import Any + +import numpy as np +import torch +from torch.utils.data import Subset + +from ..datasets import SubstationDataset +from .utils import group_shuffle_split + + +class SubstationDataModule: + """Substation Data Module with train-test split and transformations. + + .. versionadded:: 0.7 + """ + + def __init__( + self, + data_dir: str, + batch_size: int = 64, + num_workers: int = 0, + split_ratio: float = 0.8, + normalizing_type: str = "percentile", + normalizing_factor: np.ndarray | None = None, + means: np.ndarray | None = None, + stds: np.ndarray | None = None, + in_channels: int = 13, + model_type: str = "default", + geo_transforms: Any = None, + color_transforms: Any = None, + image_resize: Any = None, + mask_resize: Any = None, + **kwargs: Any, + ) -> None: + """Initialize a new SubstationDataModule instance. + + Args: + data_dir: Path to the dataset directory. + batch_size: Size of each mini-batch. + num_workers: Number of workers for data loading. + split_ratio: Ratio of data to use for training. + normalizing_type: Normalization type ('percentile', 'zscore', or 'default'). + normalizing_factor: Normalization factor for percentile normalization. + means: Mean values for z-score normalization. + stds: Standard deviation values for z-score normalization. + in_channels: Number of input channels to use. + model_type: Type of model being used (e.g., 'swin' for specific channel selection). + geo_transforms: Geometric transformations to apply to the data. + color_transforms: Color transformations to apply to the image. + image_resize: Resizing function for the image. + mask_resize: Resizing function for the mask. + **kwargs: Additional arguments passed to SubstationDataset. + """ + self.data_dir = data_dir + self.batch_size = batch_size + self.num_workers = num_workers + self.split_ratio = split_ratio + self.normalizing_type = normalizing_type + self.normalizing_factor = normalizing_factor + self.means = means + self.stds = stds + self.in_channels = in_channels + self.model_type = model_type + self.geo_transforms = geo_transforms + self.color_transforms = color_transforms + self.image_resize = image_resize + self.mask_resize = mask_resize + self.kwargs = kwargs + + # Placeholder for datasets + self.train_dataset = None + self.val_dataset = None + self.test_dataset = None + + def setup(self, stage: str) -> None: + """Set up datasets. + + Args: + stage: One of 'fit', 'validate', 'test', or 'predict'. + """ + # Initialize the dataset + dataset = SubstationDataset(data_dir=self.data_dir, **self.kwargs) + + # Train-test split + total_size = len(dataset) + train_size = int(total_size * self.split_ratio) + indices = list(range(total_size)) + train_indices, test_indices = group_shuffle_split( + indices, train_size=train_size, random_state=0 + ) + + if stage in ["fit", "validate"]: + # Further split train set into train and validation sets + val_split_ratio = 0.2 # 20% of the train set for validation + val_size = int(len(train_indices) * val_split_ratio) + train_indices, val_indices = group_shuffle_split( + train_indices, train_size=len(train_indices) - val_size, random_state=42 + ) + self.train_dataset = Subset(dataset, train_indices) + self.val_dataset = Subset(dataset, val_indices) + + # Apply preprocessing to train and validation datasets + self.train_dataset = self._apply_transforms(self.train_dataset) + self.val_dataset = self._apply_transforms(self.val_dataset) + + if stage == "test": + self.test_dataset = Subset(dataset, test_indices) + self.test_dataset = self._apply_transforms(self.test_dataset) + + def _apply_transforms(self, dataset: Subset) -> Subset: + """Apply preprocessing and transformations to the dataset. + + Args: + dataset: A subset of the dataset. + + Returns: + The processed dataset. + """ + for sample in dataset: + image, mask = sample["image"], sample["mask"] + + # Standardizing image + if self.normalizing_type == "percentile": + image = (image - self.normalizing_factor[:, 0].reshape((-1, 1, 1))) / self.normalizing_factor[:, 2].reshape((-1, 1, 1)) + elif self.normalizing_type == "zscore": + image = (image - self.means) / self.stds + else: + image = image / self.normalizing_factor + image = torch.clamp(image, 0, 1) + + # Selecting channels + if self.in_channels == 3: + image = image[:, [3, 2, 1], :, :] + elif self.model_type == "swin": + image = image[:, [3, 2, 1, 4, 5, 6, 7, 10, 11], :, :] + else: + image = image[:, :self.in_channels, :, :] + + # Applying geometric transformations + if self.geo_transforms: + combined = torch.cat((image, mask), 0) + combined = self.geo_transforms(combined) + image, mask = torch.split(combined, [image.shape[0], mask.shape[0]], 0) + + # Applying color transformations + if self.color_transforms: + num_timepoints = image.shape[0] // self.in_channels + for i in range(num_timepoints): + if self.in_channels >= 3: + start = i * self.in_channels + end = start + 3 + image[start:end, :, :] = self.color_transforms(image[start:end, :, :]) + else: + raise ValueError("Input dimensions must support color transformations.") + + # Resizing image and mask + if self.image_resize: + image = self.image_resize(image) + if self.mask_resize: + mask = self.mask_resize(mask) + + sample["image"], sample["mask"] = image, mask + + return dataset diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index f55ef3af22c..0ab72bd1db5 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -131,6 +131,7 @@ ) from .ssl4eo import SSL4EO, SSL4EOL, SSL4EOS12 from .ssl4eo_benchmark import SSL4EOLBenchmark +from .substation import SubstationDataset from .sustainbench_crop_yield import SustainBenchCropYield from .treesatai import TreeSatAI from .ucmerced import UCMerced @@ -276,6 +277,7 @@ 'SpaceNet6', 'SpaceNet7', 'SpaceNet8', + 'SubstationDataset', 'SustainBenchCropYield', 'TreeSatAI', 'TropicalCyclone', diff --git a/torchgeo/datasets/substation.py b/torchgeo/datasets/substation.py new file mode 100644 index 00000000000..4c423b53470 --- /dev/null +++ b/torchgeo/datasets/substation.py @@ -0,0 +1,241 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""This module handles the Substation segmentation dataset.""" + +import glob +import os + +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib.figure import Figure +from torch import Tensor + +from .geo import NonGeoDataset +from .utils import download_url, extract_archive + + +class SubstationDataset(NonGeoDataset): + """Base class for Substation Dataset. + + This dataset is responsible for handling the loading and transformation of + substation segmentation datasets. It extends NonGeoDataset, providing methods + for dataset verification, downloading, and transformation. + Dataset Format: + * .npz file for each datapoint + + Dataset Features: + + * 26,522 image-mask pairs stored as numpy files. + * Data from 5 revisits for most locations. + * Multi-temporal, multi-spectral images (13 channels) paired with masks, + with a spatial resolution of 228x228 pixels + + If you use this dataset in your research, please cite the following: + * https://doi.org/10.48550/arXiv.2409.17363 + """ + + directory: str = 'Substation' + filename_images: str = 'image_stack.tar.gz' + filename_masks: str = 'mask.tar.gz' + url_for_images: str = 'https://storage.googleapis.com/tz-ml-public/substation-over-10km2-csv-main-444e360fd2b6444b9018d509d0e4f36e/image_stack.tar.gz' + url_for_masks: str = 'https://storage.googleapis.com/tz-ml-public/substation-over-10km2-csv-main-444e360fd2b6444b9018d509d0e4f36e/mask.tar.gz' + + def __init__( + self, + data_dir: str, + in_channels: int, + use_timepoints: bool, + image_files: list[str], + mask_2d: bool, + timepoint_aggregation: str, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize the SubstationDataset. + + Args: + data_dir: Path to the directory containing the dataset. + in_channels: Number of channels to use from the image. + use_timepoints: Whether to use multiple timepoints for each image. + image_files: List of filenames for the images. + mask_2d: Whether to use a 2D mask. + timepoint_aggregation: How to aggregate multiple timepoints. + download: Whether to download the dataset if it is not found. + checksum: Whether to verify the dataset after downloading. + """ + self.data_dir = data_dir + self.in_channels = in_channels + self.use_timepoints = use_timepoints + self.timepoint_aggregation = timepoint_aggregation + self.mask_2d = mask_2d + self.image_dir = os.path.join(data_dir, 'substation', 'image_stack') + self.mask_dir = os.path.join(data_dir, 'substation', 'mask') + self.image_filenames = image_files + self.download = download + self.checksum = checksum + + self._verify() + + def __getitem__(self, index: int) -> dict[str, Tensor]: + """Get an item from the dataset by index. + + Args: + index: Index of the item to retrieve. + + Returns: + A dictionary containing the image and corresponding mask. + """ + image_filename = self.image_filenames[index] + image_path = os.path.join(self.image_dir, image_filename) + mask_path = os.path.join(self.mask_dir, image_filename) + + image = np.load(image_path)['arr_0'] + + # selecting channels + image = image[:, : self.in_channels, :, :] + + # handling multiple images across timepoints + if self.use_timepoints: + image = image[:4, :, :, :] + if self.timepoint_aggregation == 'concat': + image = np.reshape( + image, (-1, image.shape[2], image.shape[3]) + ) # (4*channels,h,w) + elif self.timepoint_aggregation == 'median': + image = np.median(image, axis=0) + else: + # image = np.median(image, axis=0) + # image = image[0] + if self.timepoint_aggregation == 'first': + image = image[0] + elif self.timepoint_aggregation == 'random': + image = image[np.random.randint(image.shape[0])] + + mask = np.load(mask_path)['arr_0'] + mask[mask != 3] = 0 + mask[mask == 3] = 1 + + image = torch.from_numpy(image) + mask = torch.from_numpy(mask).float() + mask = mask.unsqueeze(dim=0) + + if self.mask_2d: + mask_0 = 1.0 - mask + mask = torch.concat([mask_0, mask], dim=0) + + return {'image': image, 'mask': mask} + + def __len__(self) -> int: + """Returns the number of items in the dataset.""" + return len(self.image_filenames) + + def plot( + self, + sample: dict[str, Tensor], + show_titles: bool = True, + suptitle: str | None = None, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + A matplotlib Figure containing the rendered sample. + """ + ncols = 2 + + image = sample['image'][:3].permute(1, 2, 0).cpu().numpy() + image = image / 255.0 # Normalize image + + mask = sample['mask'][0].squeeze(0).cpu().numpy() + + showing_predictions = 'prediction' in sample + if showing_predictions: + prediction = sample['prediction'][0].squeeze(0).cpu().numpy() + ncols = 3 + + print('mask shape', mask.shape) + print('image shape', image.shape) + print('\n') + + fig, axs = plt.subplots(ncols=ncols, figsize=(4 * ncols, 4)) + axs[0].imshow(image) + axs[0].axis('off') + axs[1].imshow(mask, cmap='gray', interpolation='none') + axs[1].axis('off') + + if show_titles: + axs[0].set_title('Image') + axs[1].set_title('Mask') + + if showing_predictions: + axs[2].imshow(prediction, cmap='gray', interpolation='none') + axs[2].axis('off') + if show_titles: + axs[2].set_title('Prediction') + + if suptitle: + fig.suptitle(suptitle) + + return fig + + def _extract(self) -> None: + """Extract the dataset.""" + img_pathname = os.path.join(self.data_dir, self.filename_images) + extract_archive(img_pathname) + + mask_pathname = os.path.join(self.data_dir, self.filename_masks) + extract_archive(mask_pathname) + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # Check if the extracted files already exist + image_path = os.path.join(self.image_dir, '*.npz') + mask_path = os.path.join(self.mask_dir, '*.npz') + if glob.glob(image_path) and glob.glob(mask_path): + return + + # Check if the tar.gz files for images and masks have already been downloaded + image_exists = os.path.exists(os.path.join(self.data_dir, self.filename_images)) + mask_exists = os.path.exists(os.path.join(self.data_dir, self.filename_masks)) + + if image_exists and mask_exists: + self._extract() + return + + # If dataset files are missing and download is not allowed, raise an error + if not getattr(self, 'download', True): + raise FileNotFoundError( + f'Dataset files not found in {self.data_dir}. Enable downloading or provide the files.' + ) + + # Download and extract the dataset + self._download() + self._extract() + + def _download(self) -> None: + """Download the dataset and extract it.""" + # Download and verify images + download_url( + self.url_for_images, + self.data_dir, + filename=self.filename_images, + md5='INSERT_IMAGES_MD5_HASH' if self.checksum else None, + ) + extract_archive( + os.path.join(self.data_dir, self.filename_images), self.data_dir + ) + + # Download and verify masks + download_url( + self.url_for_masks, + self.data_dir, + filename=self.filename_masks, + md5='INSERT_MASKS_MD5_HASH' if self.checksum else None, + ) + extract_archive(os.path.join(self.data_dir, self.filename_masks), self.data_dir)