diff --git a/astronomer/providers/http/hooks/http.py b/astronomer/providers/http/hooks/http.py index d03094f4b..dd6fb7d96 100644 --- a/astronomer/providers/http/hooks/http.py +++ b/astronomer/providers/http/hooks/http.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import asyncio -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Callable import aiohttp from aiohttp import ClientResponseError @@ -20,6 +22,9 @@ class HttpHookAsync(BaseHook): API url i.e https://www.google.com/ and optional authentication credentials. Default headers can also be specified in the Extra field in json format. :param auth_type: The auth type for the service + :param keep_response: Keep the aiohttp response returned by run method without releasing it. + Use it with caution. Without properly releasing response, it might cause "Unclosed connection" error. + See https://github.com/astronomer/astronomer-providers/issues/909 :type auth_type: AuthBase of python aiohttp lib """ @@ -35,6 +40,8 @@ def __init__( auth_type: Any = aiohttp.BasicAuth, retry_limit: int = 3, retry_delay: float = 1.0, + *, + keep_response: bool = False, ) -> None: self.http_conn_id = http_conn_id self.method = method.upper() @@ -45,14 +52,15 @@ def __init__( raise ValueError("Retry limit must be greater than equal to 1") self.retry_limit = retry_limit self.retry_delay = retry_delay + self.keep_response = keep_response async def run( self, - endpoint: Optional[str] = None, - data: Optional[Union[Dict[str, Any], str]] = None, - headers: Optional[Dict[str, Any]] = None, - extra_options: Optional[Dict[str, Any]] = None, - ) -> "ClientResponse": + endpoint: str | None = None, + data: dict[str, Any] | str | None = None, + headers: dict[str, Any] | None = None, + extra_options: dict[str, Any] | None = None, + ) -> ClientResponse: r""" Performs an asynchronous HTTP request call @@ -78,10 +86,10 @@ async def run( # schema defaults to HTTP schema = conn.schema if conn.schema else "http" host = conn.host if conn.host else "" - self.base_url = schema + "://" + host + self.base_url = f"{schema}://{host}" if conn.port: - self.base_url = self.base_url + ":" + str(conn.port) + self.base_url = f"{self.base_url}:{conn.port}" if conn.login: auth = self.auth_type(conn.login, conn.password) if conn.extra: @@ -93,7 +101,7 @@ async def run( _headers.update(headers) if self.base_url and not self.base_url.endswith("/") and endpoint and not endpoint.startswith("/"): - url = self.base_url + "/" + endpoint + url = f"{self.base_url}/{endpoint}" else: url = (self.base_url or "") + (endpoint or "") @@ -109,29 +117,34 @@ async def run( attempt_num = 1 while True: - async with request_func( + response = await request_func( url, json=data if self.method in ("POST", "PATCH") else None, params=data if self.method == "GET" else None, headers=headers, auth=auth, **extra_options, - ) as response: - try: - response.raise_for_status() - return response - except ClientResponseError as e: - self.log.warning( - "[Try %d of %d] Request to %s failed.", - attempt_num, - self.retry_limit, - url, - ) - if not self._retryable_error_async(e) or attempt_num == self.retry_limit: - self.log.exception("HTTP error with status: %s", e.status) - # In this case, the user probably made a mistake. - # Don't retry. - raise AirflowException(str(e.status) + ":" + e.message) + ) + try: + response.raise_for_status() + if not self.keep_response: + response.release() + return response + except ClientResponseError as e: + self.log.warning( + "[Try %d of %d] Request to %s failed.", + attempt_num, + self.retry_limit, + url, + ) + if not self._retryable_error_async(e) or attempt_num == self.retry_limit: + self.log.exception("HTTP error with status: %s", e.status) + response.release() + # In this case, the user probably made a mistake. + # Don't retry. + raise AirflowException(f"{e.status}:{e.message}") + + response.release() attempt_num += 1 await asyncio.sleep(self.retry_delay) diff --git a/astronomer/providers/snowflake/hooks/snowflake_sql_api.py b/astronomer/providers/snowflake/hooks/snowflake_sql_api.py index 10b7e44e9..9ab17a112 100644 --- a/astronomer/providers/snowflake/hooks/snowflake_sql_api.py +++ b/astronomer/providers/snowflake/hooks/snowflake_sql_api.py @@ -141,7 +141,7 @@ def execute_query( response.raise_for_status() except requests.exceptions.HTTPError as e: # pragma: no cover raise AirflowException( - f"Response: {e.response.content}, " f"Status Code: {e.response.status_code}" + f"Response: {e.response.content!r}, " f"Status Code: {e.response.status_code}" ) # pragma: no cover json_response = response.json() self.log.info("Snowflake SQL POST API response: %s", json_response) @@ -204,7 +204,7 @@ def check_query_output(self, query_ids: list[str]) -> None: self.log.info(response.json()) except requests.exceptions.HTTPError as e: raise AirflowException( - f"Response: {e.response.content}, Status Code: {e.response.status_code}" + f"Response: {e.response.content!r}, Status Code: {e.response.status_code}" ) @staticmethod diff --git a/tests/http/hooks/test_http.py b/tests/http/hooks/test_http.py index 1f7f86b54..dae4d4364 100644 --- a/tests/http/hooks/test_http.py +++ b/tests/http/hooks/test_http.py @@ -2,6 +2,7 @@ from unittest import mock import pytest +from aiohttp.client_exceptions import ClientConnectionError from airflow.exceptions import AirflowException from airflow.models import Connection @@ -54,7 +55,8 @@ def get_airflow_connection(unused_conn_id=None): return Connection( conn_id="http_default", conn_type="http", - host="test:8080/", + host="test", + port=8080, extra='{"bearer": "test"}', ) @@ -75,6 +77,45 @@ async def test_post_request(self, aioresponse): resp = await hook.run("v1/test") assert resp.status == 200 + @pytest.mark.asyncio + async def test_post_request_and_get_json_without_keep_response(self, aioresponse): + hook = HttpHookAsync() + payload = '{"status":{"status": 200}}' + + aioresponse.post( + "http://test:8080/v1/test", + status=200, + payload=payload, + reason="OK", + ) + + with mock.patch( + "airflow.hooks.base.BaseHook.get_connection", side_effect=self.get_airflow_connection + ): + resp = await hook.run("v1/test") + with pytest.raises(ClientConnectionError, match="Connection closed"): + await resp.json() + + @pytest.mark.asyncio + async def test_post_request_and_get_json_with_keep_response(self, aioresponse): + hook = HttpHookAsync(keep_response=True) + payload = '{"status":{"status": 200}}' + + aioresponse.post( + "http://test:8080/v1/test", + status=200, + payload=payload, + reason="OK", + ) + + with mock.patch( + "airflow.hooks.base.BaseHook.get_connection", side_effect=self.get_airflow_connection + ): + resp = await hook.run("v1/test") + resp_payload = await resp.json() + assert resp.status == 200 + assert resp_payload == payload + @pytest.mark.asyncio async def test_post_request_with_error_code(self, aioresponse): hook = HttpHookAsync()