-
Notifications
You must be signed in to change notification settings - Fork 379
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added substation segementation dataset #2352
base: main
Are you sure you want to change the base?
Changes from all commits
7dff61c
10637af
2cb0842
608f76a
288e8b1
2e9bf83
78c494d
75ca32c
e2326cc
9832db4
ef79cd7
83f2eb4
69f5815
898e6b3
d14eca6
d6ae700
8892f0d
3f135b4
9a05811
4e65b04
9a9d555
3e12e7e
bbba17b
4fffc1f
598c4be
15a8881
095b7dd
b503817
f28e30c
fe1761d
c4c3545
1817132
28377f8
5af4e0f
a3b95ba
1216da4
b0c3c90
545ff66
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import hashlib | ||
import os | ||
import shutil | ||
|
||
import numpy as np | ||
|
||
# Parameters | ||
SIZE = 228 # Image dimensions | ||
NUM_SAMPLES = 5 # Number of samples | ||
np.random.seed(0) | ||
|
||
# Define directory hierarchy | ||
FILENAME_HIERARCHY = dict[str, 'FILENAME_HIERARCHY'] | list[str] | ||
|
||
filenames: FILENAME_HIERARCHY = {'image_stack': ['image'], 'mask': ['mask']} | ||
|
||
|
||
def create_file(path: str, value: str) -> None: | ||
""" | ||
Generates .npz files for images or masks based on the path. | ||
|
||
Args: | ||
- path (str): Base path for saving files (either 'image' or 'mask'). | ||
""" | ||
for i in range(NUM_SAMPLES): | ||
new_path = f'{path}_{i}.npz' | ||
|
||
if value == 'image': | ||
# Generate image data with shape (4, 13, SIZE, SIZE) for timepoints and channels | ||
data = np.random.rand(4, 13, SIZE, SIZE).astype( | ||
np.float32 | ||
) # 4 timepoints, 13 channels | ||
elif value == 'mask': | ||
# Generate mask data with shape (SIZE, SIZE) with 4 classes | ||
data = np.random.randint(0, 4, size=(SIZE, SIZE)).astype(np.uint8) | ||
|
||
np.savez_compressed(new_path, arr_0=data) | ||
|
||
|
||
def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None: | ||
""" | ||
Recursively creates directory structure based on hierarchy and populates with data files. | ||
|
||
Args: | ||
- directory (str): Base directory for dataset. | ||
- hierarchy (FILENAME_HIERARCHY): Directory and file structure. | ||
""" | ||
if isinstance(hierarchy, dict): | ||
# Recursive case | ||
for key, value in hierarchy.items(): | ||
path = os.path.join(directory, key) | ||
os.makedirs(path, exist_ok=True) | ||
create_directory(path, value) | ||
else: | ||
# Base case | ||
for value in hierarchy: | ||
path = os.path.join(directory, 'image') | ||
create_file(path, value) | ||
|
||
|
||
if __name__ == '__main__': | ||
# Generate directory structure and data | ||
create_directory('.', filenames) | ||
|
||
# Create zip archives of dataset folders | ||
filename_images = 'image_stack.tar.gz' | ||
filename_masks = 'mask.tar.gz' | ||
shutil.make_archive('image_stack', 'gztar', '.', 'image_stack') | ||
shutil.make_archive('mask', 'gztar', '.', 'mask') | ||
|
||
# Compute and print MD5 checksums for data validation | ||
with open(filename_images, 'rb') as f: | ||
md5_images = hashlib.md5(f.read()).hexdigest() | ||
print(f'{filename_images}: {md5_images}') | ||
|
||
with open(filename_masks, 'rb') as f: | ||
md5_masks = hashlib.md5(f.read()).hexdigest() | ||
print(f'{filename_masks}: {md5_masks}') |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
import os | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Needs the same copyright header |
||
import shutil | ||
from collections.abc import Generator | ||
from pathlib import Path | ||
from typing import Any | ||
from unittest.mock import MagicMock | ||
|
||
import matplotlib.pyplot as plt | ||
import pytest | ||
import torch | ||
|
||
from torchgeo.datasets import SubstationDataset | ||
|
||
|
||
class Args: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's get rid of this |
||
"""Mocked arguments for testing SubstationDataset.""" | ||
|
||
def __init__(self) -> None: | ||
self.data_dir: str = os.path.join(os.getcwd(), 'tests', 'data') | ||
self.in_channels: int = 13 | ||
self.use_timepoints: bool = True | ||
self.mask_2d: bool = True | ||
self.timepoint_aggregation: str = 'median' | ||
|
||
|
||
@pytest.fixture | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we put all of these in a class like our other tests? |
||
def dataset( | ||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch | ||
) -> Generator[SubstationDataset, None, None]: | ||
"""Fixture for the SubstationDataset.""" | ||
args = Args() | ||
image_files = ['image_0.npz', 'image_1.npz'] | ||
|
||
yield SubstationDataset(**vars(args), image_files=image_files) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
'config', | ||
[ | ||
{'in_channels': 3, 'use_timepoints': False, 'mask_2d': True}, | ||
{ | ||
'in_channels': 9, | ||
'use_timepoints': True, | ||
'timepoint_aggregation': 'concat', | ||
'mask_2d': False, | ||
}, | ||
{ | ||
'in_channels': 12, | ||
'use_timepoints': True, | ||
'timepoint_aggregation': 'median', | ||
'mask_2d': True, | ||
}, | ||
{ | ||
'in_channels': 5, | ||
'use_timepoints': True, | ||
'timepoint_aggregation': 'first', | ||
'mask_2d': False, | ||
}, | ||
{ | ||
'in_channels': 4, | ||
'use_timepoints': True, | ||
'timepoint_aggregation': 'random', | ||
'mask_2d': True, | ||
}, | ||
{'in_channels': 2, 'use_timepoints': False, 'mask_2d': False}, | ||
{ | ||
'in_channels': 5, | ||
'use_timepoints': False, | ||
'timepoint_aggregation': 'first', | ||
'mask_2d': False, | ||
}, | ||
{ | ||
'in_channels': 4, | ||
'use_timepoints': False, | ||
'timepoint_aggregation': 'random', | ||
'mask_2d': True, | ||
}, | ||
], | ||
) | ||
def test_getitem_semantic(config: dict[str, Any]) -> None: | ||
args = Args() | ||
for key, value in config.items(): | ||
setattr(args, key, value) # Dynamically set arguments for each config | ||
|
||
# Setting mock paths and creating dataset instance | ||
image_files = ['image_0.npz', 'image_1.npz'] | ||
dataset = SubstationDataset(**vars(args), image_files=image_files) | ||
|
||
x = dataset[0] | ||
assert isinstance(x, dict), f'Expected dict, got {type(x)}' | ||
assert isinstance(x['image'], torch.Tensor), 'Expected image to be a torch.Tensor' | ||
assert isinstance(x['mask'], torch.Tensor), 'Expected mask to be a torch.Tensor' | ||
|
||
|
||
def test_len(dataset: SubstationDataset) -> None: | ||
"""Test the length of the dataset.""" | ||
assert len(dataset) == 2 | ||
|
||
|
||
def test_output_shape(dataset: SubstationDataset) -> None: | ||
"""Test the output shape of the dataset.""" | ||
x = dataset[0] | ||
assert x['image'].shape == torch.Size([13, 228, 228]) | ||
assert x['mask'].shape == torch.Size([2, 228, 228]) | ||
|
||
|
||
def test_plot(dataset: SubstationDataset) -> None: | ||
sample = dataset[0] | ||
dataset.plot(sample, suptitle='Test') | ||
plt.close() | ||
dataset.plot(sample, show_titles=False) | ||
plt.close() | ||
sample['prediction'] = sample['mask'].clone() | ||
dataset.plot(sample) | ||
plt.close() | ||
|
||
|
||
def test_already_downloaded( | ||
dataset: SubstationDataset, tmp_path: Path, monkeypatch: pytest.MonkeyPatch | ||
) -> None: | ||
"""Test that the dataset doesn't re-download if already present.""" | ||
# Simulating that files are already present by copying them to the target directory | ||
url_for_images = os.path.join('tests', 'data', 'substation', 'image_stack.tar.gz') | ||
url_for_masks = os.path.join('tests', 'data', 'substation', 'mask.tar.gz') | ||
|
||
# Copy files to the temporary directory to simulate already downloaded files | ||
shutil.copy(url_for_images, tmp_path) | ||
shutil.copy(url_for_masks, tmp_path) | ||
|
||
# No download should be attempted, since the files are already present | ||
# Mock the _download method to simulate the behavior | ||
monkeypatch.setattr(dataset, '_download', MagicMock()) | ||
dataset._download() # This will now call the mocked method | ||
|
||
|
||
def test_download( | ||
dataset: SubstationDataset, tmp_path: Path, monkeypatch: pytest.MonkeyPatch | ||
) -> None: | ||
"""Test the _download method of the dataset.""" | ||
# Mock the download_url and extract_archive functions | ||
mock_download_url = MagicMock() | ||
mock_extract_archive = MagicMock() | ||
monkeypatch.setattr('torchgeo.datasets.substation.download_url', mock_download_url) | ||
monkeypatch.setattr( | ||
'torchgeo.datasets.substation.extract_archive', mock_extract_archive | ||
) | ||
|
||
# Call the _download method | ||
dataset._download() | ||
|
||
# Check that download_url was called twice | ||
mock_download_url.assert_called() | ||
assert mock_download_url.call_count == 2 | ||
|
||
# Check that extract_archive was called twice | ||
mock_extract_archive.assert_called() | ||
assert mock_extract_archive.call_count == 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We usually use much smaller fake images (32 x 32) to make the tests run faster