diff --git a/stripe/_http_client.py b/stripe/_http_client.py index 4fb246523..1455afa5e 100644 --- a/stripe/_http_client.py +++ b/stripe/_http_client.py @@ -1393,6 +1393,13 @@ def close(self): async def close_async(self): await self._client_async.aclose() + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close_async() + return False + class AIOHTTPClient(HTTPClient): name = "aiohttp" @@ -1527,6 +1534,13 @@ async def close_async(self): if self._internally_managed_session: await self._session.close() + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close_async() + return False + class NoImportFoundAsyncClient(HTTPClient): def __init__(self, **kwargs): diff --git a/tests/test_http_client.py b/tests/test_http_client.py index be0f8400f..01654c096 100644 --- a/tests/test_http_client.py +++ b/tests/test_http_client.py @@ -2044,3 +2044,104 @@ async def test_httpx_request_async_https(self): method, abs_url, headers, data ) assert code >= 200 and code < 400 + + +class TestAsyncContextManager: + """Tests for async context manager support in async HTTP clients.""" + + @pytest.mark.anyio + async def test_httpx_async_context_manager(self, mocked_request_lib): + """Test that HTTPXClient supports async context manager.""" + client = _http_client.HTTPXClient() + + # Mock the async client's aclose method + original_aclose = client._client_async.aclose + aclose_called = False + + async def mock_aclose(): + nonlocal aclose_called + aclose_called = True + await original_aclose() + + client._client_async.aclose = mock_aclose + + # Test that the context manager works + async with client as ctx: + assert ctx is client + + # Verify that close_async was called + assert aclose_called + + @pytest.mark.anyio + async def test_aiohttp_async_context_manager(self, mocked_request_lib): + """Test that AIOHTTPClient supports async context manager.""" + client = _http_client.AIOHTTPClient() + + # Mock the session's close method + close_called = False + + async def mock_close(): + nonlocal close_called + close_called = True + + # Replace the session's close method + client._cached_session = mocked_request_lib.ClientSession() + client._cached_session.close = mock_close + + # Test that the context manager works + async with client as ctx: + assert ctx is client + + # Verify that close_async was called + assert close_called + + @pytest.mark.anyio + async def test_httpx_async_context_manager_with_exception(self, mocked_request_lib): + """Test that HTTPXClient context manager handles exceptions correctly.""" + client = _http_client.HTTPXClient() + + # Mock the async client's aclose method + original_aclose = client._client_async.aclose + aclose_called = False + + async def mock_aclose(): + nonlocal aclose_called + aclose_called = True + await original_aclose() + + client._client_async.aclose = mock_aclose + + # Test that the context manager works even with an exception + with pytest.raises(ValueError): + async with client as ctx: + assert ctx is client + raise ValueError("Test exception") + + # Verify that close_async was called even with exception + assert aclose_called + + @pytest.mark.anyio + async def test_aiohttp_async_context_manager_with_exception(self, mocked_request_lib): + """Test that AIOHTTPClient context manager handles exceptions correctly.""" + client = _http_client.AIOHTTPClient() + + # Mock the session's close method + close_called = False + + async def mock_close(): + nonlocal close_called + close_called = True + + # Replace the session's close method + client._cached_session = mocked_request_lib.ClientSession() + client._cached_session.close = mock_close + + # Test that the context manager works even with an exception + with pytest.raises(ValueError): + async with client as ctx: + assert ctx is client + raise ValueError("Test exception") + + # Verify that close_async was called even with exception + assert close_called +