From 9c59b8020af72ef90fdbab389816665694536f0f Mon Sep 17 00:00:00 2001 From: Marcus Robinson Date: Thu, 4 Jan 2024 00:35:45 +0000 Subject: [PATCH] Move to using managed identity for auth to CosmosDB. (#3806) --- CHANGELOG.md | 1 + api_app/_version.py | 2 +- api_app/api/dependencies/airlock.py | 2 +- api_app/api/dependencies/database.py | 134 +++++++++--------- api_app/api/dependencies/shared_services.py | 2 +- .../workspace_service_templates.py | 2 +- api_app/api/dependencies/workspaces.py | 2 +- api_app/api/helpers.py | 22 +++ api_app/api/routes/airlock.py | 2 +- api_app/api/routes/api.py | 2 +- api_app/api/routes/costs.py | 2 +- api_app/api/routes/health.py | 8 +- api_app/api/routes/migrations.py | 2 +- api_app/api/routes/operations.py | 2 +- .../api/routes/shared_service_templates.py | 2 +- api_app/api/routes/shared_services.py | 2 +- api_app/api/routes/user_resource_templates.py | 2 +- .../api/routes/workspace_service_templates.py | 2 +- api_app/api/routes/workspace_templates.py | 2 +- api_app/api/routes/workspaces.py | 2 +- api_app/core/credentials.py | 32 +++-- api_app/db/events.py | 52 +++++-- api_app/db/migrations/airlock.py | 23 ++- api_app/db/migrations/resources.py | 6 +- api_app/db/migrations/shared_services.py | 6 +- api_app/db/migrations/workspaces.py | 6 +- api_app/db/repositories/airlock_requests.py | 5 +- api_app/db/repositories/base.py | 24 ++-- api_app/db/repositories/operations.py | 5 +- api_app/db/repositories/resource_templates.py | 5 +- api_app/db/repositories/resources.py | 7 +- api_app/db/repositories/resources_history.py | 5 +- api_app/db/repositories/shared_services.py | 5 +- api_app/db/repositories/user_resources.py | 5 +- api_app/db/repositories/workspace_services.py | 5 +- api_app/db/repositories/workspaces.py | 5 +- api_app/event_grid/helpers.py | 2 +- api_app/main.py | 8 +- api_app/resources/strings.py | 1 + .../airlock_request_status_update.py | 12 +- .../service_bus/deployment_status_updater.py | 16 +-- api_app/service_bus/helpers.py | 2 +- api_app/services/aad_authentication.py | 4 +- api_app/services/health_checker.py | 15 +- api_app/tests_ma/conftest.py | 22 +-- api_app/tests_ma/test_api/conftest.py | 10 -- .../test_api/dependencies/__init__.py | 0 .../test_api/dependencies/test_database.py | 12 ++ api_app/tests_ma/test_api/test_helpers.py | 17 +++ .../test_routes/test_resource_helpers.py | 17 +-- .../test_api/test_routes/test_workspaces.py | 12 +- api_app/tests_ma/test_core/__init__.py | 0 .../tests_ma/test_core/test_credentials.py | 24 ++++ api_app/tests_ma/test_db/test_events.py | 28 ++++ .../test_workspace_migration.py | 7 +- .../test_airlock_request_repository.py | 7 +- .../test_repositories/test_base_repository.py | 13 +- .../test_operation_repository.py | 21 ++- .../test_resource_history_repository.py | 7 +- .../test_resource_repository.py | 14 +- .../test_resource_templates_repository.py | 7 +- .../test_shared_service_repository.py | 14 +- ...est_shared_service_templates_repository.py | 7 +- .../test_user_resource_repository.py | 7 +- ...test_user_resource_templates_repository.py | 7 +- .../test_workpaces_repository.py | 14 +- .../test_workpaces_service_repository.py | 14 +- .../test_airlock_request_status_update.py | 30 ++-- .../test_deployment_status_update.py | 50 +++---- .../tests_ma/test_services/test_airlock.py | 5 +- .../test_services/test_health_checker.py | 85 +++++------ core/terraform/api-identity.tf | 13 ++ core/version.txt | 2 +- 73 files changed, 488 insertions(+), 431 deletions(-) create mode 100644 api_app/api/helpers.py create mode 100644 api_app/tests_ma/test_api/dependencies/__init__.py create mode 100644 api_app/tests_ma/test_api/dependencies/test_database.py create mode 100644 api_app/tests_ma/test_api/test_helpers.py create mode 100644 api_app/tests_ma/test_core/__init__.py create mode 100644 api_app/tests_ma/test_core/test_credentials.py create mode 100644 api_app/tests_ma/test_db/test_events.py diff --git a/CHANGELOG.md b/CHANGELOG.md index c2f2e20290..f79260fc5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ FEATURES: ENHANCEMENTS: * Switch from OpenCensus to OpenTelemetry for logging ([#3762](https://github.com/microsoft/AzureTRE/pull/3762)) +* Use managed identity for API connection to CosmosDB ([#345](https://github.com/microsoft/AzureTRE/issues/345)) * Switch to Structured Firewall Logs ([#3816](https://github.com/microsoft/AzureTRE/pull/3816)) BUG FIXES: diff --git a/api_app/_version.py b/api_app/_version.py index c6eae9f8a3..1317d7554a 100644 --- a/api_app/_version.py +++ b/api_app/_version.py @@ -1 +1 @@ -__version__ = "0.17.1" +__version__ = "0.18.0" diff --git a/api_app/api/dependencies/airlock.py b/api_app/api/dependencies/airlock.py index 4a15aa2741..c824352ee5 100644 --- a/api_app/api/dependencies/airlock.py +++ b/api_app/api/dependencies/airlock.py @@ -1,7 +1,7 @@ from fastapi import Depends, HTTPException, Path, status from pydantic import UUID4 -from api.dependencies.database import get_repository +from api.helpers import get_repository from db.repositories.airlock_requests import AirlockRequestRepository from models.domain.airlock_request import AirlockRequest from db.errors import EntityDoesNotExist, UnableToAccessDatabase diff --git a/api_app/api/dependencies/database.py b/api_app/api/dependencies/database.py index 61dfe0f901..7bfc89ff22 100644 --- a/api_app/api/dependencies/database.py +++ b/api_app/api/dependencies/database.py @@ -1,80 +1,86 @@ -from typing import Callable, Type - -from azure.cosmos.aio import CosmosClient +from azure.cosmos.aio import CosmosClient, DatabaseProxy, ContainerProxy from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient -from fastapi import Depends, FastAPI, HTTPException -from fastapi import Request, status -from core import config, credentials -from db.errors import UnableToAccessDatabase -from db.repositories.base import BaseRepository -from resources import strings + +from core.config import MANAGED_IDENTITY_CLIENT_ID, STATE_STORE_ENDPOINT, STATE_STORE_KEY, STATE_STORE_SSL_VERIFY, SUBSCRIPTION_ID, RESOURCE_MANAGER_ENDPOINT, CREDENTIAL_SCOPES, RESOURCE_GROUP_NAME, COSMOSDB_ACCOUNT_NAME, STATE_STORE_DATABASE +from core.credentials import get_credential_async from services.logging import logger -async def connect_to_db() -> CosmosClient: - logger.debug(f"Connecting to {config.STATE_STORE_ENDPOINT}") +class Singleton(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + - try: - async with credentials.get_credential_async() as credential: - primary_master_key = await get_store_key(credential) +class Database(metaclass=Singleton): - if config.STATE_STORE_SSL_VERIFY: + _cosmos_client: CosmosClient = None + _database_proxy: DatabaseProxy = None + + def __init__(cls): + pass + + @classmethod + async def _connect_to_db(cls) -> CosmosClient: + logger.debug(f"Connecting to {STATE_STORE_ENDPOINT}") + + credential = await get_credential_async() + if MANAGED_IDENTITY_CLIENT_ID: + logger.debug("Connecting with managed identity") cosmos_client = CosmosClient( - url=config.STATE_STORE_ENDPOINT, credential=primary_master_key + url=STATE_STORE_ENDPOINT, + credential=credential ) else: - # ignore TLS (setup is a pain) when using local Cosmos emulator. - cosmos_client = CosmosClient( - config.STATE_STORE_ENDPOINT, primary_master_key, connection_verify=False - ) + logger.debug("Connecting with key") + primary_master_key = await cls._get_store_key(credential) + + if STATE_STORE_SSL_VERIFY: + logger.debug("Connecting with SSL verification") + cosmos_client = CosmosClient( + url=STATE_STORE_ENDPOINT, + credential=primary_master_key + ) + else: + logger.debug("Connecting without SSL verification") + # ignore TLS (setup is a pain) when using local Cosmos emulator. + cosmos_client = CosmosClient( + url=STATE_STORE_ENDPOINT, + credential=primary_master_key, + connection_verify=False + ) logger.debug("Connection established") return cosmos_client - except Exception: - logger.exception("Connection to state store could not be established.") - - -async def get_store_key(credential) -> str: - if config.STATE_STORE_KEY: - primary_master_key = config.STATE_STORE_KEY - else: - async with CosmosDBManagementClient( - credential, - subscription_id=config.SUBSCRIPTION_ID, - base_url=config.RESOURCE_MANAGER_ENDPOINT, - credential_scopes=config.CREDENTIAL_SCOPES - ) as cosmosdb_mng_client: - database_keys = await cosmosdb_mng_client.database_accounts.list_keys( - resource_group_name=config.RESOURCE_GROUP_NAME, - account_name=config.COSMOSDB_ACCOUNT_NAME, - ) - primary_master_key = database_keys.primary_master_key - - return primary_master_key - - -async def get_db_client(app: FastAPI) -> CosmosClient: - if not hasattr(app.state, 'cosmos_client') or not app.state.cosmos_client: - app.state.cosmos_client = await connect_to_db() - return app.state.cosmos_client + @classmethod + async def _get_store_key(cls, credential) -> str: + logger.debug("Getting store key") + if STATE_STORE_KEY: + primary_master_key = STATE_STORE_KEY + else: + async with CosmosDBManagementClient( + credential, + subscription_id=SUBSCRIPTION_ID, + base_url=RESOURCE_MANAGER_ENDPOINT, + credential_scopes=CREDENTIAL_SCOPES + ) as cosmosdb_mng_client: + database_keys = await cosmosdb_mng_client.database_accounts.list_keys( + resource_group_name=RESOURCE_GROUP_NAME, + account_name=COSMOSDB_ACCOUNT_NAME, + ) + primary_master_key = database_keys.primary_master_key -async def get_db_client_from_request(request: Request) -> CosmosClient: - return await get_db_client(request.app) + return primary_master_key + @classmethod + async def get_container_proxy(cls, container_name) -> ContainerProxy: + if cls._cosmos_client is None: + cls._cosmos_client = await cls._connect_to_db() -def get_repository( - repo_type: Type[BaseRepository], -) -> Callable[[CosmosClient], BaseRepository]: - async def _get_repo( - client: CosmosClient = Depends(get_db_client_from_request), - ) -> BaseRepository: - try: - return await repo_type.create(client) - except UnableToAccessDatabase: - logger.exception(strings.STATE_STORE_ENDPOINT_NOT_RESPONDING) - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail=strings.STATE_STORE_ENDPOINT_NOT_RESPONDING, - ) + if cls._database_proxy is None: + cls._database_proxy = cls._cosmos_client.get_database_client(STATE_STORE_DATABASE) - return _get_repo + return cls._database_proxy.get_container_client(container_name) diff --git a/api_app/api/dependencies/shared_services.py b/api_app/api/dependencies/shared_services.py index 970f120776..87bf4474cc 100644 --- a/api_app/api/dependencies/shared_services.py +++ b/api_app/api/dependencies/shared_services.py @@ -1,7 +1,7 @@ from fastapi import Depends, HTTPException, Path, status from pydantic import UUID4 -from api.dependencies.database import get_repository +from api.helpers import get_repository from db.errors import EntityDoesNotExist from resources import strings from models.domain.shared_service import SharedService diff --git a/api_app/api/dependencies/workspace_service_templates.py b/api_app/api/dependencies/workspace_service_templates.py index 17d049d153..56bf60e869 100644 --- a/api_app/api/dependencies/workspace_service_templates.py +++ b/api_app/api/dependencies/workspace_service_templates.py @@ -1,6 +1,6 @@ from fastapi import Depends, HTTPException, Path, status -from api.dependencies.database import get_repository +from api.helpers import get_repository from db.errors import EntityDoesNotExist from db.repositories.resource_templates import ResourceTemplateRepository from models.domain.resource import ResourceType diff --git a/api_app/api/dependencies/workspaces.py b/api_app/api/dependencies/workspaces.py index 40f566845f..aae2cc7213 100644 --- a/api_app/api/dependencies/workspaces.py +++ b/api_app/api/dependencies/workspaces.py @@ -1,7 +1,7 @@ from fastapi import Depends, HTTPException, Path, status from pydantic import UUID4 -from api.dependencies.database import get_repository +from api.helpers import get_repository from db.errors import EntityDoesNotExist, ResourceIsNotDeployed from db.repositories.operations import OperationRepository from db.repositories.user_resources import UserResourceRepository diff --git a/api_app/api/helpers.py b/api_app/api/helpers.py new file mode 100644 index 0000000000..1c2d3a0529 --- /dev/null +++ b/api_app/api/helpers.py @@ -0,0 +1,22 @@ +from typing import Callable, Type + +from fastapi import HTTPException, status + +from db.errors import UnableToAccessDatabase +from db.repositories.base import BaseRepository +from resources.strings import UNABLE_TO_GET_STATE_STORE_CLIENT +from services.logging import logger + + +def get_repository(repo_type: Type[BaseRepository],) -> Callable: + async def _get_repo() -> BaseRepository: + try: + return await repo_type.create() + except UnableToAccessDatabase: + logger.exception(UNABLE_TO_GET_STATE_STORE_CLIENT) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail=UNABLE_TO_GET_STATE_STORE_CLIENT, + ) + + return _get_repo diff --git a/api_app/api/routes/airlock.py b/api_app/api/routes/airlock.py index 9fa9790e97..7aefa62b41 100644 --- a/api_app/api/routes/airlock.py +++ b/api_app/api/routes/airlock.py @@ -3,6 +3,7 @@ from fastapi import APIRouter, Depends, HTTPException, status as status_code, Response from jsonschema.exceptions import ValidationError +from api.helpers import get_repository from db.repositories.resources_history import ResourceHistoryRepository from db.repositories.user_resources import UserResourceRepository from db.repositories.workspace_services import WorkspaceServiceRepository @@ -11,7 +12,6 @@ from db.repositories.airlock_requests import AirlockRequestRepository from db.errors import EntityDoesNotExist, UserNotAuthorizedToUseTemplate -from api.dependencies.database import get_repository from api.dependencies.workspaces import get_workspace_by_id_from_path, get_deployed_workspace_by_id_from_path from api.dependencies.airlock import get_airlock_request_by_id_from_path from models.domain.airlock_request import AirlockRequestStatus, AirlockRequestType diff --git a/api_app/api/routes/api.py b/api_app/api/routes/api.py index 5cf104f976..3c27782f72 100644 --- a/api_app/api/routes/api.py +++ b/api_app/api/routes/api.py @@ -5,7 +5,7 @@ from fastapi.openapi.docs import get_swagger_ui_html, get_swagger_ui_oauth2_redirect_html from fastapi.openapi.utils import get_openapi -from api.dependencies.database import get_repository +from api.helpers import get_repository from db.repositories.workspaces import WorkspaceRepository from api.routes import health, ping, workspaces, workspace_templates, workspace_service_templates, user_resource_templates, \ shared_services, shared_service_templates, migrations, costs, airlock, operations, metadata diff --git a/api_app/api/routes/costs.py b/api_app/api/routes/costs.py index b6684756f2..f7e89fa1fc 100644 --- a/api_app/api/routes/costs.py +++ b/api_app/api/routes/costs.py @@ -7,8 +7,8 @@ from pydantic import UUID4 from models.schemas.costs import get_cost_report_responses, get_workspace_cost_report_responses -from api.dependencies.database import get_repository from core import config +from api.helpers import get_repository from db.repositories.shared_services import SharedServiceRepository from db.repositories.user_resources import UserResourceRepository from db.repositories.workspace_services import WorkspaceServiceRepository diff --git a/api_app/api/routes/health.py b/api_app/api/routes/health.py index 301a6fd54d..2cefe21266 100644 --- a/api_app/api/routes/health.py +++ b/api_app/api/routes/health.py @@ -1,5 +1,5 @@ import asyncio -from fastapi import APIRouter +from fastapi import APIRouter, Request from core import credentials from models.schemas.status import HealthCheck, ServiceStatus, StatusEnum from resources import strings @@ -10,13 +10,13 @@ @router.get("/health", name=strings.API_GET_HEALTH_STATUS) -async def health_check() -> HealthCheck: +async def health_check(request: Request) -> HealthCheck: # The health endpoint checks the status of key components of the system. # Note that Resource Processor checks incur Azure management calls, so # calling this endpoint frequently may result in API throttling. - async with credentials.get_credential_async() as credential: + async with credentials.get_credential_async_context() as credential: cosmos, sb, rp = await asyncio.gather( - create_state_store_status(credential), + create_state_store_status(), create_service_bus_status(credential), create_resource_processor_status(credential) ) diff --git a/api_app/api/routes/migrations.py b/api_app/api/routes/migrations.py index 48fe49437f..692f664583 100644 --- a/api_app/api/routes/migrations.py +++ b/api_app/api/routes/migrations.py @@ -1,11 +1,11 @@ from fastapi import APIRouter, Depends, HTTPException, status from db.migrations.airlock import AirlockMigration from db.migrations.resources import ResourceMigration +from api.helpers import get_repository from db.repositories.operations import OperationRepository from db.repositories.resources_history import ResourceHistoryRepository from services.authentication import get_current_admin_user from resources import strings -from api.dependencies.database import get_repository from db.migrations.shared_services import SharedServiceMigration from db.migrations.workspaces import WorkspaceMigration from db.repositories.resources import ResourceRepository diff --git a/api_app/api/routes/operations.py b/api_app/api/routes/operations.py index d5a707bebf..0ab67f5be2 100644 --- a/api_app/api/routes/operations.py +++ b/api_app/api/routes/operations.py @@ -1,7 +1,7 @@ from fastapi import APIRouter, Depends +from api.helpers import get_repository from db.repositories.operations import OperationRepository -from api.dependencies.database import get_repository from models.schemas.operation import OperationInList from resources import strings from services.authentication import get_current_tre_user_or_tre_admin diff --git a/api_app/api/routes/shared_service_templates.py b/api_app/api/routes/shared_service_templates.py index fee2369ae6..b7801c3789 100644 --- a/api_app/api/routes/shared_service_templates.py +++ b/api_app/api/routes/shared_service_templates.py @@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException, status from pydantic import parse_obj_as -from api.dependencies.database import get_repository +from api.helpers import get_repository from db.errors import EntityDoesNotExist, EntityVersionExist, InvalidInput from db.repositories.resource_templates import ResourceTemplateRepository from models.domain.resource import ResourceType diff --git a/api_app/api/routes/shared_services.py b/api_app/api/routes/shared_services.py index f09c3f996f..6e23945bdd 100644 --- a/api_app/api/routes/shared_services.py +++ b/api_app/api/routes/shared_services.py @@ -5,7 +5,7 @@ from db.repositories.operations import OperationRepository from db.errors import DuplicateEntity, MajorVersionUpdateDenied, UserNotAuthorizedToUseTemplate, TargetTemplateVersionDoesNotExist, VersionDowngradeDenied -from api.dependencies.database import get_repository +from api.helpers import get_repository from api.dependencies.shared_services import get_shared_service_by_id_from_path, get_operation_by_id_from_path from db.repositories.resource_templates import ResourceTemplateRepository from db.repositories.resources_history import ResourceHistoryRepository diff --git a/api_app/api/routes/user_resource_templates.py b/api_app/api/routes/user_resource_templates.py index b4cc9f9b6b..f4cb24cb5f 100644 --- a/api_app/api/routes/user_resource_templates.py +++ b/api_app/api/routes/user_resource_templates.py @@ -3,10 +3,10 @@ from fastapi import APIRouter, Depends, HTTPException, status from pydantic import parse_obj_as -from api.dependencies.database import get_repository from api.dependencies.workspace_service_templates import get_workspace_service_template_by_name_from_path from api.routes.resource_helpers import get_template from db.errors import EntityVersionExist, InvalidInput +from api.helpers import get_repository from db.repositories.resource_templates import ResourceTemplateRepository from models.domain.resource import ResourceType from models.schemas.user_resource_template import UserResourceTemplateInResponse, UserResourceTemplateInCreate diff --git a/api_app/api/routes/workspace_service_templates.py b/api_app/api/routes/workspace_service_templates.py index c04558ad81..e6df3fadba 100644 --- a/api_app/api/routes/workspace_service_templates.py +++ b/api_app/api/routes/workspace_service_templates.py @@ -2,9 +2,9 @@ from fastapi import APIRouter, Depends, HTTPException, status from pydantic import parse_obj_as -from api.dependencies.database import get_repository from api.routes.resource_helpers import get_template from db.errors import EntityVersionExist, InvalidInput +from api.helpers import get_repository from db.repositories.resource_templates import ResourceTemplateRepository from models.domain.resource import ResourceType from models.schemas.resource_template import ResourceTemplateInResponse, ResourceTemplateInformationInList diff --git a/api_app/api/routes/workspace_templates.py b/api_app/api/routes/workspace_templates.py index 6ba864724d..c61b6f7e82 100644 --- a/api_app/api/routes/workspace_templates.py +++ b/api_app/api/routes/workspace_templates.py @@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException, status from pydantic import parse_obj_as -from api.dependencies.database import get_repository +from api.helpers import get_repository from db.errors import EntityVersionExist, InvalidInput from db.repositories.resource_templates import ResourceTemplateRepository from models.domain.resource import ResourceType diff --git a/api_app/api/routes/workspaces.py b/api_app/api/routes/workspaces.py index a086812ff5..018a21999c 100644 --- a/api_app/api/routes/workspaces.py +++ b/api_app/api/routes/workspaces.py @@ -4,7 +4,7 @@ from jsonschema.exceptions import ValidationError -from api.dependencies.database import get_repository +from api.helpers import get_repository from api.dependencies.workspaces import get_operation_by_id_from_path, get_workspace_by_id_from_path, get_deployed_workspace_by_id_from_path, get_deployed_workspace_service_by_id_from_path, get_workspace_service_by_id_from_path, get_user_resource_by_id_from_path from db.errors import InvalidInput, MajorVersionUpdateDenied, TargetTemplateVersionDoesNotExist, UserNotAuthorizedToUseTemplate, VersionDowngradeDenied from db.repositories.operations import OperationRepository diff --git a/api_app/core/credentials.py b/api_app/core/credentials.py index 8780248a63..427f62b529 100644 --- a/api_app/core/credentials.py +++ b/api_app/core/credentials.py @@ -1,5 +1,5 @@ from contextlib import asynccontextmanager -from core import config +from core.config import MANAGED_IDENTITY_CLIENT_ID, AAD_AUTHORITY_URL from azure.core.credentials import TokenCredential from urllib.parse import urlparse @@ -16,13 +16,12 @@ def get_credential() -> TokenCredential: - managed_identity = config.MANAGED_IDENTITY_CLIENT_ID - if managed_identity: + if MANAGED_IDENTITY_CLIENT_ID: return ChainedTokenCredential( - ManagedIdentityCredential(client_id=managed_identity) + ManagedIdentityCredential(client_id=MANAGED_IDENTITY_CLIENT_ID) ) else: - return DefaultAzureCredential(authority=urlparse(config.AAD_AUTHORITY_URL).netloc, + return DefaultAzureCredential(authority=urlparse(AAD_AUTHORITY_URL).netloc, exclude_shared_token_cache_credential=True, exclude_workload_identity_credential=True, exclude_developer_cli_credential=True, @@ -31,18 +30,13 @@ def get_credential() -> TokenCredential: ) -@asynccontextmanager -async def get_credential_async() -> TokenCredential: - """ - Context manager which yields the default credentials. - """ - managed_identity = config.MANAGED_IDENTITY_CLIENT_ID - credential = ( +async def get_credential_async(): + return ( ChainedTokenCredentialASync( - ManagedIdentityCredentialASync(client_id=managed_identity) + ManagedIdentityCredentialASync(client_id=MANAGED_IDENTITY_CLIENT_ID) ) - if managed_identity - else DefaultAzureCredentialASync(authority=urlparse(config.AAD_AUTHORITY_URL).netloc, + if MANAGED_IDENTITY_CLIENT_ID + else DefaultAzureCredentialASync(authority=urlparse(AAD_AUTHORITY_URL).netloc, exclude_shared_token_cache_credential=True, exclude_workload_identity_credential=True, exclude_developer_cli_credential=True, @@ -50,5 +44,13 @@ async def get_credential_async() -> TokenCredential: exclude_powershell_credential=True ) ) + + +@asynccontextmanager +async def get_credential_async_context() -> TokenCredential: + """ + Context manager which yields the default credentials. + """ + credential = await get_credential_async() yield credential await credential.close() diff --git a/api_app/db/events.py b/api_app/db/events.py index af92b2a859..462af10ed5 100644 --- a/api_app/db/events.py +++ b/api_app/db/events.py @@ -1,19 +1,49 @@ -from azure.cosmos.aio import CosmosClient +import asyncio +from azure.mgmt.cosmosdb import CosmosDBManagementClient -from api.dependencies.database import get_db_client -from db.repositories.resources import ResourceRepository -from core import config +from core.config import SUBSCRIPTION_ID, RESOURCE_GROUP_NAME, RESOURCE_LOCATION, COSMOSDB_ACCOUNT_NAME, STATE_STORE_DATABASE, STATE_STORE_RESOURCES_CONTAINER, STATE_STORE_RESOURCE_TEMPLATES_CONTAINER, STATE_STORE_RESOURCES_HISTORY_CONTAINER, STATE_STORE_OPERATIONS_CONTAINER, STATE_STORE_AIRLOCK_REQUESTS_CONTAINER +from core.credentials import get_credential from services.logging import logger -async def bootstrap_database(app) -> bool: +async def bootstrap_database() -> bool: try: - client: CosmosClient = await get_db_client(app) - if client: - await client.create_database_if_not_exists(id=config.STATE_STORE_DATABASE) - # Test access to database - await ResourceRepository.create(client) - return True + credential = get_credential() + db_mgmt_client = CosmosDBManagementClient(credential=credential, subscription_id=SUBSCRIPTION_ID) + + await asyncio.gather( + create_container_if_not_exists(db_mgmt_client, STATE_STORE_RESOURCES_CONTAINER, "/id"), + create_container_if_not_exists(db_mgmt_client, STATE_STORE_RESOURCE_TEMPLATES_CONTAINER, "/id"), + create_container_if_not_exists(db_mgmt_client, STATE_STORE_RESOURCES_HISTORY_CONTAINER, "/resourceId"), + create_container_if_not_exists(db_mgmt_client, STATE_STORE_OPERATIONS_CONTAINER, "/id"), + create_container_if_not_exists(db_mgmt_client, STATE_STORE_AIRLOCK_REQUESTS_CONTAINER, "/id") + ) + + return True + except Exception as e: + logger.exception("Could not bootstrap database") logger.debug(e) return False + + +async def create_container_if_not_exists(db_mgmt_client, container, partition_key): + + db_mgmt_client.sql_resources.begin_create_update_sql_container( + resource_group_name=RESOURCE_GROUP_NAME, + account_name=COSMOSDB_ACCOUNT_NAME, + database_name=STATE_STORE_DATABASE, + container_name=container, + create_update_sql_container_parameters={ + "location": RESOURCE_LOCATION, + "resource": { + "id": container, + "partition_key": { + "paths": [ + partition_key + ], + "kind": "Hash" + } + } + } + ) diff --git a/api_app/db/migrations/airlock.py b/api_app/db/migrations/airlock.py index 299441fb63..aa0ce011cd 100644 --- a/api_app/db/migrations/airlock.py +++ b/api_app/db/migrations/airlock.py @@ -1,15 +1,13 @@ -from azure.cosmos.aio import CosmosClient from resources import strings from db.repositories.airlock_requests import AirlockRequestRepository class AirlockMigration(AirlockRequestRepository): @classmethod - async def create(cls, client: CosmosClient): + async def create(cls): cls = AirlockMigration() - resource_repo = await super().create(client) + resource_repo = await super().create() cls._container = resource_repo._container - cls._client = resource_repo._client return cls async def add_created_by_and_rename_in_history(self) -> int: @@ -42,15 +40,14 @@ async def change_review_resources_to_dict(self) -> int: num_updated = 0 for request in await self.query('SELECT * FROM c'): # Only migrate if airlockReviewResources property present and is a list - if 'reviewUserResources' in request: - if type(request['reviewUserResources']) == list: - updated_review_resources = {} - for i, resource in enumerate(request['reviewUserResources']): - updated_review_resources['UNKNOWN' + str(i)] = resource - - request['reviewUserResources'] = updated_review_resources - await self.update_item_dict(request) - num_updated += 1 + if 'reviewUserResources' in request and isinstance(request['reviewUserResources'], list): + updated_review_resources = {} + for i, resource in enumerate(request['reviewUserResources']): + updated_review_resources['UNKNOWN' + str(i)] = resource + + request['reviewUserResources'] = updated_review_resources + await self.update_item_dict(request) + num_updated += 1 return num_updated diff --git a/api_app/db/migrations/resources.py b/api_app/db/migrations/resources.py index c0cdcf25e7..4c99bff4cf 100644 --- a/api_app/db/migrations/resources.py +++ b/api_app/db/migrations/resources.py @@ -1,5 +1,4 @@ import uuid -from azure.cosmos.aio import CosmosClient from db.repositories.operations import OperationRepository from db.repositories.resources import ResourceRepository from db.repositories.resources_history import ResourceHistoryRepository @@ -7,11 +6,10 @@ class ResourceMigration(ResourceRepository): @classmethod - async def create(cls, client: CosmosClient): + async def create(cls): cls = ResourceMigration() - resource_repo = await super().create(client) + resource_repo = await super().create() cls._container = resource_repo._container - cls._client = resource_repo._client return cls async def add_deployment_status_field(self, operations_repository: OperationRepository) -> int: diff --git a/api_app/db/migrations/shared_services.py b/api_app/db/migrations/shared_services.py index 575ac74bb2..621991314a 100644 --- a/api_app/db/migrations/shared_services.py +++ b/api_app/db/migrations/shared_services.py @@ -1,4 +1,3 @@ -from azure.cosmos.aio import CosmosClient import semantic_version from db.repositories.shared_services import SharedServiceRepository @@ -8,11 +7,10 @@ class SharedServiceMigration(SharedServiceRepository): @classmethod - async def create(cls, client: CosmosClient): + async def create(cls): cls = SharedServiceMigration() - resource_repo = await super().create(client) + resource_repo = await super().create() cls._container = resource_repo._container - cls._client = resource_repo._client return cls async def deleteDuplicatedSharedServices(self) -> bool: diff --git a/api_app/db/migrations/workspaces.py b/api_app/db/migrations/workspaces.py index 79209ec30a..e8a422080b 100644 --- a/api_app/db/migrations/workspaces.py +++ b/api_app/db/migrations/workspaces.py @@ -1,4 +1,3 @@ -from azure.cosmos.aio import CosmosClient import semantic_version from db.repositories.workspaces import WorkspaceRepository @@ -7,11 +6,10 @@ class WorkspaceMigration(WorkspaceRepository): @classmethod - async def create(cls, client: CosmosClient): + async def create(cls): cls = WorkspaceMigration() - resource_repo = await super().create(client) + resource_repo = await super().create() cls._container = resource_repo._container - cls._client = resource_repo._client return cls async def moveAuthInformationToProperties(self) -> bool: diff --git a/api_app/db/repositories/airlock_requests.py b/api_app/db/repositories/airlock_requests.py index ba683723d6..f4a1926348 100644 --- a/api_app/db/repositories/airlock_requests.py +++ b/api_app/db/repositories/airlock_requests.py @@ -5,7 +5,6 @@ from typing import List, Optional from pydantic import UUID4 from azure.cosmos.exceptions import CosmosResourceNotFoundError, CosmosAccessConditionFailedError -from azure.cosmos.aio import CosmosClient from fastapi import HTTPException, status from pydantic import parse_obj_as from models.domain.authentication import User @@ -21,9 +20,9 @@ class AirlockRequestRepository(BaseRepository): @classmethod - async def create(cls, client: CosmosClient): + async def create(cls): cls = AirlockRequestRepository() - await super().create(client, config.STATE_STORE_AIRLOCK_REQUESTS_CONTAINER) + await super().create(config.STATE_STORE_AIRLOCK_REQUESTS_CONTAINER) return cls @staticmethod diff --git a/api_app/db/repositories/base.py b/api_app/db/repositories/base.py index 35395fa064..7fe5371b5c 100644 --- a/api_app/db/repositories/base.py +++ b/api_app/db/repositories/base.py @@ -1,34 +1,26 @@ from typing import Optional -from azure.cosmos.aio import CosmosClient, ContainerProxy -from azure.cosmos import PartitionKey +from azure.cosmos.aio import ContainerProxy from azure.core import MatchConditions from pydantic import BaseModel -from core import config +from api.dependencies.database import Database from db.errors import UnableToAccessDatabase class BaseRepository: @classmethod - async def create(cls, client: CosmosClient, container_name: Optional[str] = None, partition_key: str = "/id"): - partition_key_obj = PartitionKey(path=partition_key) - cls._client: CosmosClient = client - cls._container: ContainerProxy = await cls._get_container(container_name, partition_key_obj) + async def create(cls, container_name: Optional[str] = None): + try: + cls._container: ContainerProxy = await Database().get_container_proxy(container_name) + except Exception: + raise UnableToAccessDatabase + return cls @property def container(self) -> ContainerProxy: return self._container - @classmethod - async def _get_container(cls, container_name, partition_key_obj) -> ContainerProxy: - try: - database = cls._client.get_database_client(config.STATE_STORE_DATABASE) - container = await database.create_container_if_not_exists(id=container_name, partition_key=partition_key_obj) - return container - except Exception: - raise UnableToAccessDatabase - async def query(self, query: str, parameters: Optional[dict] = None): items = self.container.query_items(query=query, parameters=parameters) return [i async for i in items] diff --git a/api_app/db/repositories/operations.py b/api_app/db/repositories/operations.py index 394e6713e8..8489d231d6 100644 --- a/api_app/db/repositories/operations.py +++ b/api_app/db/repositories/operations.py @@ -2,7 +2,6 @@ import uuid from typing import List -from azure.cosmos.aio import CosmosClient from pydantic import parse_obj_as from db.repositories.resource_templates import ResourceTemplateRepository from resources import strings @@ -19,9 +18,9 @@ class OperationRepository(BaseRepository): @classmethod - async def create(cls, client: CosmosClient): + async def create(cls): cls = OperationRepository() - await super().create(client, config.STATE_STORE_OPERATIONS_CONTAINER) + await super().create(config.STATE_STORE_OPERATIONS_CONTAINER) return cls @staticmethod diff --git a/api_app/db/repositories/resource_templates.py b/api_app/db/repositories/resource_templates.py index 471269d100..288b096883 100644 --- a/api_app/db/repositories/resource_templates.py +++ b/api_app/db/repositories/resource_templates.py @@ -1,7 +1,6 @@ import uuid from typing import List, Optional, Union -from azure.cosmos.aio import CosmosClient from pydantic import parse_obj_as from core import config @@ -16,9 +15,9 @@ class ResourceTemplateRepository(BaseRepository): @classmethod - async def create(cls, client: CosmosClient): + async def create(cls): cls = ResourceTemplateRepository() - await super().create(client, config.STATE_STORE_RESOURCE_TEMPLATES_CONTAINER) + await super().create(config.STATE_STORE_RESOURCE_TEMPLATES_CONTAINER) return cls @staticmethod diff --git a/api_app/db/repositories/resources.py b/api_app/db/repositories/resources.py index 6cbc464080..a1740dd167 100644 --- a/api_app/db/repositories/resources.py +++ b/api_app/db/repositories/resources.py @@ -3,7 +3,6 @@ from datetime import datetime from typing import Optional, Tuple, List -from azure.cosmos.aio import CosmosClient from azure.cosmos.exceptions import CosmosResourceNotFoundError from core import config from db.errors import VersionDowngradeDenied, EntityDoesNotExist, MajorVersionUpdateDenied, TargetTemplateVersionDoesNotExist, UserNotAuthorizedToUseTemplate @@ -25,9 +24,9 @@ class ResourceRepository(BaseRepository): @classmethod - async def create(cls, client: CosmosClient): + async def create(cls): cls = ResourceRepository() - await super().create(client, config.STATE_STORE_RESOURCES_CONTAINER) + await super().create(config.STATE_STORE_RESOURCES_CONTAINER) return cls @staticmethod @@ -46,7 +45,7 @@ def _validate_resource_parameters(resource_input, resource_template): validate(instance=resource_input["properties"], schema=resource_template) async def _get_enriched_template(self, template_name: str, resource_type: ResourceType, parent_template_name: str = "") -> dict: - template_repo = await ResourceTemplateRepository.create(self._client) + template_repo = await ResourceTemplateRepository.create() template = await template_repo.get_current_template(template_name, resource_type, parent_template_name) return template_repo.enrich_template(template) diff --git a/api_app/db/repositories/resources_history.py b/api_app/db/repositories/resources_history.py index 2a6524d62d..2ccfc061e5 100644 --- a/api_app/db/repositories/resources_history.py +++ b/api_app/db/repositories/resources_history.py @@ -1,6 +1,5 @@ from typing import List import uuid -from azure.cosmos.aio import CosmosClient from pydantic import parse_obj_as from db.errors import EntityDoesNotExist @@ -12,9 +11,9 @@ class ResourceHistoryRepository(BaseRepository): @classmethod - async def create(cls, client: CosmosClient): + async def create(cls): cls = ResourceHistoryRepository() - await super().create(client, config.STATE_STORE_RESOURCES_HISTORY_CONTAINER, "/resourceId") + await super().create(config.STATE_STORE_RESOURCES_HISTORY_CONTAINER) return cls @staticmethod diff --git a/api_app/db/repositories/shared_services.py b/api_app/db/repositories/shared_services.py index c7cc1e988a..af83ab3116 100644 --- a/api_app/db/repositories/shared_services.py +++ b/api_app/db/repositories/shared_services.py @@ -2,7 +2,6 @@ from typing import List, Tuple import uuid -from azure.cosmos.aio import CosmosClient from pydantic import parse_obj_as from models.domain.resource_template import ResourceTemplate from models.domain.authentication import User @@ -18,9 +17,9 @@ class SharedServiceRepository(ResourceRepository): @classmethod - async def create(cls, client: CosmosClient): + async def create(cls): cls = SharedServiceRepository() - await super().create(client) + await super().create() return cls @staticmethod diff --git a/api_app/db/repositories/user_resources.py b/api_app/db/repositories/user_resources.py index 4cffe296a8..dd093e2419 100644 --- a/api_app/db/repositories/user_resources.py +++ b/api_app/db/repositories/user_resources.py @@ -1,7 +1,6 @@ import uuid from typing import List, Tuple -from azure.cosmos.aio import CosmosClient from pydantic import parse_obj_as from db.repositories.resources_history import ResourceHistoryRepository from models.domain.resource_template import ResourceTemplate @@ -18,9 +17,9 @@ class UserResourceRepository(ResourceRepository): @classmethod - async def create(cls, client: CosmosClient): + async def create(cls): cls = UserResourceRepository() - await super().create(client) + await super().create() return cls @staticmethod diff --git a/api_app/db/repositories/workspace_services.py b/api_app/db/repositories/workspace_services.py index 48523dbcac..5f614aaa94 100644 --- a/api_app/db/repositories/workspace_services.py +++ b/api_app/db/repositories/workspace_services.py @@ -1,7 +1,6 @@ import uuid from typing import List, Tuple -from azure.cosmos.aio import CosmosClient from pydantic import parse_obj_as from db.repositories.resources_history import ResourceHistoryRepository from models.domain.resource_template import ResourceTemplate @@ -19,9 +18,9 @@ class WorkspaceServiceRepository(ResourceRepository): @classmethod - async def create(cls, client: CosmosClient): + async def create(cls): cls = WorkspaceServiceRepository() - await super().create(client) + await super().create() return cls @staticmethod diff --git a/api_app/db/repositories/workspaces.py b/api_app/db/repositories/workspaces.py index 23af53109b..8065d48be0 100644 --- a/api_app/db/repositories/workspaces.py +++ b/api_app/db/repositories/workspaces.py @@ -1,7 +1,6 @@ import uuid from typing import List, Tuple -from azure.cosmos.aio import CosmosClient from pydantic import parse_obj_as from db.repositories.resources_history import ResourceHistoryRepository from models.domain.resource_template import ResourceTemplate @@ -28,9 +27,9 @@ class WorkspaceRepository(ResourceRepository): predefined_address_spaces = {"small": 24, "medium": 22, "large": 16} @classmethod - async def create(cls, client: CosmosClient): + async def create(cls): cls = WorkspaceRepository() - await super().create(client) + await super().create() return cls @staticmethod diff --git a/api_app/event_grid/helpers.py b/api_app/event_grid/helpers.py index ed96ec695c..bcad3e65e1 100644 --- a/api_app/event_grid/helpers.py +++ b/api_app/event_grid/helpers.py @@ -4,7 +4,7 @@ async def publish_event(event: EventGridEvent, topic_endpoint: str): - async with credentials.get_credential_async() as credential: + async with credentials.get_credential_async_context() as credential: client = EventGridPublisherClient(topic_endpoint, credential) async with client: await client.send([event]) diff --git a/api_app/main.py b/api_app/main.py index 703b8f4d77..0bdc769141 100644 --- a/api_app/main.py +++ b/api_app/main.py @@ -24,16 +24,14 @@ @asynccontextmanager async def lifespan(app: FastAPI): - app.state.cosmos_client = None - - while not await bootstrap_database(app): + while not await bootstrap_database(): await asyncio.sleep(5) logger.warning("Database connection could not be established") - deploymentStatusUpdater = DeploymentStatusUpdater(app) + deploymentStatusUpdater = DeploymentStatusUpdater() await deploymentStatusUpdater.init_repos() - airlockStatusUpdater = AirlockStatusUpdater(app) + airlockStatusUpdater = AirlockStatusUpdater() await airlockStatusUpdater.init_repos() asyncio.create_task(deploymentStatusUpdater.receive_messages()) diff --git a/api_app/resources/strings.py b/api_app/resources/strings.py index 0c78bd7850..9c2d7ff4b4 100644 --- a/api_app/resources/strings.py +++ b/api_app/resources/strings.py @@ -82,6 +82,7 @@ OK = "OK" NOT_OK = "Not OK" COSMOS_DB = "Cosmos DB" +UNABLE_TO_GET_STATE_STORE_CLIENT = "Unable to get state store client" STATE_STORE_ENDPOINT_NOT_RESPONDING = "State Store endpoint is not responding" STATE_STORE_ENDPOINT_NOT_ACCESSIBLE = "State Store endpoint is not accessible" UNSPECIFIED_ERROR = "Unspecified error" diff --git a/api_app/service_bus/airlock_request_status_update.py b/api_app/service_bus/airlock_request_status_update.py index e26f573303..637e29c036 100644 --- a/api_app/service_bus/airlock_request_status_update.py +++ b/api_app/service_bus/airlock_request_status_update.py @@ -6,7 +6,6 @@ from fastapi import HTTPException from pydantic import ValidationError, parse_obj_as -from api.dependencies.database import get_db_client from api.dependencies.airlock import get_airlock_request_by_id_from_path from services.airlock import update_and_publish_event_airlock_request from services.logging import logger, tracer @@ -20,19 +19,18 @@ class AirlockStatusUpdater(): - def __init__(self, app): - self.app = app + def __init__(self): + pass async def init_repos(self): - db_client = await get_db_client(self.app) - self.airlock_request_repo = await AirlockRequestRepository.create(db_client) - self.workspace_repo = await WorkspaceRepository.create(db_client) + self.airlock_request_repo = await AirlockRequestRepository.create() + self.workspace_repo = await WorkspaceRepository.create() async def receive_messages(self): with tracer.start_as_current_span("airlock_receive_messages"): while True: try: - async with credentials.get_credential_async() as credential: + async with credentials.get_credential_async_context() as credential: service_bus_client = ServiceBusClient(config.SERVICE_BUS_FULLY_QUALIFIED_NAMESPACE, credential) receiver = service_bus_client.get_queue_receiver(queue_name=config.SERVICE_BUS_STEP_RESULT_QUEUE) logger.info(f"Looking for new messages on {config.SERVICE_BUS_STEP_RESULT_QUEUE} queue...") diff --git a/api_app/service_bus/deployment_status_updater.py b/api_app/service_bus/deployment_status_updater.py index b38e138927..4bac477754 100644 --- a/api_app/service_bus/deployment_status_updater.py +++ b/api_app/service_bus/deployment_status_updater.py @@ -4,7 +4,6 @@ from pydantic import ValidationError, parse_obj_as -from api.dependencies.database import get_db_client from api.routes.resource_helpers import get_timestamp from models.domain.resource import Output from db.repositories.resources_history import ResourceHistoryRepository @@ -24,15 +23,14 @@ class DeploymentStatusUpdater(): - def __init__(self, app): - self.app = app + def __init__(self): + pass async def init_repos(self): - db_client = await get_db_client(self.app) - self.operations_repo = await OperationRepository.create(db_client) - self.resource_repo = await ResourceRepository.create(db_client) - self.resource_template_repo = await ResourceTemplateRepository.create(db_client) - self.resource_history_repo = await ResourceHistoryRepository.create(db_client) + self.operations_repo = await OperationRepository.create() + self.resource_repo = await ResourceRepository.create() + self.resource_template_repo = await ResourceTemplateRepository.create() + self.resource_history_repo = await ResourceHistoryRepository.create() def run(self, *args, **kwargs): asyncio.run(self.receive_messages()) @@ -41,7 +39,7 @@ async def receive_messages(self): with tracer.start_as_current_span("deployment_status_receive_messages"): while True: try: - async with credentials.get_credential_async() as credential: + async with credentials.get_credential_async_context() as credential: service_bus_client = ServiceBusClient(config.SERVICE_BUS_FULLY_QUALIFIED_NAMESPACE, credential) logger.info(f"Looking for new messages on {config.SERVICE_BUS_DEPLOYMENT_STATUS_UPDATE_QUEUE} queue...") diff --git a/api_app/service_bus/helpers.py b/api_app/service_bus/helpers.py index 55ae5c1b20..77f7127999 100644 --- a/api_app/service_bus/helpers.py +++ b/api_app/service_bus/helpers.py @@ -24,7 +24,7 @@ async def _send_message(message: ServiceBusMessage, queue: str): :param queue: The Service Bus queue to send the message to. :type queue: str """ - async with credentials.get_credential_async() as credential: + async with credentials.get_credential_async_context() as credential: service_bus_client = ServiceBusClient(config.SERVICE_BUS_FULLY_QUALIFIED_NAMESPACE, credential) async with service_bus_client: diff --git a/api_app/services/aad_authentication.py b/api_app/services/aad_authentication.py index fba97619ed..81dd486a8f 100644 --- a/api_app/services/aad_authentication.py +++ b/api_app/services/aad_authentication.py @@ -14,7 +14,6 @@ from models.domain.authentication import User, RoleAssignment from models.domain.workspace import Workspace, WorkspaceRole from resources import strings -from api.dependencies.database import get_db_client_from_request from db.repositories.workspaces import WorkspaceRepository from services.logging import logger @@ -120,8 +119,7 @@ async def _fetch_ws_app_reg_id_from_ws_id(request: Request) -> str: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=strings.AUTH_COULD_NOT_VALIDATE_CREDENTIALS) try: workspace_id = request.path_params['workspace_id'] - db_client = await get_db_client_from_request(request) - ws_repo = await WorkspaceRepository.create(db_client) + ws_repo = await WorkspaceRepository.create() workspace = await ws_repo.get_workspace_by_id(workspace_id) ws_app_reg_id = "" diff --git a/api_app/services/health_checker.py b/api_app/services/health_checker.py index a79ee90cd1..a4d53067b0 100644 --- a/api_app/services/health_checker.py +++ b/api_app/services/health_checker.py @@ -1,11 +1,12 @@ from typing import Tuple from azure.core import exceptions -from azure.cosmos.aio import CosmosClient from azure.servicebus.aio import ServiceBusClient from azure.mgmt.compute.aio import ComputeManagementClient from azure.cosmos.exceptions import CosmosHttpResponseError +from azure.cosmos.aio import ContainerProxy from azure.servicebus.exceptions import ServiceBusConnectionError, ServiceBusAuthenticationError -from api.dependencies.database import get_store_key +from api.dependencies.database import Database +from core.config import STATE_STORE_RESOURCES_CONTAINER from core import config from models.schemas.status import StatusEnum @@ -13,16 +14,12 @@ from services.logging import logger -async def create_state_store_status(credential) -> Tuple[StatusEnum, str]: +async def create_state_store_status() -> Tuple[StatusEnum, str]: status = StatusEnum.ok message = "" - debug = True if config.LOGGING_LEVEL == "DEBUG" else False try: - primary_master_key = await get_store_key(credential) - cosmos_client = CosmosClient(config.STATE_STORE_ENDPOINT, primary_master_key, connection_verify=debug) - async with cosmos_client: - list_databases_response = cosmos_client.list_databases() - [database async for database in list_databases_response] + container: ContainerProxy = await Database().get_container_proxy(STATE_STORE_RESOURCES_CONTAINER) + container.query_items("SELECT TOP 1 * FROM c") except exceptions.ServiceRequestError: status = StatusEnum.not_ok message = strings.STATE_STORE_ENDPOINT_NOT_RESPONDING diff --git a/api_app/tests_ma/conftest.py b/api_app/tests_ma/conftest.py index 1afbaf3819..0bd06e076d 100644 --- a/api_app/tests_ma/conftest.py +++ b/api_app/tests_ma/conftest.py @@ -1,6 +1,9 @@ import pytest import pytest_asyncio -from mock import patch +from mock import AsyncMock, patch +from azure.cosmos.aio import CosmosClient, DatabaseProxy + +from api.dependencies.database import Database from models.domain.request_action import RequestAction from models.domain.resource import Resource from models.domain.user_resource import UserResource @@ -572,13 +575,10 @@ def simple_pipeline_step() -> PipelineStep: ) -@pytest_asyncio.fixture() -def no_database(): - """overrides connecting to the database""" - with patch("api.dependencies.database.connect_to_db", return_value=None): - with patch("api.dependencies.database.get_db_client", return_value=None): - with patch( - "db.repositories.base.BaseRepository._get_container", return_value=None - ): - with patch("db.events.bootstrap_database", return_value=None): - yield +@pytest_asyncio.fixture(autouse=True) +async def no_database(): + with patch('api.dependencies.database.get_credential_async', return_value=AsyncMock()), \ + patch('api.dependencies.database.CosmosDBManagementClient', return_value=AsyncMock()), \ + patch('api.dependencies.database.CosmosClient', return_value=AsyncMock(spec=CosmosClient)) as cosmos_client_mock: + cosmos_client_mock.return_value.get_database_client.return_value = AsyncMock(spec=DatabaseProxy) + yield Database() diff --git a/api_app/tests_ma/test_api/conftest.py b/api_app/tests_ma/test_api/conftest.py index a247b91ea5..e781cf854e 100644 --- a/api_app/tests_ma/test_api/conftest.py +++ b/api_app/tests_ma/test_api/conftest.py @@ -14,16 +14,6 @@ def no_lifespan_events(): yield -@pytest_asyncio.fixture(autouse=True) -def no_database(): - """ overrides connecting to the database for all tests""" - with patch('api.dependencies.database.connect_to_db', return_value=None): - with patch('api.dependencies.database.get_db_client', return_value=None): - with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - with patch('db.events.bootstrap_database', return_value=None): - yield - - @pytest.fixture(autouse=True) def no_auth_token(): """ overrides validating and decoding tokens for all tests""" diff --git a/api_app/tests_ma/test_api/dependencies/__init__.py b/api_app/tests_ma/test_api/dependencies/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api_app/tests_ma/test_api/dependencies/test_database.py b/api_app/tests_ma/test_api/dependencies/test_database.py new file mode 100644 index 0000000000..a1faefb453 --- /dev/null +++ b/api_app/tests_ma/test_api/dependencies/test_database.py @@ -0,0 +1,12 @@ +from mock import MagicMock +import pytest + +from api.dependencies.database import Database + +pytestmark = pytest.mark.asyncio + + +async def test_get_container_proxy(): + container_name = "test_container" + container_proxy = await Database().get_container_proxy(container_name) + assert isinstance(container_proxy, MagicMock) diff --git a/api_app/tests_ma/test_api/test_helpers.py b/api_app/tests_ma/test_api/test_helpers.py new file mode 100644 index 0000000000..c315a26346 --- /dev/null +++ b/api_app/tests_ma/test_api/test_helpers.py @@ -0,0 +1,17 @@ +import pytest +from mock import patch +from fastapi import HTTPException + +from db.errors import UnableToAccessDatabase +from db.repositories.base import BaseRepository +from api.helpers import get_repository + +pytestmark = pytest.mark.asyncio + + +@patch("db.repositories.base.BaseRepository.create") +async def test_get_repository_raises_http_exception_when_unable_to_access_database(create_base_repo_mock): + create_base_repo_mock.side_effect = UnableToAccessDatabase() + with pytest.raises(HTTPException): + get_repo = get_repository(BaseRepository) + await get_repo() diff --git a/api_app/tests_ma/test_api/test_routes/test_resource_helpers.py b/api_app/tests_ma/test_api/test_routes/test_resource_helpers.py index 092662930e..31fb31e5f1 100644 --- a/api_app/tests_ma/test_api/test_routes/test_resource_helpers.py +++ b/api_app/tests_ma/test_api/test_routes/test_resource_helpers.py @@ -29,24 +29,21 @@ @pytest_asyncio.fixture async def resource_repo() -> ResourceRepository: - with patch('db.repositories.base.BaseRepository._get_container', return_value=AsyncMock()): - with patch("azure.cosmos.CosmosClient") as cosmos_client_mock: - resource_repo_mock = await ResourceRepository.create(cosmos_client_mock) - yield resource_repo_mock + with patch('api.dependencies.database.Database.get_container_proxy', return_value=AsyncMock()): + resource_repo_mock = await ResourceRepository().create() + yield resource_repo_mock @pytest_asyncio.fixture async def operations_repo() -> OperationRepository: - with patch("azure.cosmos.CosmosClient") as cosmos_client_mock: - operation_repo_mock = await OperationRepository.create(cosmos_client_mock) - yield operation_repo_mock + operation_repo_mock = await OperationRepository().create() + yield operation_repo_mock @pytest_asyncio.fixture async def resource_history_repo() -> ResourceHistoryRepository: - with patch("azure.cosmos.CosmosClient") as cosmos_client_mock: - resource_history_repo_mock = await ResourceHistoryRepository.create(cosmos_client_mock) - yield resource_history_repo_mock + resource_history_repo_mock = await ResourceHistoryRepository().create() + yield resource_history_repo_mock def sample_resource(workspace_id=WORKSPACE_ID): diff --git a/api_app/tests_ma/test_api/test_routes/test_workspaces.py b/api_app/tests_ma/test_api/test_routes/test_workspaces.py index 577391fe7f..0febe3694b 100644 --- a/api_app/tests_ma/test_api/test_routes/test_workspaces.py +++ b/api_app/tests_ma/test_api/test_routes/test_workspaces.py @@ -676,14 +676,13 @@ async def test_delete_workspace_returns_400_if_associated_workspace_services_are # [DELETE] /workspaces/{workspace_id} @ patch("api.routes.resource_helpers.ResourceRepository.get_resource_dependency_list") @ patch("api.routes.workspaces.ResourceTemplateRepository.get_template_by_name_and_version") - @ patch("api.dependencies.workspaces.get_repository") + @ patch("api.helpers.get_repository") @ patch("api.dependencies.workspaces.WorkspaceRepository.get_workspace_by_id") @ patch("api.routes.workspaces.WorkspaceServiceRepository.get_active_workspace_services_for_workspace", return_value=[]) - @ patch('azure.cosmos.CosmosClient') @ patch('api.routes.resource_helpers.send_resource_request_message', return_value=sample_resource_operation(resource_id=WORKSPACE_ID, operation_id=OPERATION_ID)) - async def test_delete_workspace_sends_a_request_message_to_uninstall_the_workspace(self, send_request_message_mock, cosmos_client_mock, __, get_workspace_mock, get_repository_mock, resource_template_repo, ___, disabled_workspace, app, client, basic_resource_template): + async def test_delete_workspace_sends_a_request_message_to_uninstall_the_workspace(self, send_request_message_mock, __, get_workspace_mock, get_repository_mock, resource_template_repo, ___, disabled_workspace, app, client, basic_resource_template): get_workspace_mock.return_value = disabled_workspace - get_repository_mock.side_effects = [await WorkspaceRepository.create(cosmos_client_mock), await WorkspaceServiceRepository.create(cosmos_client_mock)] + get_repository_mock.side_effects = [await WorkspaceRepository.create(), await WorkspaceServiceRepository.create()] resource_template_repo.return_value = basic_resource_template await client.delete(app.url_path_for(strings.API_DELETE_WORKSPACE, workspace_id=WORKSPACE_ID)) @@ -692,11 +691,10 @@ async def test_delete_workspace_sends_a_request_message_to_uninstall_the_workspa # [DELETE] /workspaces/{workspace_id} @ patch("api.routes.resource_helpers.ResourceRepository.get_resource_dependency_list") @ patch("api.routes.workspaces.ResourceTemplateRepository.get_template_by_name_and_version") - @ patch("api.dependencies.workspaces.get_repository") + @ patch("api.helpers.get_repository") @ patch("api.dependencies.workspaces.WorkspaceRepository.get_workspace_by_id") @ patch("api.routes.workspaces.WorkspaceServiceRepository.get_active_workspace_services_for_workspace") - @ patch('azure.cosmos.CosmosClient') - async def test_delete_workspace_raises_503_if_marking_the_resource_as_deleted_in_the_db_fails(self, __, ___, get_workspace_mock, _____, resource_template_repo, ______, client, app, disabled_workspace, basic_resource_template): + async def test_delete_workspace_raises_503_if_marking_the_resource_as_deleted_in_the_db_fails(self, ___, get_workspace_mock, _____, resource_template_repo, ______, client, app, disabled_workspace, basic_resource_template): get_workspace_mock.return_value = disabled_workspace resource_template_repo.return_value = basic_resource_template response = await client.delete(app.url_path_for(strings.API_DELETE_WORKSPACE, workspace_id=WORKSPACE_ID)) diff --git a/api_app/tests_ma/test_core/__init__.py b/api_app/tests_ma/test_core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api_app/tests_ma/test_core/test_credentials.py b/api_app/tests_ma/test_core/test_credentials.py new file mode 100644 index 0000000000..d99e360be9 --- /dev/null +++ b/api_app/tests_ma/test_core/test_credentials.py @@ -0,0 +1,24 @@ +from mock import patch +import pytest + +from azure.identity.aio import ( + DefaultAzureCredential as DefaultAzureCredentialASync, + ManagedIdentityCredential as ManagedIdentityCredentialASync +) + +from core.credentials import get_credential_async + +pytestmark = pytest.mark.asyncio + + +@patch("core.credentials.MANAGED_IDENTITY_CLIENT_ID", "mocked_client_id") +async def test_get_credential_async_with_managed_identity_client_id(): + credential = await get_credential_async() + + assert isinstance(credential.credentials[0], ManagedIdentityCredentialASync) + + +async def test_get_credential_async_without_managed_identity_client_id(): + credential = await get_credential_async() + + assert isinstance(credential, DefaultAzureCredentialASync) diff --git a/api_app/tests_ma/test_db/test_events.py b/api_app/tests_ma/test_db/test_events.py new file mode 100644 index 0000000000..009259f983 --- /dev/null +++ b/api_app/tests_ma/test_db/test_events.py @@ -0,0 +1,28 @@ +from unittest.mock import AsyncMock, MagicMock, patch +from azure.core.exceptions import AzureError +import pytest +from db import events + +pytestmark = pytest.mark.asyncio + + +@patch("db.events.get_credential") +@patch("db.events.CosmosDBManagementClient") +async def test_bootstrap_database_success(cosmos_db_mgmt_client_mock, get_credential_async_context_mock): + get_credential_async_context_mock.return_value = AsyncMock() + cosmos_db_mgmt_client_mock.return_value = MagicMock() + + result = await events.bootstrap_database() + + assert result is True + + +@patch("db.events.get_credential") +@patch("db.events.CosmosDBManagementClient") +async def test_bootstrap_database_failure(cosmos_db_mgmt_client_mock, get_credential_async_context_mock): + get_credential_async_context_mock.return_value = AsyncMock() + cosmos_db_mgmt_client_mock.side_effect = AzureError("some error") + + result = await events.bootstrap_database() + + assert result is False diff --git a/api_app/tests_ma/test_db/test_migrations/test_workspace_migration.py b/api_app/tests_ma/test_db/test_migrations/test_workspace_migration.py index c15310f6d3..3da577094b 100644 --- a/api_app/tests_ma/test_db/test_migrations/test_workspace_migration.py +++ b/api_app/tests_ma/test_db/test_migrations/test_workspace_migration.py @@ -11,10 +11,9 @@ @pytest_asyncio.fixture async def workspace_migrator(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=AsyncMock()): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - workspace_migrator = await WorkspaceMigration.create(cosmos_client_mock) - yield workspace_migrator + with patch('api.dependencies.database.Database.get_container_proxy', return_value=AsyncMock()): + workspace_migrator = await WorkspaceMigration.create() + yield workspace_migrator def get_sample_old_workspace(workspace_id: str = "7ab18f7e-ee8f-4202-8d46-747818ec76f4", spec_workspace_id: str = "0001") -> dict: diff --git a/api_app/tests_ma/test_db/test_repositories/test_airlock_request_repository.py b/api_app/tests_ma/test_db/test_repositories/test_airlock_request_repository.py index ffff8a3c18..4c773db327 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_airlock_request_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_airlock_request_repository.py @@ -46,10 +46,9 @@ @pytest_asyncio.fixture async def airlock_request_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=AsyncMock()): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - airlock_request_repo_mock = await AirlockRequestRepository.create(cosmos_client_mock) - yield airlock_request_repo_mock + with patch('api.dependencies.database.Database.get_container_proxy', return_value=AsyncMock()): + airlock_request_repo_mock = await AirlockRequestRepository.create() + yield airlock_request_repo_mock @pytest.fixture diff --git a/api_app/tests_ma/test_db/test_repositories/test_base_repository.py b/api_app/tests_ma/test_db/test_repositories/test_base_repository.py index 3298b5049e..3d89923e24 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_base_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_base_repository.py @@ -1,6 +1,5 @@ import pytest - -from mock import patch, MagicMock +from mock import patch from db.errors import UnableToAccessDatabase from db.repositories.base import BaseRepository @@ -8,8 +7,8 @@ pytestmark = pytest.mark.asyncio -async def test_instantiating_a_repo_raises_unable_to_access_database_if_database_cant_be_accessed(): - with patch("azure.cosmos.CosmosClient") as cosmos_client_mock: - cosmos_client_mock.get_database_client = MagicMock(side_effect=Exception) - with pytest.raises(UnableToAccessDatabase): - await BaseRepository.create(cosmos_client_mock, "container") +@patch("api.dependencies.database.Database.get_container_proxy") +async def test_instantiating_a_repo_raises_unable_to_access_database_if_database_cant_be_accessed(get_container_proxy_mock): + get_container_proxy_mock.side_effect = Exception() + with pytest.raises(UnableToAccessDatabase): + await BaseRepository.create() diff --git a/api_app/tests_ma/test_db/test_repositories/test_operation_repository.py b/api_app/tests_ma/test_db/test_repositories/test_operation_repository.py index 09bc96cd70..91e51fba0a 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_operation_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_operation_repository.py @@ -18,26 +18,23 @@ @pytest_asyncio.fixture async def operations_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - operations_repo = await OperationRepository.create(cosmos_client_mock) - yield operations_repo + with patch('api.dependencies.database.Database.get_container_proxy', return_value=None): + operations_repo = await OperationRepository.create() + yield operations_repo @pytest_asyncio.fixture async def resource_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - resource_repo = await ResourceRepository.create(cosmos_client_mock) - yield resource_repo + with patch('api.dependencies.database.Database.get_container_proxy', return_value=None): + resource_repo = await ResourceRepository.create() + yield resource_repo @pytest_asyncio.fixture async def resource_template_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - resource_template_repo = await ResourceTemplateRepository.create(cosmos_client_mock) - yield resource_template_repo + with patch('api.dependencies.database.Database.get_container_proxy', return_value=None): + resource_template_repo = await ResourceTemplateRepository.create() + yield resource_template_repo @patch('uuid.uuid4', side_effect=["random-uuid-1", "random-uuid-2", "random-uuid-3"]) diff --git a/api_app/tests_ma/test_db/test_repositories/test_resource_history_repository.py b/api_app/tests_ma/test_db/test_repositories/test_resource_history_repository.py index aea50b9690..f4f37093c7 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_resource_history_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_resource_history_repository.py @@ -16,10 +16,9 @@ @pytest_asyncio.fixture async def resource_history_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - resource_history_repo = await ResourceHistoryRepository.create(cosmos_client_mock) - yield resource_history_repo + with patch('api.dependencies.database.Database.get_container_proxy', return_value=None): + resource_history_repo = await ResourceHistoryRepository().create() + yield resource_history_repo @pytest.fixture diff --git a/api_app/tests_ma/test_db/test_repositories/test_resource_repository.py b/api_app/tests_ma/test_db/test_repositories/test_resource_repository.py index d0dec25234..a308b7888f 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_resource_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_resource_repository.py @@ -26,18 +26,16 @@ @pytest_asyncio.fixture async def resource_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - resource_repo = await ResourceRepository.create(cosmos_client_mock) - yield resource_repo + with patch('api.dependencies.database.Database.get_container_proxy', return_value=None): + resource_repo = await ResourceRepository().create() + yield resource_repo @pytest_asyncio.fixture async def resource_history_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - resource_history_repo = await ResourceHistoryRepository.create(cosmos_client_mock) - yield resource_history_repo + with patch('api.dependencies.database.Database.get_container_proxy', return_value=None): + resource_history_repo = await ResourceHistoryRepository().create() + yield resource_history_repo @pytest.fixture diff --git a/api_app/tests_ma/test_db/test_repositories/test_resource_templates_repository.py b/api_app/tests_ma/test_db/test_repositories/test_resource_templates_repository.py index f84429b7c8..d007326323 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_resource_templates_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_resource_templates_repository.py @@ -14,10 +14,9 @@ @pytest_asyncio.fixture async def resource_template_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - resource_template_repo = await ResourceTemplateRepository.create(cosmos_client_mock) - yield resource_template_repo + with patch('api.dependencies.database.Database.get_container_proxy', return_value=None): + resource_template_repo = await ResourceTemplateRepository().create() + yield resource_template_repo def sample_resource_template_as_dict(name: str, version: str = "1.0", resource_type: ResourceType = ResourceType.Workspace) -> ResourceTemplate: diff --git a/api_app/tests_ma/test_db/test_repositories/test_shared_service_repository.py b/api_app/tests_ma/test_db/test_repositories/test_shared_service_repository.py index 71da756dfd..131b7c9d78 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_shared_service_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_shared_service_repository.py @@ -17,18 +17,16 @@ @pytest_asyncio.fixture async def shared_service_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=AsyncMock()): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - shared_service_repo = await SharedServiceRepository.create(cosmos_client_mock) - yield shared_service_repo + with patch('api.dependencies.database.Database.get_container_proxy', return_value=AsyncMock()): + shared_service_repo = await SharedServiceRepository().create() + yield shared_service_repo @pytest_asyncio.fixture async def operations_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - operations_repo = await OperationRepository.create(cosmos_client_mock) - yield operations_repo + with patch('api.dependencies.database.Database.get_container_proxy', return_value=None): + operations_repo = await OperationRepository().create() + yield operations_repo @pytest.fixture diff --git a/api_app/tests_ma/test_db/test_repositories/test_shared_service_templates_repository.py b/api_app/tests_ma/test_db/test_repositories/test_shared_service_templates_repository.py index 936ebe50d7..226777f6f4 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_shared_service_templates_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_shared_service_templates_repository.py @@ -11,10 +11,9 @@ @pytest_asyncio.fixture async def resource_template_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - resource_template_repo = await ResourceTemplateRepository.create(cosmos_client_mock) - yield resource_template_repo + with patch('api.dependencies.database.Database.get_container_proxy', return_value=None): + resource_template_repo = await ResourceTemplateRepository().create() + yield resource_template_repo # Because shared service templates repository uses generic ResourceTemplate repository, most test cases are already covered diff --git a/api_app/tests_ma/test_db/test_repositories/test_user_resource_repository.py b/api_app/tests_ma/test_db/test_repositories/test_user_resource_repository.py index 7a2adbebde..90067c3356 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_user_resource_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_user_resource_repository.py @@ -26,10 +26,9 @@ def basic_user_resource_request(): @pytest_asyncio.fixture async def user_resource_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - user_resource_repo = await UserResourceRepository.create(cosmos_client_mock) - yield user_resource_repo + with patch('api.dependencies.database.Database.get_container_proxy', return_value=None): + user_resource_repo = await UserResourceRepository().create() + yield user_resource_repo @pytest.fixture diff --git a/api_app/tests_ma/test_db/test_repositories/test_user_resource_templates_repository.py b/api_app/tests_ma/test_db/test_repositories/test_user_resource_templates_repository.py index 0cb5cb56fd..727c9a2d18 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_user_resource_templates_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_user_resource_templates_repository.py @@ -13,10 +13,9 @@ @pytest_asyncio.fixture async def resource_template_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - resource_template_repo = await ResourceTemplateRepository.create(cosmos_client_mock) - yield resource_template_repo + with patch('api.dependencies.database.Database.get_container_proxy', return_value=None): + resource_template_repo = await ResourceTemplateRepository().create() + yield resource_template_repo def sample_user_resource_template_as_dict(name: str, version: str = "1.0") -> dict: diff --git a/api_app/tests_ma/test_db/test_repositories/test_workpaces_repository.py b/api_app/tests_ma/test_db/test_repositories/test_workpaces_repository.py index 79b9a5176a..01727d4e24 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_workpaces_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_workpaces_repository.py @@ -20,18 +20,16 @@ def basic_workspace_request(): @pytest_asyncio.fixture async def workspace_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=MagicMock()): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - workspace_repo = await WorkspaceRepository.create(cosmos_client_mock) - yield workspace_repo + with patch('api.dependencies.database.Database.get_container_proxy', return_value=MagicMock()): + workspace_repo = await WorkspaceRepository().create() + yield workspace_repo @pytest_asyncio.fixture async def operations_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=MagicMock()): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - operations_repo = await OperationRepository.create(cosmos_client_mock) - yield operations_repo + with patch('api.dependencies.database.Database.get_container_proxy', return_value=MagicMock()): + operations_repo = await OperationRepository().create() + yield operations_repo @pytest.fixture diff --git a/api_app/tests_ma/test_db/test_repositories/test_workpaces_service_repository.py b/api_app/tests_ma/test_db/test_repositories/test_workpaces_service_repository.py index 10ce04e187..390c0e18d2 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_workpaces_service_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_workpaces_service_repository.py @@ -18,18 +18,16 @@ @pytest_asyncio.fixture async def workspace_service_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=MagicMock()): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - workspace_repo = await WorkspaceServiceRepository.create(cosmos_client_mock) - yield workspace_repo + with patch('api.dependencies.database.Database.get_container_proxy', return_value=MagicMock()): + workspace_repo = await WorkspaceServiceRepository().create() + yield workspace_repo @pytest_asyncio.fixture async def operations_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=MagicMock()): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - operations_repo = await OperationRepository.create(cosmos_client_mock) - yield operations_repo + with patch('api.dependencies.database.Database.get_container_proxy', return_value=MagicMock()): + operations_repo = await OperationRepository().create() + yield operations_repo @pytest.fixture diff --git a/api_app/tests_ma/test_service_bus/test_airlock_request_status_update.py b/api_app/tests_ma/test_service_bus/test_airlock_request_status_update.py index 66829a4a47..d9165db2f5 100644 --- a/api_app/tests_ma/test_service_bus/test_airlock_request_status_update.py +++ b/api_app/tests_ma/test_service_bus/test_airlock_request_status_update.py @@ -108,9 +108,8 @@ def __str__(self): @patch('service_bus.airlock_request_status_update.AirlockRequestRepository.create') @patch('service_bus.airlock_request_status_update.WorkspaceRepository.create') @patch('logging.exception') -@patch('fastapi.FastAPI') @patch("services.aad_authentication.AzureADAuthorization.get_workspace_role_assignment_details", return_value={"researcher_emails": ["researcher@outlook.com"], "owner_emails": ["owner@outlook.com"]}) -async def test_receiving_good_message(_, app, logging_mock, workspace_repo, airlock_request_repo, eg_client): +async def test_receiving_good_message(_, logging_mock, workspace_repo, airlock_request_repo, eg_client): eg_client().send = AsyncMock() expected_airlock_request = sample_airlock_request() @@ -118,7 +117,7 @@ async def test_receiving_good_message(_, app, logging_mock, workspace_repo, airl airlock_request_repo.return_value.update_airlock_request.return_value = sample_airlock_request(status=AirlockRequestStatus.InReview) workspace_repo.return_value.get_workspace_by_id.return_value = sample_workspace() - airlockStatusUpdater = AirlockStatusUpdater(app) + airlockStatusUpdater = AirlockStatusUpdater() await airlockStatusUpdater.init_repos() complete_message = await airlockStatusUpdater.process_message(ServiceBusReceivedMessageMock(test_sb_step_result_message)) @@ -140,10 +139,9 @@ async def test_receiving_good_message(_, app, logging_mock, workspace_repo, airl @patch('service_bus.airlock_request_status_update.AirlockRequestRepository.create') @patch('service_bus.airlock_request_status_update.WorkspaceRepository.create') @patch('services.logging.logger.exception') -@patch('fastapi.FastAPI') -async def test_receiving_bad_json_logs_error(app, logging_mock, workspace_repo, airlock_request_repo, payload): +async def test_receiving_bad_json_logs_error(logging_mock, workspace_repo, airlock_request_repo, payload): service_bus_received_message_mock = ServiceBusReceivedMessageMock(payload) - airlockStatusUpdater = AirlockStatusUpdater(app) + airlockStatusUpdater = AirlockStatusUpdater() await airlockStatusUpdater.init_repos() complete_message = await airlockStatusUpdater.process_message(service_bus_received_message_mock) @@ -156,12 +154,11 @@ async def test_receiving_bad_json_logs_error(app, logging_mock, workspace_repo, @patch('service_bus.airlock_request_status_update.AirlockRequestRepository.create') @patch('services.logging.logger.exception') @patch('service_bus.airlock_request_status_update.ServiceBusClient') -@patch('fastapi.FastAPI') -async def test_updating_non_existent_airlock_request_error_is_logged(app, sb_client, logging_mock, airlock_request_repo, _): +async def test_updating_non_existent_airlock_request_error_is_logged(sb_client, logging_mock, airlock_request_repo, _): service_bus_received_message_mock = ServiceBusReceivedMessageMock(test_sb_step_result_message) airlock_request_repo.return_value.get_airlock_request_by_id.side_effect = EntityDoesNotExist - airlockStatusUpdater = AirlockStatusUpdater(app) + airlockStatusUpdater = AirlockStatusUpdater() await airlockStatusUpdater.init_repos() complete_message = await airlockStatusUpdater.process_message(service_bus_received_message_mock) @@ -173,12 +170,11 @@ async def test_updating_non_existent_airlock_request_error_is_logged(app, sb_cli @patch('service_bus.airlock_request_status_update.WorkspaceRepository.create') @patch('service_bus.airlock_request_status_update.AirlockRequestRepository.create') @patch('services.logging.logger.exception') -@patch('fastapi.FastAPI') -async def test_when_updating_and_state_store_exception_error_is_logged(app, logging_mock, airlock_request_repo, _): +async def test_when_updating_and_state_store_exception_error_is_logged(logging_mock, airlock_request_repo, _): service_bus_received_message_mock = ServiceBusReceivedMessageMock(test_sb_step_result_message) airlock_request_repo.return_value.get_airlock_request_by_id.side_effect = Exception - airlockStatusUpdater = AirlockStatusUpdater(app) + airlockStatusUpdater = AirlockStatusUpdater() await airlockStatusUpdater.init_repos() complete_message = await airlockStatusUpdater.process_message(service_bus_received_message_mock) @@ -189,13 +185,12 @@ async def test_when_updating_and_state_store_exception_error_is_logged(app, logg @patch('service_bus.airlock_request_status_update.WorkspaceRepository.create') @patch('service_bus.airlock_request_status_update.AirlockRequestRepository.create') @patch('services.logging.logger.error') -@patch('fastapi.FastAPI') -async def test_when_updating_and_current_status_differs_from_status_in_state_store_error_is_logged(app, logging_mock, airlock_request_repo, _): +async def test_when_updating_and_current_status_differs_from_status_in_state_store_error_is_logged(logging_mock, airlock_request_repo, _): service_bus_received_message_mock = ServiceBusReceivedMessageMock(test_sb_step_result_message) expected_airlock_request = sample_airlock_request(AirlockRequestStatus.Draft) airlock_request_repo.return_value.get_airlock_request_by_id.return_value = expected_airlock_request - airlockStatusUpdater = AirlockStatusUpdater(app) + airlockStatusUpdater = AirlockStatusUpdater() await airlockStatusUpdater.init_repos() complete_message = await airlockStatusUpdater.process_message(service_bus_received_message_mock) @@ -208,12 +203,11 @@ async def test_when_updating_and_current_status_differs_from_status_in_state_sto @patch('service_bus.airlock_request_status_update.AirlockRequestRepository.create') @patch('services.logging.logger.exception') @patch('service_bus.airlock_request_status_update.ServiceBusClient') -@patch('fastapi.FastAPI') -async def test_when_updating_and_status_update_is_illegal_error_is_logged(app, sb_client, logging_mock, airlock_request_repo, _): +async def test_when_updating_and_status_update_is_illegal_error_is_logged(sb_client, logging_mock, airlock_request_repo, _): service_bus_received_message_mock = ServiceBusReceivedMessageMock(test_sb_step_result_message_with_invalid_status) airlock_request_repo.return_value.get_airlock_request_by_id.side_effect = HTTPException(status_code=status.HTTP_400_BAD_REQUEST) - airlockStatusUpdater = AirlockStatusUpdater(app) + airlockStatusUpdater = AirlockStatusUpdater() await airlockStatusUpdater.init_repos() complete_message = await airlockStatusUpdater.process_message(service_bus_received_message_mock) diff --git a/api_app/tests_ma/test_service_bus/test_deployment_status_update.py b/api_app/tests_ma/test_service_bus/test_deployment_status_update.py index 71cfe2d843..db80c5b1f7 100644 --- a/api_app/tests_ma/test_service_bus/test_deployment_status_update.py +++ b/api_app/tests_ma/test_service_bus/test_deployment_status_update.py @@ -118,11 +118,10 @@ def create_sample_operation(resource_id, request_action): @pytest.mark.parametrize("payload", test_data) @patch('services.logging.logger.exception') -@patch('fastapi.FastAPI') -async def test_receiving_bad_json_logs_error(app, logging_mock, payload): +async def test_receiving_bad_json_logs_error(logging_mock, payload): service_bus_received_message_mock = ServiceBusReceivedMessageMock(payload) - status_updater = DeploymentStatusUpdater(app) + status_updater = DeploymentStatusUpdater() complete_message = await status_updater.process_message(service_bus_received_message_mock) # bad message data will fail. we don't mark complete=true since we want the message in the DLQ @@ -138,15 +137,14 @@ async def test_receiving_bad_json_logs_error(app, logging_mock, payload): @patch('service_bus.deployment_status_updater.OperationRepository.create') @patch('service_bus.deployment_status_updater.ResourceRepository.create') @patch('services.logging.logger.exception') -@patch('fastapi.FastAPI') -async def test_receiving_good_message(app, logging_mock, resource_repo, operation_repo, _, __): +async def test_receiving_good_message(logging_mock, resource_repo, operation_repo, _, __): expected_workspace = create_sample_workspace_object(test_sb_message["id"]) resource_repo.return_value.get_resource_dict_by_id.return_value = expected_workspace.dict() operation = create_sample_operation(test_sb_message["id"], RequestAction.Install) operation_repo.return_value.get_operation_by_id.return_value = operation - status_updater = DeploymentStatusUpdater(app) + status_updater = DeploymentStatusUpdater() await status_updater.init_repos() complete_message = await status_updater.process_message(ServiceBusReceivedMessageMock(test_sb_message)) @@ -161,14 +159,13 @@ async def test_receiving_good_message(app, logging_mock, resource_repo, operatio @patch('service_bus.deployment_status_updater.OperationRepository.create') @patch('service_bus.deployment_status_updater.ResourceRepository.create') @patch('services.logging.logger.exception') -@patch('fastapi.FastAPI') -async def test_when_updating_non_existent_workspace_error_is_logged(app, logging_mock, resource_repo, operation_repo, _, __): +async def test_when_updating_non_existent_workspace_error_is_logged(logging_mock, resource_repo, operation_repo, _, __): resource_repo.return_value.get_resource_dict_by_id.side_effect = EntityDoesNotExist operation = create_sample_operation(test_sb_message["id"], RequestAction.Install) operation_repo.return_value.get_operation_by_id.return_value = operation - status_updater = DeploymentStatusUpdater(app) + status_updater = DeploymentStatusUpdater() await status_updater.init_repos() complete_message = await status_updater.process_message(ServiceBusReceivedMessageMock(test_sb_message)) @@ -182,14 +179,13 @@ async def test_when_updating_non_existent_workspace_error_is_logged(app, logging @patch('service_bus.deployment_status_updater.OperationRepository.create') @patch('service_bus.deployment_status_updater.ResourceRepository.create') @patch('services.logging.logger.exception') -@patch('fastapi.FastAPI') -async def test_when_updating_and_state_store_exception(app, logging_mock, resource_repo, operation_repo, _, __): +async def test_when_updating_and_state_store_exception(logging_mock, resource_repo, operation_repo, _, __): resource_repo.return_value.get_resource_dict_by_id.side_effect = Exception operation = create_sample_operation(test_sb_message["id"], RequestAction.Install) operation_repo.return_value.get_operation_by_id.return_value = operation - status_updater = DeploymentStatusUpdater(app) + status_updater = DeploymentStatusUpdater() await status_updater.init_repos() complete_message = await status_updater.process_message(ServiceBusReceivedMessageMock(test_sb_message)) @@ -202,8 +198,7 @@ async def test_when_updating_and_state_store_exception(app, logging_mock, resour @patch("service_bus.deployment_status_updater.get_timestamp", return_value=FAKE_UPDATE_TIMESTAMP) @patch('service_bus.deployment_status_updater.OperationRepository.create') @patch('service_bus.deployment_status_updater.ResourceRepository.create') -@patch('fastapi.FastAPI') -async def test_state_transitions_from_deployed_to_deleted(app, resource_repo, operations_repo_mock, _, __, ___): +async def test_state_transitions_from_deployed_to_deleted(resource_repo, operations_repo_mock, _, __, ___): updated_message = test_sb_message updated_message["status"] = Status.Deleted updated_message["message"] = "Has been deleted" @@ -222,7 +217,7 @@ async def test_state_transitions_from_deployed_to_deleted(app, resource_repo, op expected_operation.status = Status.Deleted expected_operation.message = updated_message["message"] - status_updater = DeploymentStatusUpdater(app) + status_updater = DeploymentStatusUpdater() await status_updater.init_repos() complete_message = await status_updater.process_message(service_bus_received_message_mock) @@ -234,8 +229,7 @@ async def test_state_transitions_from_deployed_to_deleted(app, resource_repo, op @patch('service_bus.deployment_status_updater.ResourceTemplateRepository.create') @patch('service_bus.deployment_status_updater.OperationRepository.create') @patch('service_bus.deployment_status_updater.ResourceRepository.create') -@patch('fastapi.FastAPI') -async def test_outputs_are_added_to_resource_item(app, resource_repo, operations_repo, _, __): +async def test_outputs_are_added_to_resource_item(resource_repo, operations_repo, _, __): received_message = test_sb_message_with_outputs received_message["status"] = Status.Deployed service_bus_received_message_mock = ServiceBusReceivedMessageMock(received_message) @@ -260,7 +254,7 @@ async def test_outputs_are_added_to_resource_item(app, resource_repo, operations operation = create_sample_operation(resource.id, RequestAction.UnInstall) operations_repo.return_value.get_operation_by_id.return_value = operation - status_updater = DeploymentStatusUpdater(app) + status_updater = DeploymentStatusUpdater() await status_updater.init_repos() complete_message = await status_updater.process_message(service_bus_received_message_mock) @@ -272,8 +266,7 @@ async def test_outputs_are_added_to_resource_item(app, resource_repo, operations @patch('service_bus.deployment_status_updater.ResourceTemplateRepository.create') @patch('service_bus.deployment_status_updater.OperationRepository.create') @patch('service_bus.deployment_status_updater.ResourceRepository.create') -@patch('fastapi.FastAPI') -async def test_properties_dont_change_with_no_outputs(app, resource_repo, operations_repo, _, __): +async def test_properties_dont_change_with_no_outputs(resource_repo, operations_repo, _, __): received_message = test_sb_message received_message["status"] = Status.Deployed service_bus_received_message_mock = ServiceBusReceivedMessageMock(received_message) @@ -287,7 +280,7 @@ async def test_properties_dont_change_with_no_outputs(app, resource_repo, operat expected_resource = resource - status_updater = DeploymentStatusUpdater(app) + status_updater = DeploymentStatusUpdater() await status_updater.init_repos() complete_message = await status_updater.process_message(service_bus_received_message_mock) @@ -301,8 +294,7 @@ async def test_properties_dont_change_with_no_outputs(app, resource_repo, operat @patch('service_bus.deployment_status_updater.OperationRepository.create') @patch('service_bus.deployment_status_updater.ResourceRepository.create') @patch('service_bus.helpers.ServiceBusClient') -@patch('fastapi.FastAPI') -async def test_multi_step_operation_sends_next_step(app, sb_sender_client, resource_repo, operations_repo, update_resource_for_step, _, __, multi_step_operation, user_resource_multi, basic_shared_service): +async def test_multi_step_operation_sends_next_step(sb_sender_client, resource_repo, operations_repo, update_resource_for_step, _, __, multi_step_operation, user_resource_multi, basic_shared_service): received_message = test_sb_message_multi_step_1_complete received_message["status"] = Status.Updated service_bus_received_message_mock = ServiceBusReceivedMessageMock(received_message) @@ -320,7 +312,7 @@ async def test_multi_step_operation_sends_next_step(app, sb_sender_client, resou operations_repo.return_value.get_operation_by_id.return_value = multi_step_operation update_resource_for_step.return_value = user_resource_multi - status_updater = DeploymentStatusUpdater(app) + status_updater = DeploymentStatusUpdater() await status_updater.init_repos() complete_message = await status_updater.process_message(service_bus_received_message_mock) @@ -356,8 +348,7 @@ async def test_multi_step_operation_sends_next_step(app, sb_sender_client, resou @patch('service_bus.deployment_status_updater.OperationRepository.create') @patch('service_bus.deployment_status_updater.ResourceRepository.create') @patch('service_bus.helpers.ServiceBusClient') -@patch('fastapi.FastAPI') -async def test_multi_step_operation_ends_at_last_step(app, sb_sender_client, resource_repo, operations_repo, _, __, multi_step_operation, user_resource_multi, basic_shared_service): +async def test_multi_step_operation_ends_at_last_step(sb_sender_client, resource_repo, operations_repo, _, __, multi_step_operation, user_resource_multi, basic_shared_service): received_message = test_sb_message_multi_step_3_complete received_message["status"] = Status.Updated service_bus_received_message_mock = ServiceBusReceivedMessageMock(received_message) @@ -384,7 +375,7 @@ async def test_multi_step_operation_ends_at_last_step(app, sb_sender_client, res operations_repo.return_value.get_operation_by_id.return_value = in_flight_op - status_updater = DeploymentStatusUpdater(app) + status_updater = DeploymentStatusUpdater() await status_updater.init_repos() complete_message = await status_updater.process_message(service_bus_received_message_mock) assert complete_message is True @@ -401,13 +392,12 @@ async def test_multi_step_operation_ends_at_last_step(app, sb_sender_client, res sb_sender_client().get_queue_sender().send_messages.assert_not_called() -@patch('fastapi.FastAPI') -async def test_convert_outputs_to_dict(app): +async def test_convert_outputs_to_dict(): # Test case 1: Empty list of outputs outputs_list = [] expected_result = {} - status_updater = DeploymentStatusUpdater(app) + status_updater = DeploymentStatusUpdater() assert status_updater.convert_outputs_to_dict(outputs_list) == expected_result # Test case 2: List of outputs with mixed types diff --git a/api_app/tests_ma/test_services/test_airlock.py b/api_app/tests_ma/test_services/test_airlock.py index f70f3f4003..c7fccec4d9 100644 --- a/api_app/tests_ma/test_services/test_airlock.py +++ b/api_app/tests_ma/test_services/test_airlock.py @@ -31,9 +31,8 @@ @pytest_asyncio.fixture async def airlock_request_repo_mock(no_database): _ = no_database - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - airlock_request_repo_mock = await AirlockRequestRepository.create(cosmos_client_mock) - yield airlock_request_repo_mock + airlock_request_repo_mock = await AirlockRequestRepository().create() + yield airlock_request_repo_mock def sample_workspace(): diff --git a/api_app/tests_ma/test_services/test_health_checker.py b/api_app/tests_ma/test_services/test_health_checker.py index 84ff3ae989..6b1d955f89 100644 --- a/api_app/tests_ma/test_services/test_health_checker.py +++ b/api_app/tests_ma/test_services/test_health_checker.py @@ -11,84 +11,73 @@ pytestmark = pytest.mark.asyncio -@patch("core.credentials.get_credential_async") -@patch("services.health_checker.get_store_key") -@patch("services.health_checker.CosmosClient") -async def test_get_state_store_status_responding(_, get_store_key_mock, get_credential_async) -> None: - get_store_key_mock.return_value = None - status, message = await health_checker.create_state_store_status(get_credential_async) +@patch("azure.cosmos.aio.ContainerProxy.query_items", return_value=AsyncMock()) +async def test_get_state_store_status_responding(_) -> None: + status, message = await health_checker.create_state_store_status() assert status == StatusEnum.ok assert message == "" -@patch("core.credentials.get_credential_async") -@patch("services.health_checker.get_store_key") -@patch("services.health_checker.CosmosClient") -async def test_get_state_store_status_not_responding(cosmos_client_mock, get_store_key_mock, get_credential_async) -> None: - get_credential_async.return_value = AsyncMock() - get_store_key_mock.return_value = None - cosmos_client_mock.return_value = None - cosmos_client_mock.side_effect = ServiceRequestError(message="some message") - status, message = await health_checker.create_state_store_status(get_credential_async) +@patch("api.dependencies.database.Database.get_container_proxy") +async def test_get_state_store_status_not_responding(container_proxy_mock) -> None: + container_proxy_mock.return_value = None + container_proxy_mock.side_effect = ServiceRequestError(message="some message") + status, message = await health_checker.create_state_store_status() assert status == StatusEnum.not_ok assert message == strings.STATE_STORE_ENDPOINT_NOT_RESPONDING -@patch("core.credentials.get_credential_async") -@patch("services.health_checker.get_store_key") -@patch("services.health_checker.CosmosClient") -async def test_get_state_store_status_other_exception(cosmos_client_mock, get_store_key_mock, get_credential_async) -> None: - get_credential_async.return_value = AsyncMock() - get_store_key_mock.return_value = None - cosmos_client_mock.return_value = None - cosmos_client_mock.side_effect = Exception() - status, message = await health_checker.create_state_store_status(get_credential_async) +@patch("api.dependencies.database.Database.get_container_proxy") +async def test_get_state_store_status_other_exception(container_proxy_mock) -> None: + container_proxy_mock.return_value = None + container_proxy_mock.side_effect = Exception() + status, message = await health_checker.create_state_store_status() assert status == StatusEnum.not_ok assert message == strings.UNSPECIFIED_ERROR -@patch("core.credentials.get_credential_async") +@patch("core.credentials.get_credential_async_context") @patch("services.health_checker.ServiceBusClient") -async def test_get_service_bus_status_responding(service_bus_client_mock, get_credential_async) -> None: - get_credential_async.return_value = AsyncMock() +async def test_get_service_bus_status_responding(service_bus_client_mock, get_credential_async_context) -> None: + get_credential_async_context.return_value = AsyncMock() service_bus_client_mock().get_queue_receiver.__aenter__.return_value = AsyncMock() - status, message = await health_checker.create_service_bus_status(get_credential_async) + status, message = await health_checker.create_service_bus_status(get_credential_async_context) assert status == StatusEnum.ok assert message == "" -@patch("core.credentials.get_credential_async") +@patch("core.credentials.get_credential_async_context") @patch("services.health_checker.ServiceBusClient") -async def test_get_service_bus_status_not_responding(service_bus_client_mock, get_credential_async) -> None: - get_credential_async.return_value = AsyncMock() +async def test_get_service_bus_status_not_responding(service_bus_client_mock, get_credential_async_context) -> None: + get_credential_async_context.return_value = AsyncMock() service_bus_client_mock.return_value = None service_bus_client_mock.side_effect = ServiceBusConnectionError(message="some message") - status, message = await health_checker.create_service_bus_status(get_credential_async) + status, message = await health_checker.create_service_bus_status(get_credential_async_context) assert status == StatusEnum.not_ok assert message == strings.SERVICE_BUS_NOT_RESPONDING -@patch("core.credentials.get_credential_async") +@patch("core.credentials.get_credential_async_context") @patch("services.health_checker.ServiceBusClient") -async def test_get_service_bus_status_other_exception(service_bus_client_mock, get_credential_async) -> None: - get_credential_async.return_value = AsyncMock() +async def test_get_service_bus_status_other_exception(service_bus_client_mock, get_credential_async_context) -> None: + get_credential_async_context.return_value = AsyncMock() service_bus_client_mock.return_value = None service_bus_client_mock.side_effect = Exception() - status, message = await health_checker.create_service_bus_status(get_credential_async) + status, message = await health_checker.create_service_bus_status(get_credential_async_context) assert status == StatusEnum.not_ok assert message == strings.UNSPECIFIED_ERROR -@patch("core.credentials.get_credential_async") +@patch("core.credentials.get_credential_async_context") @patch("services.health_checker.ComputeManagementClient") -async def test_get_resource_processor_status_healthy(resource_processor_client_mock, get_credential_async) -> None: - get_credential_async.return_value = AsyncMock() +async def test_get_resource_processor_status_healthy(resource_processor_client_mock, get_credential_async_context) -> None: + get_credential_async_context.return_value = AsyncMock() resource_processor_client_mock().virtual_machine_scale_set_vms.return_value = AsyncMock() vm_mock = MagicMock() vm_mock.instance_id = 'mocked_id' @@ -100,16 +89,16 @@ async def test_get_resource_processor_status_healthy(resource_processor_client_m awaited_mock.set_result(instance_view_mock) resource_processor_client_mock().virtual_machine_scale_set_vms.get_instance_view.return_value = awaited_mock - status, message = await health_checker.create_resource_processor_status(get_credential_async) + status, message = await health_checker.create_resource_processor_status(get_credential_async_context) assert status == StatusEnum.ok assert message == "" -@patch("core.credentials.get_credential_async") +@patch("core.credentials.get_credential_async_context") @patch("services.health_checker.ComputeManagementClient", return_value=MagicMock()) -async def test_get_resource_processor_status_not_healthy(resource_processor_client_mock, get_credential_async) -> None: - get_credential_async.return_value = AsyncMock() +async def test_get_resource_processor_status_not_healthy(resource_processor_client_mock, get_credential_async_context) -> None: + get_credential_async_context.return_value = AsyncMock() resource_processor_client_mock().virtual_machine_scale_set_vms.return_value = AsyncMock() vm_mock = MagicMock() @@ -122,19 +111,19 @@ async def test_get_resource_processor_status_not_healthy(resource_processor_clie awaited_mock.set_result(instance_view_mock) resource_processor_client_mock().virtual_machine_scale_set_vms.get_instance_view.return_value = awaited_mock - status, message = await health_checker.create_resource_processor_status(get_credential_async) + status, message = await health_checker.create_resource_processor_status(get_credential_async_context) assert status == StatusEnum.not_ok assert message == strings.RESOURCE_PROCESSOR_GENERAL_ERROR_MESSAGE -@patch("core.credentials.get_credential_async") +@patch("core.credentials.get_credential_async_context") @patch("services.health_checker.ComputeManagementClient") -async def test_get_resource_processor_status_other_exception(resource_processor_client_mock, get_credential_async) -> None: - get_credential_async.return_value = AsyncMock() +async def test_get_resource_processor_status_other_exception(resource_processor_client_mock, get_credential_async_context) -> None: + get_credential_async_context.return_value = AsyncMock() resource_processor_client_mock.return_value = None resource_processor_client_mock.side_effect = Exception() - status, message = await health_checker.create_resource_processor_status(get_credential_async) + status, message = await health_checker.create_resource_processor_status(get_credential_async_context) assert status == StatusEnum.not_ok assert message == strings.UNSPECIFIED_ERROR diff --git a/core/terraform/api-identity.tf b/core/terraform/api-identity.tf index c68213f816..8209e37143 100644 --- a/core/terraform/api-identity.tf +++ b/core/terraform/api-identity.tf @@ -45,3 +45,16 @@ resource "azurerm_role_assignment" "cosmos_contributor" { principal_id = azurerm_user_assigned_identity.id.principal_id } +data "azurerm_cosmosdb_sql_role_definition" "cosmosdb_db_contributor" { + resource_group_name = azurerm_resource_group.core.name + account_name = azurerm_cosmosdb_account.tre_db_account.name + role_definition_id = "00000000-0000-0000-0000-000000000002" # Cosmos DB Built-in Data Contributor +} + +resource "azurerm_cosmosdb_sql_role_assignment" "tre_db_contributor" { + resource_group_name = azurerm_resource_group.core.name + account_name = azurerm_cosmosdb_account.tre_db_account.name + role_definition_id = data.azurerm_cosmosdb_sql_role_definition.cosmosdb_db_contributor.id + principal_id = azurerm_user_assigned_identity.id.principal_id + scope = azurerm_cosmosdb_account.tre_db_account.id +} diff --git a/core/version.txt b/core/version.txt index a2fecb4576..e94731c0f2 100644 --- a/core/version.txt +++ b/core/version.txt @@ -1 +1 @@ -__version__ = "0.9.2" +__version__ = "0.9.4"