-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #10 from soda-inria/tests
Tests ✅
- Loading branch information
Showing
4 changed files
with
449 additions
and
222 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,68 +1,96 @@ | ||
import pytest | ||
import pandas as pd | ||
import json | ||
from carte_ai.data.load_data import * | ||
import numpy as np | ||
from unittest.mock import patch, MagicMock | ||
from carte_ai.data.load_data import ( | ||
load_parquet_config, | ||
set_split, | ||
set_split_hf, | ||
spotify, | ||
wina_pl, | ||
wine_dot_com_prices, | ||
wine_vivino_price, | ||
) | ||
|
||
# Parametrize with additional test cases including edge cases | ||
@pytest.mark.parametrize("data_path, config_path", [ | ||
("/mnt/data/wine_vivino_price.parquet", "/mnt/data/config_wine_vivino_price.json"), | ||
("/mnt/data/wine_pl.parquet", "/mnt/data/config_wine_pl.json"), | ||
("/mnt/data/wine_dot_com_prices.parquet", "/mnt/data/config_wine_dot_com_prices.json"), | ||
("/mnt/data/spotify.parquet", "/mnt/data/config_spotify.json"), | ||
# Edge case: Non-existent data file | ||
("invalid/path/non_existent.parquet", "/mnt/data/config_wine_vivino_price.json"), | ||
# Edge case: Missing config file | ||
("/mnt/data/wine_vivino_price.parquet", "invalid/path/non_existent_config.json"), | ||
]) | ||
def test_load_data(data_path, config_path): | ||
try: | ||
# Load the configuration | ||
with open(config_path, "r") as f: | ||
config = json.load(f) | ||
except FileNotFoundError: | ||
# Test the behavior when config file is missing | ||
pytest.fail(f"Configuration file not found: {config_path}") | ||
return | ||
# Mocked DataFrame | ||
mock_data = pd.DataFrame({ | ||
"entity_name": ["A", "B", "C", "D"], | ||
"target_name": [1, 0, 1, 0], | ||
"feature1": [10, 20, 30, 40], | ||
"feature2": [100, 200, 300, 400] | ||
}) | ||
|
||
# Try loading the data | ||
try: | ||
data = load_data(config) | ||
except FileNotFoundError: | ||
# Test the behavior when data file is missing | ||
pytest.fail(f"Data file not found: {data_path}") | ||
return | ||
except Exception as e: | ||
pytest.fail(f"Exception occurred while loading data: {str(e)}") | ||
return | ||
# Mocked Configurations | ||
mock_config = { | ||
"entity_name": "entity_name", | ||
"target_name": "target_name", | ||
"task": "regression", | ||
"repeated": False | ||
} | ||
|
||
# Basic assertions to ensure the data is loaded correctly | ||
assert isinstance(data, pd.DataFrame), "The loaded data is not a DataFrame." | ||
assert config["entity_name"] in data.columns, f"{config['entity_name']} not found in DataFrame columns." | ||
assert config["target_name"] in data.columns, f"{config['target_name']} not found in DataFrame columns." | ||
assert not data.empty, "The DataFrame is empty." | ||
### Test `load_parquet_config` ### | ||
@patch("carte_ai.data.load_data.requests.get") | ||
@patch("pandas.read_parquet") | ||
def test_load_parquet_config(mock_read_parquet, mock_requests): | ||
# Mock parquet loading | ||
mock_read_parquet.return_value = mock_data | ||
|
||
# Check for missing values | ||
assert data.isna().sum().sum() == 0, "There are missing values in the DataFrame." | ||
# Mock config JSON | ||
mock_response = MagicMock() | ||
mock_response.status_code = 200 | ||
mock_response.json.return_value = mock_config | ||
mock_requests.return_value = mock_response | ||
|
||
# Add task-specific checks | ||
if config["task"] == "regression": | ||
assert pd.api.types.is_numeric_dtype(data[config["target_name"]]), "Target for regression is not numeric." | ||
elif config["task"] == "classification": | ||
assert data[config["target_name"]].nunique() > 1, "Target for classification has only one class." | ||
data, config = load_parquet_config("mock_dataset") | ||
|
||
# Check for duplicates if specified in config | ||
if not config.get("repeated", False): | ||
assert not data.duplicated(subset=[config["entity_name"]]).any(), "There are duplicate entries in the entity column." | ||
# Assertions | ||
mock_read_parquet.assert_called_once() | ||
mock_requests.assert_called_once() | ||
assert isinstance(data, pd.DataFrame), "Data should be a DataFrame" | ||
assert isinstance(config, dict), "Config should be a dictionary" | ||
assert config == mock_config, "Config does not match expected values" | ||
|
||
@pytest.mark.parametrize("config,expected_exception", [ | ||
# Malformed config: missing entity_name | ||
({"target_name": "Price", "task": "regression"}, KeyError), | ||
# Malformed config: missing target_name | ||
({"entity_name": "Name", "task": "regression"}, KeyError), | ||
# Malformed config: invalid task type | ||
({"entity_name": "Name", "target_name": "Price", "task": "invalid_task"}, ValueError), | ||
]) | ||
def test_load_data_invalid_configs(config, expected_exception): | ||
# Test behavior with malformed configuration | ||
with pytest.raises(expected_exception): | ||
load_data(config) | ||
### Test `set_split` ### | ||
def test_set_split(): | ||
X_train, X_test, y_train, y_test = set_split(mock_data, mock_config, num_train=2) | ||
|
||
# Assertions | ||
assert isinstance(X_train, pd.DataFrame), "X_train should be a DataFrame" | ||
assert isinstance(X_test, pd.DataFrame), "X_test should be a DataFrame" | ||
assert isinstance(y_train, np.ndarray), "y_train should be a NumPy array" | ||
assert isinstance(y_test, np.ndarray), "y_test should be a NumPy array" | ||
assert len(X_train) > 0 and len(X_test) > 0, "Train and test splits should not be empty" | ||
|
||
### Test `set_split_hf` ### | ||
def test_set_split_hf(): | ||
X_train, X_test, y_train, y_test = set_split_hf( | ||
mock_data, target_name="target_name", entity_name="entity_name", num_train=2 | ||
) | ||
|
||
# Assertions | ||
assert isinstance(X_train, pd.DataFrame), "X_train should be a DataFrame" | ||
assert isinstance(X_test, pd.DataFrame), "X_test should be a DataFrame" | ||
assert isinstance(y_train, np.ndarray), "y_train should be a NumPy array" | ||
assert isinstance(y_test, np.ndarray), "y_test should be a NumPy array" | ||
assert len(X_train) > 0 and len(X_test) > 0, "Train and test splits should not be empty" | ||
|
||
### Test dataset-specific methods ### | ||
@patch("carte_ai.data.load_data.load_parquet_config", return_value=(mock_data, mock_config)) | ||
def test_spotify(mock_load_parquet_config): | ||
X_train, X_test, y_train, y_test = spotify(num_train=2) | ||
assert X_train.shape[0] > 0 and X_test.shape[0] > 0, "Spotify train/test splits should not be empty" | ||
|
||
@patch("carte_ai.data.load_data.load_parquet_config", return_value=(mock_data, mock_config)) | ||
def test_wina_pl(mock_load_parquet_config): | ||
X_train, X_test, y_train, y_test = wina_pl(num_train=2) | ||
assert X_train.shape[0] > 0 and X_test.shape[0] > 0, "Wina_PL train/test splits should not be empty" | ||
|
||
@patch("carte_ai.data.load_data.load_parquet_config", return_value=(mock_data, mock_config)) | ||
def test_wine_dot_com_prices(mock_load_parquet_config): | ||
X_train, X_test, y_train, y_test = wine_dot_com_prices(num_train=2) | ||
assert X_train.shape[0] > 0 and X_test.shape[0] > 0, "Wine.com Prices train/test splits should not be empty" | ||
|
||
@patch("carte_ai.data.load_data.load_parquet_config", return_value=(mock_data, mock_config)) | ||
def test_wine_vivino_price(mock_load_parquet_config): | ||
X_train, X_test, y_train, y_test = wine_vivino_price(num_train=2) | ||
assert X_train.shape[0] > 0 and X_test.shape[0] > 0, "Vivino Wine Prices train/test splits should not be empty" |
Oops, something went wrong.