Skip to content

Commit

Permalink
Merge pull request #10 from soda-inria/tests
Browse files Browse the repository at this point in the history
Tests ✅
  • Loading branch information
gaetanbrison authored Dec 8, 2024
2 parents 49c0263 + da5ce8a commit 26e38ee
Show file tree
Hide file tree
Showing 4 changed files with 449 additions and 222 deletions.
146 changes: 87 additions & 59 deletions tests/tests_data/test_load_data.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,96 @@
import pytest
import pandas as pd
import json
from carte_ai.data.load_data import *
import numpy as np
from unittest.mock import patch, MagicMock
from carte_ai.data.load_data import (
load_parquet_config,
set_split,
set_split_hf,
spotify,
wina_pl,
wine_dot_com_prices,
wine_vivino_price,
)

# Parametrize with additional test cases including edge cases
@pytest.mark.parametrize("data_path, config_path", [
("/mnt/data/wine_vivino_price.parquet", "/mnt/data/config_wine_vivino_price.json"),
("/mnt/data/wine_pl.parquet", "/mnt/data/config_wine_pl.json"),
("/mnt/data/wine_dot_com_prices.parquet", "/mnt/data/config_wine_dot_com_prices.json"),
("/mnt/data/spotify.parquet", "/mnt/data/config_spotify.json"),
# Edge case: Non-existent data file
("invalid/path/non_existent.parquet", "/mnt/data/config_wine_vivino_price.json"),
# Edge case: Missing config file
("/mnt/data/wine_vivino_price.parquet", "invalid/path/non_existent_config.json"),
])
def test_load_data(data_path, config_path):
try:
# Load the configuration
with open(config_path, "r") as f:
config = json.load(f)
except FileNotFoundError:
# Test the behavior when config file is missing
pytest.fail(f"Configuration file not found: {config_path}")
return
# Mocked DataFrame
mock_data = pd.DataFrame({
"entity_name": ["A", "B", "C", "D"],
"target_name": [1, 0, 1, 0],
"feature1": [10, 20, 30, 40],
"feature2": [100, 200, 300, 400]
})

# Try loading the data
try:
data = load_data(config)
except FileNotFoundError:
# Test the behavior when data file is missing
pytest.fail(f"Data file not found: {data_path}")
return
except Exception as e:
pytest.fail(f"Exception occurred while loading data: {str(e)}")
return
# Mocked Configurations
mock_config = {
"entity_name": "entity_name",
"target_name": "target_name",
"task": "regression",
"repeated": False
}

# Basic assertions to ensure the data is loaded correctly
assert isinstance(data, pd.DataFrame), "The loaded data is not a DataFrame."
assert config["entity_name"] in data.columns, f"{config['entity_name']} not found in DataFrame columns."
assert config["target_name"] in data.columns, f"{config['target_name']} not found in DataFrame columns."
assert not data.empty, "The DataFrame is empty."
### Test `load_parquet_config` ###
@patch("carte_ai.data.load_data.requests.get")
@patch("pandas.read_parquet")
def test_load_parquet_config(mock_read_parquet, mock_requests):
# Mock parquet loading
mock_read_parquet.return_value = mock_data

# Check for missing values
assert data.isna().sum().sum() == 0, "There are missing values in the DataFrame."
# Mock config JSON
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = mock_config
mock_requests.return_value = mock_response

# Add task-specific checks
if config["task"] == "regression":
assert pd.api.types.is_numeric_dtype(data[config["target_name"]]), "Target for regression is not numeric."
elif config["task"] == "classification":
assert data[config["target_name"]].nunique() > 1, "Target for classification has only one class."
data, config = load_parquet_config("mock_dataset")

# Check for duplicates if specified in config
if not config.get("repeated", False):
assert not data.duplicated(subset=[config["entity_name"]]).any(), "There are duplicate entries in the entity column."
# Assertions
mock_read_parquet.assert_called_once()
mock_requests.assert_called_once()
assert isinstance(data, pd.DataFrame), "Data should be a DataFrame"
assert isinstance(config, dict), "Config should be a dictionary"
assert config == mock_config, "Config does not match expected values"

@pytest.mark.parametrize("config,expected_exception", [
# Malformed config: missing entity_name
({"target_name": "Price", "task": "regression"}, KeyError),
# Malformed config: missing target_name
({"entity_name": "Name", "task": "regression"}, KeyError),
# Malformed config: invalid task type
({"entity_name": "Name", "target_name": "Price", "task": "invalid_task"}, ValueError),
])
def test_load_data_invalid_configs(config, expected_exception):
# Test behavior with malformed configuration
with pytest.raises(expected_exception):
load_data(config)
### Test `set_split` ###
def test_set_split():
X_train, X_test, y_train, y_test = set_split(mock_data, mock_config, num_train=2)

# Assertions
assert isinstance(X_train, pd.DataFrame), "X_train should be a DataFrame"
assert isinstance(X_test, pd.DataFrame), "X_test should be a DataFrame"
assert isinstance(y_train, np.ndarray), "y_train should be a NumPy array"
assert isinstance(y_test, np.ndarray), "y_test should be a NumPy array"
assert len(X_train) > 0 and len(X_test) > 0, "Train and test splits should not be empty"

### Test `set_split_hf` ###
def test_set_split_hf():
X_train, X_test, y_train, y_test = set_split_hf(
mock_data, target_name="target_name", entity_name="entity_name", num_train=2
)

# Assertions
assert isinstance(X_train, pd.DataFrame), "X_train should be a DataFrame"
assert isinstance(X_test, pd.DataFrame), "X_test should be a DataFrame"
assert isinstance(y_train, np.ndarray), "y_train should be a NumPy array"
assert isinstance(y_test, np.ndarray), "y_test should be a NumPy array"
assert len(X_train) > 0 and len(X_test) > 0, "Train and test splits should not be empty"

### Test dataset-specific methods ###
@patch("carte_ai.data.load_data.load_parquet_config", return_value=(mock_data, mock_config))
def test_spotify(mock_load_parquet_config):
X_train, X_test, y_train, y_test = spotify(num_train=2)
assert X_train.shape[0] > 0 and X_test.shape[0] > 0, "Spotify train/test splits should not be empty"

@patch("carte_ai.data.load_data.load_parquet_config", return_value=(mock_data, mock_config))
def test_wina_pl(mock_load_parquet_config):
X_train, X_test, y_train, y_test = wina_pl(num_train=2)
assert X_train.shape[0] > 0 and X_test.shape[0] > 0, "Wina_PL train/test splits should not be empty"

@patch("carte_ai.data.load_data.load_parquet_config", return_value=(mock_data, mock_config))
def test_wine_dot_com_prices(mock_load_parquet_config):
X_train, X_test, y_train, y_test = wine_dot_com_prices(num_train=2)
assert X_train.shape[0] > 0 and X_test.shape[0] > 0, "Wine.com Prices train/test splits should not be empty"

@patch("carte_ai.data.load_data.load_parquet_config", return_value=(mock_data, mock_config))
def test_wine_vivino_price(mock_load_parquet_config):
X_train, X_test, y_train, y_test = wine_vivino_price(num_train=2)
assert X_train.shape[0] > 0 and X_test.shape[0] > 0, "Vivino Wine Prices train/test splits should not be empty"
Loading

0 comments on commit 26e38ee

Please sign in to comment.