Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added substation segementation dataset #2352

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
7dff61c
Added substation segementation dataset
rijuld Oct 17, 2024
10637af
resolved bugs
rijuld Oct 21, 2024
2cb0842
a
rijuld Oct 21, 2024
608f76a
Resolved error
rijuld Oct 21, 2024
288e8b1
fixed ruff errors
rijuld Oct 21, 2024
2e9bf83
fixed mypy errors for substation seg py file
rijuld Oct 21, 2024
78c494d
removed more errors
rijuld Oct 21, 2024
75ca32c
resolved ruff errors and mypy errors
rijuld Oct 24, 2024
e2326cc
fixed length and data size along with ruff and mypy errors
rijuld Oct 25, 2024
9832db4
resolved float error
rijuld Oct 25, 2024
ef79cd7
organized imports
rijuld Oct 25, 2024
83f2eb4
changed to float
rijuld Oct 25, 2024
69f5815
resolved mypy errors
rijuld Oct 27, 2024
898e6b3
resolved further tests
rijuld Oct 27, 2024
d14eca6
sorted imports
rijuld Oct 27, 2024
d6ae700
more test coverage
rijuld Oct 30, 2024
8892f0d
ruff format
rijuld Oct 30, 2024
3f135b4
increased test code coverage
rijuld Oct 30, 2024
9a05811
added formatting
rijuld Oct 30, 2024
4e65b04
removed transformations so that I can add them in data module
rijuld Oct 30, 2024
9a9d555
increased underline length
rijuld Oct 30, 2024
3e12e7e
corrected csv row length
rijuld Oct 30, 2024
bbba17b
Update datasets.rst
zijinyin Nov 24, 2024
4fffc1f
Update non_geo_datasets.csv
zijinyin Nov 24, 2024
598c4be
Merge pull request #3 from zijinyin/patch-4
rijuld Nov 25, 2024
15a8881
Merge pull request #1 from zijinyin/patch-2
rijuld Nov 25, 2024
095b7dd
added comment for dataset
rijuld Nov 25, 2024
b503817
changed name to substation
rijuld Nov 25, 2024
f28e30c
added copyright
rijuld Nov 25, 2024
fe1761d
corrected issues
rijuld Nov 25, 2024
c4c3545
added plot and tests
rijuld Nov 25, 2024
1817132
removed pytest
rijuld Nov 25, 2024
28377f8
ruff format
rijuld Nov 25, 2024
5af4e0f
Merge branch 'main' into main
rijuld Nov 26, 2024
a3b95ba
added extract function
rijuld Dec 2, 2024
1216da4
added import
rijuld Dec 2, 2024
b0c3c90
Merge branch 'main' into main
rijuld Dec 2, 2024
545ff66
added datamodule
rijuld Dec 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,11 @@ SSL4EO-L Benchmark

.. autoclass:: SSL4EOLBenchmark

SubstationDataset
^^^^^^^^^^^^^^^^^

.. autoclass:: SubstationDataset

SustainBench Crop Yield
^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/datasets/non_geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
`SSL4EO`_-S12,T,Sentinel-1/2,"CC-BY-4.0",1M,-,264x264,10,"SAR, MSI"
`SSL4EO-L Benchmark`_,S,Lansat & CDL,"CC0-1.0",25K,134,264x264,30,MSI
`SSL4EO-L Benchmark`_,S,Lansat & NLCD,"CC0-1.0",25K,17,264x264,30,MSI
`SubstationDataset`_,S,OpenStreetMap & Sentinel-2, "CC-BY-SA 2.0", 27K, 2, 228x228, 10, MSI
`SustainBench Crop Yield`_,R,MODIS,"CC-BY-SA-4.0",11k,-,32x32,-,MSI
`TreeSatAI`_,"C, R, S","Aerial, Sentinel-1/2",CC-BY-4.0,50K,"12, 15, 20","6, 20, 304","0.2, 10","CIR, MSI, SAR"
`Tropical Cyclone`_,R,GOES 8--16,"CC-BY-4.0","108,110",-,256x256,4K--8K,MSI
Expand Down
83 changes: 83 additions & 0 deletions tests/data/substation/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import os
import shutil

import numpy as np

# Parameters
SIZE = 228 # Image dimensions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually use much smaller fake images (32 x 32) to make the tests run faster

NUM_SAMPLES = 5 # Number of samples
np.random.seed(0)

# Define directory hierarchy
FILENAME_HIERARCHY = dict[str, 'FILENAME_HIERARCHY'] | list[str]

filenames: FILENAME_HIERARCHY = {'image_stack': ['image'], 'mask': ['mask']}


def create_file(path: str, value: str) -> None:
"""
Generates .npz files for images or masks based on the path.

Args:
- path (str): Base path for saving files (either 'image' or 'mask').
"""
for i in range(NUM_SAMPLES):
new_path = f'{path}_{i}.npz'

if value == 'image':
# Generate image data with shape (4, 13, SIZE, SIZE) for timepoints and channels
data = np.random.rand(4, 13, SIZE, SIZE).astype(
np.float32
) # 4 timepoints, 13 channels
elif value == 'mask':
# Generate mask data with shape (SIZE, SIZE) with 4 classes
data = np.random.randint(0, 4, size=(SIZE, SIZE)).astype(np.uint8)

np.savez_compressed(new_path, arr_0=data)


def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None:
"""
Recursively creates directory structure based on hierarchy and populates with data files.

Args:
- directory (str): Base directory for dataset.
- hierarchy (FILENAME_HIERARCHY): Directory and file structure.
"""
if isinstance(hierarchy, dict):
# Recursive case
for key, value in hierarchy.items():
path = os.path.join(directory, key)
os.makedirs(path, exist_ok=True)
create_directory(path, value)
else:
# Base case
for value in hierarchy:
path = os.path.join(directory, 'image')
create_file(path, value)


if __name__ == '__main__':
# Generate directory structure and data
create_directory('.', filenames)

# Create zip archives of dataset folders
filename_images = 'image_stack.tar.gz'
filename_masks = 'mask.tar.gz'
shutil.make_archive('image_stack', 'gztar', '.', 'image_stack')
shutil.make_archive('mask', 'gztar', '.', 'mask')

# Compute and print MD5 checksums for data validation
with open(filename_images, 'rb') as f:
md5_images = hashlib.md5(f.read()).hexdigest()
print(f'{filename_images}: {md5_images}')

with open(filename_masks, 'rb') as f:
md5_masks = hashlib.md5(f.read()).hexdigest()
print(f'{filename_masks}: {md5_masks}')
Binary file added tests/data/substation/image_stack.tar.gz
Binary file not shown.
Binary file added tests/data/substation/image_stack/image_0.npz
Binary file not shown.
Binary file added tests/data/substation/image_stack/image_1.npz
Binary file not shown.
Binary file added tests/data/substation/image_stack/image_2.npz
Binary file not shown.
Binary file added tests/data/substation/image_stack/image_3.npz
Binary file not shown.
Binary file added tests/data/substation/image_stack/image_4.npz
Binary file not shown.
Binary file added tests/data/substation/mask.tar.gz
Binary file not shown.
Binary file added tests/data/substation/mask/image_0.npz
Binary file not shown.
Binary file added tests/data/substation/mask/image_1.npz
Binary file not shown.
Binary file added tests/data/substation/mask/image_2.npz
Binary file not shown.
Binary file added tests/data/substation/mask/image_3.npz
Binary file not shown.
Binary file added tests/data/substation/mask/image_4.npz
Binary file not shown.
157 changes: 157 additions & 0 deletions tests/datasets/test_substation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import os
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs the same copyright header

import shutil
from collections.abc import Generator
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock

import matplotlib.pyplot as plt
import pytest
import torch

from torchgeo.datasets import SubstationDataset


class Args:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's get rid of this

"""Mocked arguments for testing SubstationDataset."""

def __init__(self) -> None:
self.data_dir: str = os.path.join(os.getcwd(), 'tests', 'data')
self.in_channels: int = 13
self.use_timepoints: bool = True
self.mask_2d: bool = True
self.timepoint_aggregation: str = 'median'


@pytest.fixture
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put all of these in a class like our other tests?

def dataset(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> Generator[SubstationDataset, None, None]:
"""Fixture for the SubstationDataset."""
args = Args()
image_files = ['image_0.npz', 'image_1.npz']

yield SubstationDataset(**vars(args), image_files=image_files)


@pytest.mark.parametrize(
'config',
[
{'in_channels': 3, 'use_timepoints': False, 'mask_2d': True},
{
'in_channels': 9,
'use_timepoints': True,
'timepoint_aggregation': 'concat',
'mask_2d': False,
},
{
'in_channels': 12,
'use_timepoints': True,
'timepoint_aggregation': 'median',
'mask_2d': True,
},
{
'in_channels': 5,
'use_timepoints': True,
'timepoint_aggregation': 'first',
'mask_2d': False,
},
{
'in_channels': 4,
'use_timepoints': True,
'timepoint_aggregation': 'random',
'mask_2d': True,
},
{'in_channels': 2, 'use_timepoints': False, 'mask_2d': False},
{
'in_channels': 5,
'use_timepoints': False,
'timepoint_aggregation': 'first',
'mask_2d': False,
},
{
'in_channels': 4,
'use_timepoints': False,
'timepoint_aggregation': 'random',
'mask_2d': True,
},
],
)
def test_getitem_semantic(config: dict[str, Any]) -> None:
args = Args()
for key, value in config.items():
setattr(args, key, value) # Dynamically set arguments for each config

# Setting mock paths and creating dataset instance
image_files = ['image_0.npz', 'image_1.npz']
dataset = SubstationDataset(**vars(args), image_files=image_files)

x = dataset[0]
assert isinstance(x, dict), f'Expected dict, got {type(x)}'
assert isinstance(x['image'], torch.Tensor), 'Expected image to be a torch.Tensor'
assert isinstance(x['mask'], torch.Tensor), 'Expected mask to be a torch.Tensor'


def test_len(dataset: SubstationDataset) -> None:
"""Test the length of the dataset."""
assert len(dataset) == 2


def test_output_shape(dataset: SubstationDataset) -> None:
"""Test the output shape of the dataset."""
x = dataset[0]
assert x['image'].shape == torch.Size([13, 228, 228])
assert x['mask'].shape == torch.Size([2, 228, 228])


def test_plot(dataset: SubstationDataset) -> None:
sample = dataset[0]
dataset.plot(sample, suptitle='Test')
plt.close()
dataset.plot(sample, show_titles=False)
plt.close()
sample['prediction'] = sample['mask'].clone()
dataset.plot(sample)
plt.close()


def test_already_downloaded(
dataset: SubstationDataset, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Test that the dataset doesn't re-download if already present."""
# Simulating that files are already present by copying them to the target directory
url_for_images = os.path.join('tests', 'data', 'substation', 'image_stack.tar.gz')
url_for_masks = os.path.join('tests', 'data', 'substation', 'mask.tar.gz')

# Copy files to the temporary directory to simulate already downloaded files
shutil.copy(url_for_images, tmp_path)
shutil.copy(url_for_masks, tmp_path)

# No download should be attempted, since the files are already present
# Mock the _download method to simulate the behavior
monkeypatch.setattr(dataset, '_download', MagicMock())
dataset._download() # This will now call the mocked method


def test_download(
dataset: SubstationDataset, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Test the _download method of the dataset."""
# Mock the download_url and extract_archive functions
mock_download_url = MagicMock()
mock_extract_archive = MagicMock()
monkeypatch.setattr('torchgeo.datasets.substation.download_url', mock_download_url)
monkeypatch.setattr(
'torchgeo.datasets.substation.extract_archive', mock_extract_archive
)

# Call the _download method
dataset._download()

# Check that download_url was called twice
mock_download_url.assert_called()
assert mock_download_url.call_count == 2

# Check that extract_archive was called twice
mock_extract_archive.assert_called()
assert mock_extract_archive.call_count == 2
Loading
Loading