diff --git a/.github/workflows/integration-tests.yaml b/.github/workflows/integration-tests.yaml index 489c1f98b5..1da0fdd54c 100644 --- a/.github/workflows/integration-tests.yaml +++ b/.github/workflows/integration-tests.yaml @@ -72,6 +72,13 @@ jobs: env: GANG_SCHEDULER_NAME: ${{ matrix.gang-scheduler-name }} + - name: Run initializer_v2 e2e for Python 3.11+ + if: ${{ matrix.python-version == '3.11' }} + run: | + pip install urllib3 huggingface_hub + pip install -U './sdk_v2' + pytest ./test/e2e/initializer_v2 + - name: Collect volcano logs if: ${{ failure() && matrix.gang-scheduler-name == 'volcano' }} run: | diff --git a/.github/workflows/test-python.yaml b/.github/workflows/test-python.yaml index 9a706461b7..854995e065 100644 --- a/.github/workflows/test-python.yaml +++ b/.github/workflows/test-python.yaml @@ -32,4 +32,13 @@ jobs: pip install -U './sdk/python[huggingface]' - name: Run unit test for training sdk - run: pytest ./sdk/python/kubeflow/training/api/training_client_test.py + run: | + pytest ./sdk/python/kubeflow/training/api/training_client_test.py + + - name: Run Python unit tests for v2 + run: | + pip install -U './sdk_v2' + export PYTHONPATH="${{ github.workspace }}:$PYTHONPATH" + pytest ./pkg/initializer_v2/model + pytest ./pkg/initializer_v2/dataset + pytest ./pkg/initializer_v2/utils diff --git a/pkg/initializer_v2/dataset/__main__.py b/pkg/initializer_v2/dataset/__main__.py index 2be2dd9cb8..7f65577ea8 100644 --- a/pkg/initializer_v2/dataset/__main__.py +++ b/pkg/initializer_v2/dataset/__main__.py @@ -11,7 +11,8 @@ level=logging.INFO, ) -if __name__ == "__main__": + +def main(): logging.info("Starting dataset initialization") try: @@ -29,3 +30,7 @@ case _: logging.error("STORAGE_URI must have the valid dataset provider") raise Exception + + +if __name__ == "__main__": + main() diff --git a/pkg/initializer_v2/dataset/huggingface_test.py b/pkg/initializer_v2/dataset/huggingface_test.py new file mode 100644 index 0000000000..00f15c9229 --- /dev/null +++ b/pkg/initializer_v2/dataset/huggingface_test.py @@ -0,0 +1,144 @@ +from unittest.mock import MagicMock, patch + +import pytest +from kubeflow.training import DATASET_PATH + +import pkg.initializer_v2.utils.utils as utils + + +@pytest.fixture +def huggingface_dataset_instance(): + """Fixture for HuggingFace Dataset instance""" + from pkg.initializer_v2.dataset.huggingface import HuggingFace + + return HuggingFace() + + +# Test cases for config loading +@pytest.mark.parametrize( + "test_name, test_config, expected", + [ + ( + "Full config with token", + {"storage_uri": "hf://dataset/path", "access_token": "test_token"}, + {"storage_uri": "hf://dataset/path", "access_token": "test_token"}, + ), + ( + "Minimal config without token", + {"storage_uri": "hf://dataset/path"}, + {"storage_uri": "hf://dataset/path", "access_token": None}, + ), + ], +) +def test_load_config(test_name, test_config, expected, huggingface_dataset_instance): + """Test config loading with different configurations""" + print(f"Running test: {test_name}") + + with patch.object(utils, "get_config_from_env", return_value=test_config): + huggingface_dataset_instance.load_config() + assert ( + huggingface_dataset_instance.config.storage_uri == expected["storage_uri"] + ) + assert ( + huggingface_dataset_instance.config.access_token == expected["access_token"] + ) + + print("Test execution completed") + + +@pytest.mark.parametrize( + "test_name, test_case", + [ + ( + "Successful download with token", + { + "config": { + "storage_uri": "hf://username/dataset-name", + "access_token": "test_token", + }, + "should_login": True, + "expected_repo_id": "username/dataset-name", + "mock_login_side_effect": None, + "mock_download_side_effect": None, + "expected_error": None, + }, + ), + ( + "Successful download without token", + { + "config": {"storage_uri": "hf://org/dataset-v1", "access_token": None}, + "should_login": False, + "expected_repo_id": "org/dataset-v1", + "mock_login_side_effect": None, + "mock_download_side_effect": None, + "expected_error": None, + }, + ), + ( + "Login failure", + { + "config": { + "storage_uri": "hf://username/dataset-name", + "access_token": "test_token", + }, + "should_login": True, + "expected_repo_id": "username/dataset-name", + "mock_login_side_effect": Exception, + "mock_download_side_effect": None, + "expected_error": Exception, + }, + ), + ( + "Download failure", + { + "config": { + "storage_uri": "hf://invalid/repo/name", + "access_token": None, + }, + "should_login": False, + "expected_repo_id": "invalid/repo/name", + "mock_login_side_effect": None, + "mock_download_side_effect": Exception, + "expected_error": Exception, + }, + ), + ], +) +def test_download_dataset(test_name, test_case, huggingface_dataset_instance): + """Test dataset download with different configurations""" + + print(f"Running test: {test_name}") + + huggingface_dataset_instance.config = MagicMock(**test_case["config"]) + + with patch("huggingface_hub.login") as mock_login, patch( + "huggingface_hub.snapshot_download" + ) as mock_download: + + # Configure mock behavior + if test_case["mock_login_side_effect"]: + mock_login.side_effect = test_case["mock_login_side_effect"] + if test_case["mock_download_side_effect"]: + mock_download.side_effect = test_case["mock_download_side_effect"] + + # Execute test + if test_case["expected_error"]: + with pytest.raises(test_case["expected_error"]): + huggingface_dataset_instance.download_dataset() + else: + huggingface_dataset_instance.download_dataset() + + # Verify login behavior + if test_case["should_login"]: + mock_login.assert_called_once_with(test_case["config"]["access_token"]) + else: + mock_login.assert_not_called() + + # Verify download parameters + if test_case["expected_repo_id"]: + mock_download.assert_called_once_with( + repo_id=test_case["expected_repo_id"], + local_dir=DATASET_PATH, + repo_type="dataset", + ) + print("Test execution completed") diff --git a/pkg/initializer_v2/dataset/main_test.py b/pkg/initializer_v2/dataset/main_test.py new file mode 100644 index 0000000000..bf1b269083 --- /dev/null +++ b/pkg/initializer_v2/dataset/main_test.py @@ -0,0 +1,122 @@ +import os +from unittest.mock import MagicMock, patch + +import pytest + +from pkg.initializer_v2.dataset.__main__ import main + + +@pytest.fixture +def mock_env_vars(): + """Fixture to set and clean up environment variables""" + original_env = dict(os.environ) + + def _set_env_vars(**kwargs): + for key, value in kwargs.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = str(value) + return os.environ + + yield _set_env_vars + + # Cleanup + os.environ.clear() + os.environ.update(original_env) + + +@pytest.mark.parametrize( + "test_name, test_case", + [ + ( + "Successful download with HuggingFace provider", + { + "storage_uri": "hf://dataset/path", + "access_token": "test_token", + "mock_config_error": False, + "mock_download_error": False, + "expected_error": None, + }, + ), + ( + "Missing storage URI environment variable", + { + "storage_uri": None, + "access_token": None, + "mock_config_error": False, + "mock_download_error": False, + "expected_error": Exception, + }, + ), + ( + "Invalid storage URI scheme", + { + "storage_uri": "invalid://dataset/path", + "access_token": None, + "mock_config_error": False, + "mock_download_error": False, + "expected_error": Exception, + }, + ), + ( + "Config loading failure", + { + "storage_uri": "hf://dataset/path", + "access_token": None, + "mock_config_error": True, + "mock_download_error": False, + "expected_error": Exception, + }, + ), + ( + "Dataset download failure", + { + "storage_uri": "hf://dataset/path/error", + "access_token": None, + "mock_config_error": False, + "mock_download_error": True, + "expected_error": Exception, + }, + ), + ], +) +def test_dataset_main(test_name, test_case, mock_env_vars): + """Test main script with different scenarios""" + print(f"Running test: {test_name}") + + # Setup mock environment variables + env_vars = { + "STORAGE_URI": test_case["storage_uri"], + "ACCESS_TOKEN": test_case["access_token"], + } + mock_env_vars(**env_vars) + + # Setup mock HuggingFace instance + mock_hf_instance = MagicMock() + if test_case["mock_config_error"]: + mock_hf_instance.load_config.side_effect = Exception + if test_case["mock_download_error"]: + mock_hf_instance.download_dataset.side_effect = Exception + + with patch( + "pkg.initializer_v2.dataset.__main__.HuggingFace", + return_value=mock_hf_instance, + ) as mock_hf: + + # Execute test + if test_case["expected_error"]: + with pytest.raises(test_case["expected_error"]): + main() + else: + main() + + # Verify HuggingFace instance methods were called + mock_hf_instance.load_config.assert_called_once() + mock_hf_instance.download_dataset.assert_called_once() + + # Verify HuggingFace class instantiation + if test_case["storage_uri"] and test_case["storage_uri"].startswith("hf://"): + mock_hf.assert_called_once() + + print("Test execution completed") diff --git a/pkg/initializer_v2/model/__main__.py b/pkg/initializer_v2/model/__main__.py index eb3126385a..2ff3636873 100644 --- a/pkg/initializer_v2/model/__main__.py +++ b/pkg/initializer_v2/model/__main__.py @@ -11,7 +11,8 @@ level=logging.INFO, ) -if __name__ == "__main__": + +def main(): logging.info("Starting pre-trained model initialization") try: @@ -31,3 +32,7 @@ f"STORAGE_URI must have the valid model provider. STORAGE_URI: {storage_uri}" ) raise Exception + + +if __name__ == "__main__": + main() diff --git a/pkg/initializer_v2/model/huggingface_test.py b/pkg/initializer_v2/model/huggingface_test.py new file mode 100644 index 0000000000..59d2ab1f20 --- /dev/null +++ b/pkg/initializer_v2/model/huggingface_test.py @@ -0,0 +1,143 @@ +from unittest.mock import MagicMock, patch + +import pytest +from kubeflow.training import MODEL_PATH + +import pkg.initializer_v2.utils.utils as utils + + +@pytest.fixture +def huggingface_model_instance(): + """Fixture for HuggingFace Model instance""" + from pkg.initializer_v2.model.huggingface import HuggingFace + + return HuggingFace() + + +# Test cases for config loading +@pytest.mark.parametrize( + "test_name, test_config, expected", + [ + ( + "Full config with token", + {"storage_uri": "hf://model/path", "access_token": "test_token"}, + {"storage_uri": "hf://model/path", "access_token": "test_token"}, + ), + ( + "Minimal config without token", + {"storage_uri": "hf://model/path"}, + {"storage_uri": "hf://model/path", "access_token": None}, + ), + ], +) +def test_load_config(test_name, test_config, expected, huggingface_model_instance): + """Test config loading with different configurations""" + print(f"Running test: {test_name}") + + with patch.object(utils, "get_config_from_env", return_value=test_config): + huggingface_model_instance.load_config() + assert huggingface_model_instance.config.storage_uri == expected["storage_uri"] + assert ( + huggingface_model_instance.config.access_token == expected["access_token"] + ) + + print("Test execution completed") + + +@pytest.mark.parametrize( + "test_name, test_case", + [ + ( + "Successful download with token", + { + "config": { + "storage_uri": "hf://username/model-name", + "access_token": "test_token", + }, + "should_login": True, + "expected_repo_id": "username/model-name", + "mock_login_side_effect": None, + "mock_download_side_effect": None, + "expected_error": None, + }, + ), + ( + "Successful download without token", + { + "config": {"storage_uri": "hf://org/model-v1", "access_token": None}, + "should_login": False, + "expected_repo_id": "org/model-v1", + "mock_login_side_effect": None, + "mock_download_side_effect": None, + "expected_error": None, + }, + ), + ( + "Login failure", + { + "config": { + "storage_uri": "hf://username/model-name", + "access_token": "test_token", + }, + "should_login": True, + "expected_repo_id": "username/model-name", + "mock_login_side_effect": Exception, + "mock_download_side_effect": None, + "expected_error": Exception, + }, + ), + ( + "Download failure", + { + "config": { + "storage_uri": "hf://invalid/repo/name", + "access_token": None, + }, + "should_login": False, + "expected_repo_id": "invalid/repo/name", + "mock_login_side_effect": None, + "mock_download_side_effect": Exception, + "expected_error": Exception, + }, + ), + ], +) +def test_download_model(test_name, test_case, huggingface_model_instance): + """Test model download with different configurations""" + + print(f"Running test: {test_name}") + + huggingface_model_instance.config = MagicMock(**test_case["config"]) + + with patch("huggingface_hub.login") as mock_login, patch( + "huggingface_hub.snapshot_download" + ) as mock_download: + + # Configure mock behavior + if test_case["mock_login_side_effect"]: + mock_login.side_effect = test_case["mock_login_side_effect"] + if test_case["mock_download_side_effect"]: + mock_download.side_effect = test_case["mock_download_side_effect"] + + # Execute test + if test_case["expected_error"]: + with pytest.raises(test_case["expected_error"]): + huggingface_model_instance.download_model() + else: + huggingface_model_instance.download_model() + + # Verify login behavior + if test_case["should_login"]: + mock_login.assert_called_once_with(test_case["config"]["access_token"]) + else: + mock_login.assert_not_called() + + # Verify download parameters + if test_case["expected_repo_id"]: + mock_download.assert_called_once_with( + repo_id=test_case["expected_repo_id"], + local_dir=MODEL_PATH, + allow_patterns=["*.json", "*.safetensors", "*.model"], + ignore_patterns=["*.msgpack", "*.h5", "*.bin", ".pt", ".pth"], + ) + print("Test execution completed") diff --git a/pkg/initializer_v2/model/main_test.py b/pkg/initializer_v2/model/main_test.py new file mode 100644 index 0000000000..6e44600c67 --- /dev/null +++ b/pkg/initializer_v2/model/main_test.py @@ -0,0 +1,122 @@ +import os +from unittest.mock import MagicMock, patch + +import pytest + +from pkg.initializer_v2.model.__main__ import main + + +@pytest.fixture +def mock_env_vars(): + """Fixture to set and clean up environment variables""" + original_env = dict(os.environ) + + def _set_env_vars(**kwargs): + for key, value in kwargs.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = str(value) + return os.environ + + yield _set_env_vars + + # Cleanup + os.environ.clear() + os.environ.update(original_env) + + +@pytest.mark.parametrize( + "test_name, test_case", + [ + ( + "Successful download with HuggingFace provider", + { + "storage_uri": "hf://model/path", + "access_token": "test_token", + "mock_config_error": False, + "mock_download_error": False, + "expected_error": None, + }, + ), + ( + "Missing storage URI environment variable", + { + "storage_uri": None, + "access_token": None, + "mock_config_error": False, + "mock_download_error": False, + "expected_error": Exception, + }, + ), + ( + "Invalid storage URI scheme", + { + "storage_uri": "invalid://model/path", + "access_token": None, + "mock_config_error": False, + "mock_download_error": False, + "expected_error": Exception, + }, + ), + ( + "Config loading failure", + { + "storage_uri": "hf://model/path", + "access_token": None, + "mock_config_error": True, + "mock_download_error": False, + "expected_error": Exception, + }, + ), + ( + "Model download failure", + { + "storage_uri": "hf://model/path/error", + "access_token": None, + "mock_config_error": False, + "mock_download_error": True, + "expected_error": Exception, + }, + ), + ], +) +def test_model_main(test_name, test_case, mock_env_vars): + """Test main script with different scenarios""" + print(f"Running test: {test_name}") + + # Setup mock environment variables + env_vars = { + "STORAGE_URI": test_case["storage_uri"], + "ACCESS_TOKEN": test_case["access_token"], + } + mock_env_vars(**env_vars) + + # Setup mock HuggingFace instance + mock_hf_instance = MagicMock() + if test_case["mock_config_error"]: + mock_hf_instance.load_config.side_effect = Exception + if test_case["mock_download_error"]: + mock_hf_instance.download_model.side_effect = Exception + + with patch( + "pkg.initializer_v2.model.__main__.HuggingFace", + return_value=mock_hf_instance, + ) as mock_hf: + + # Execute test + if test_case["expected_error"]: + with pytest.raises(test_case["expected_error"]): + main() + else: + main() + + # Verify HuggingFace instance methods were called + mock_hf_instance.load_config.assert_called_once() + mock_hf_instance.download_model.assert_called_once() + + # Verify HuggingFace class instantiation + if test_case["storage_uri"] and test_case["storage_uri"].startswith("hf://"): + mock_hf.assert_called_once() + + print("Test execution completed") diff --git a/pkg/initializer_v2/utils/utils_test.py b/pkg/initializer_v2/utils/utils_test.py new file mode 100644 index 0000000000..1ddf538aed --- /dev/null +++ b/pkg/initializer_v2/utils/utils_test.py @@ -0,0 +1,57 @@ +import os + +import pytest +from kubeflow.training import HuggingFaceDatasetConfig, HuggingFaceModelInputConfig + +import pkg.initializer_v2.utils.utils as utils + + +@pytest.fixture +def mock_env_vars(): + """Fixture to set and clean up environment variables""" + original_env = dict(os.environ) + + def _set_env_vars(**kwargs): + for key, value in kwargs.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = str(value) + return os.environ + + yield _set_env_vars + + # Cleanup + os.environ.clear() + os.environ.update(original_env) + + +@pytest.mark.parametrize( + "config_class,env_vars,expected", + [ + ( + HuggingFaceModelInputConfig, + {"STORAGE_URI": "hf://test", "ACCESS_TOKEN": "token"}, + {"storage_uri": "hf://test", "access_token": "token"}, + ), + ( + HuggingFaceModelInputConfig, + {"STORAGE_URI": "hf://test"}, + {"storage_uri": "hf://test", "access_token": None}, + ), + ( + HuggingFaceDatasetConfig, + {"STORAGE_URI": "hf://test", "ACCESS_TOKEN": "token"}, + {"storage_uri": "hf://test", "access_token": "token"}, + ), + ( + HuggingFaceDatasetConfig, + {"STORAGE_URI": "hf://test"}, + {"storage_uri": "hf://test", "access_token": None}, + ), + ], +) +def test_get_config_from_env(mock_env_vars, config_class, env_vars, expected): + mock_env_vars(**env_vars) + result = utils.get_config_from_env(config_class) + assert result == expected diff --git a/sdk_v2/kubeflow/training/types/config_test.py b/sdk_v2/kubeflow/training/types/config_test.py new file mode 100644 index 0000000000..23c5c70e32 --- /dev/null +++ b/sdk_v2/kubeflow/training/types/config_test.py @@ -0,0 +1,38 @@ +import pytest +from kubeflow.training import HuggingFaceDatasetConfig, HuggingFaceModelInputConfig + + +@pytest.mark.parametrize( + "storage_uri, access_token, expected_storage_uri, expected_access_token", + [ + ("hf://dataset/path", None, "hf://dataset/path", None), + ("hf://dataset/path", "dummy_token", "hf://dataset/path", "dummy_token"), + ], +) +def test_huggingface_dataset_config_creation( + storage_uri, access_token, expected_storage_uri, expected_access_token +): + """Test HuggingFaceDatasetConfig creation with different parameters""" + config = HuggingFaceDatasetConfig( + storage_uri=storage_uri, access_token=access_token + ) + assert config.storage_uri == expected_storage_uri + assert config.access_token == expected_access_token + + +@pytest.mark.parametrize( + "storage_uri, access_token, expected_storage_uri, expected_access_token", + [ + ("hf://model/path", None, "hf://model/path", None), + ("hf://model/path", "dummy_token", "hf://model/path", "dummy_token"), + ], +) +def test_huggingface_model_config_creation( + storage_uri, access_token, expected_storage_uri, expected_access_token +): + """Test HuggingFaceModelInputConfig creation with different parameters""" + config = HuggingFaceModelInputConfig( + storage_uri=storage_uri, access_token=access_token + ) + assert config.storage_uri == expected_storage_uri + assert config.access_token == expected_access_token diff --git a/test/e2e/initializer_v2/__init__.py b/test/e2e/initializer_v2/__init__.py new file mode 100644 index 0000000000..cae8b68a7e --- /dev/null +++ b/test/e2e/initializer_v2/__init__.py @@ -0,0 +1,6 @@ +import os +import sys + +sys.path.insert( + 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")) +) diff --git a/test/e2e/initializer_v2/test_dataset.py b/test/e2e/initializer_v2/test_dataset.py new file mode 100644 index 0000000000..887b77bc6b --- /dev/null +++ b/test/e2e/initializer_v2/test_dataset.py @@ -0,0 +1,92 @@ +import os +import runpy +import shutil +import tempfile + +import pytest +from kubeflow.training import DATASET_PATH + +import pkg.initializer_v2.utils.utils as utils + + +class TestDatasetE2E: + """E2E tests for dataset initialization""" + + @pytest.fixture(autouse=True) + def setup_teardown(self, monkeypatch): + """Setup and teardown for each test""" + # Create temporary directory for dataset downloads + current_dir = os.path.dirname(os.path.abspath(__file__)) + self.temp_dir = tempfile.mkdtemp(dir=current_dir) + os.environ[DATASET_PATH] = self.temp_dir + + # Store original environment + self.original_env = dict(os.environ) + + # Monkeypatch the constant in the module + import kubeflow.training as training + + monkeypatch.setattr(training, "DATASET_PATH", self.temp_dir) + + yield + + # Cleanup + shutil.rmtree(self.temp_dir, ignore_errors=True) + os.environ.clear() + os.environ.update(self.original_env) + + def verify_dataset_files(self, expected_files): + """Verify downloaded dataset files""" + if expected_files: + actual_files = set(os.listdir(self.temp_dir)) + missing_files = set(expected_files) - actual_files + assert not missing_files, f"Missing expected files: {missing_files}" + + @pytest.mark.parametrize( + "test_name, provider, test_case", + [ + # Public HuggingFace dataset test + ( + "HuggingFace - Public dataset", + "huggingface", + { + "storage_uri": "hf://karpathy/tiny_shakespeare", + "access_token": None, + "expected_files": ["tiny_shakespeare.py"], + "expected_error": None, + }, + ), + ( + "HuggingFace - Invalid dataset", + "huggingface", + { + "storage_uri": "hf://invalid/nonexistent-dataset", + "access_token": None, + "expected_files": None, + "expected_error": Exception, + }, + ), + ], + ) + def test_dataset_download(self, test_name, provider, test_case): + """Test end-to-end dataset download for different providers""" + print(f"Running E2E test for {provider}: {test_name}") + + # Setup environment variables based on test case + os.environ[utils.STORAGE_URI_ENV] = test_case["storage_uri"] + expected_files = test_case.get("expected_files") + + if test_case.get("access_token"): + os.environ["ACCESS_TOKEN"] = test_case["access_token"] + + # Run the main script + if test_case["expected_error"]: + with pytest.raises(test_case["expected_error"]): + runpy.run_module( + "pkg.initializer_v2.dataset.__main__", run_name="__main__" + ) + else: + runpy.run_module("pkg.initializer_v2.dataset.__main__", run_name="__main__") + self.verify_dataset_files(expected_files) + + print("Test execution completed") diff --git a/test/e2e/initializer_v2/test_model.py b/test/e2e/initializer_v2/test_model.py new file mode 100644 index 0000000000..61928b3b8e --- /dev/null +++ b/test/e2e/initializer_v2/test_model.py @@ -0,0 +1,98 @@ +import os +import runpy +import shutil +import tempfile + +import pytest +from kubeflow.training import MODEL_PATH + +import pkg.initializer_v2.utils.utils as utils + + +class TestModelE2E: + """E2E tests for model initialization""" + + @pytest.fixture(autouse=True) + def setup_teardown(self, monkeypatch): + """Setup and teardown for each test""" + # Create temporary directory for model downloads + current_dir = os.path.dirname(os.path.abspath(__file__)) + self.temp_dir = tempfile.mkdtemp(dir=current_dir) + os.environ[MODEL_PATH] = self.temp_dir + + # Store original environment + self.original_env = dict(os.environ) + + # Monkeypatch the constant in the module + import kubeflow.training as training + + monkeypatch.setattr(training, "MODEL_PATH", self.temp_dir) + + yield + + # Cleanup + shutil.rmtree(self.temp_dir, ignore_errors=True) + os.environ.clear() + os.environ.update(self.original_env) + + def verify_model_files(self, expected_files): + """Verify downloaded model files""" + if expected_files: + actual_files = set(os.listdir(self.temp_dir)) + missing_files = set(expected_files) - actual_files + assert not missing_files, f"Missing expected files: {missing_files}" + + @pytest.mark.parametrize( + "test_name, provider, test_case", + [ + # Public HuggingFace model test + ( + "HuggingFace - Public model", + "huggingface", + { + "storage_uri": "hf://hf-internal-testing/tiny-random-bert", + "access_token": None, + "expected_files": [ + "config.json", + "model.safetensors", + "tokenizer.json", + "tokenizer_config.json", + ], + "expected_error": None, + }, + ), + ( + "HuggingFace - Invalid model", + "huggingface", + { + "storage_uri": "hf://invalid/nonexistent-model", + "access_token": None, + "expected_files": None, + "expected_error": Exception, + }, + ), + ], + ) + def test_model_download(self, test_name, provider, test_case): + """Test end-to-end model download for different providers""" + print(f"Running E2E test for {provider}: {test_name}") + + # Setup environment variables based on test case + os.environ[utils.STORAGE_URI_ENV] = test_case["storage_uri"] + expected_files = test_case.get("expected_files") + + # Handle token/credentials + if test_case.get("access_token"): + os.environ["ACCESS_TOKEN"] = test_case["access_token"] + + # Run the main script + if test_case["expected_error"]: + with pytest.raises(test_case["expected_error"]): + runpy.run_module( + "pkg.initializer_v2.model.__main__", run_name="__main__" + ) + else: + runpy.run_module("pkg.initializer_v2.model.__main__", run_name="__main__") + self.verify_model_files(expected_files) + + print("Test execution completed")