diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index 706cbba0ac2c..50dc8f42f6ee 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -12,7 +12,6 @@ ### Bugs Fixed -- Fixed the `AZURE_REGIONAL_AUTHORITY_NAME` environment variable not being respected in certain credentials. ([#44347](https://github.com/Azure/azure-sdk-for-python/pull/44347)) - Fixed an issue with certain credentials not bypassing the token cache when claims are provided in `get_token` or `get_token_info` calls. ([#44552](https://github.com/Azure/azure-sdk-for-python/pull/44552)) - Fixed an issue where an unhelpful TypeError was raised during Entra ID token requests that returned empty responses. Now, a ClientAuthenticationError is raised with the full response for better troubleshooting. ([#44258](https://github.com/Azure/azure-sdk-for-python/pull/44258)) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py index c6690033697d..216f601f3869 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py @@ -3,8 +3,6 @@ # Licensed under the MIT License. # ------------------------------------ import time -import logging -import os from typing import Iterable, Union, Optional, Any from azure.core.credentials import AccessTokenInfo @@ -13,17 +11,9 @@ from .aad_client_base import AadClientBase from .aadclient_certificate import AadClientCertificate from .pipeline import build_pipeline -from .._enums import RegionalAuthority -_LOGGER = logging.getLogger(__name__) - - -class AadClient(AadClientBase): # pylint:disable=client-accepts-api-version-keyword - - # pylint:disable=missing-client-constructor-parameter-credential - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) +class AadClient(AadClientBase): def __enter__(self) -> "AadClient": self._pipeline.__enter__() @@ -38,7 +28,6 @@ def close(self) -> None: def obtain_token_by_authorization_code( self, scopes: Iterable[str], code: str, redirect_uri: str, client_secret: Optional[str] = None, **kwargs: Any ) -> AccessTokenInfo: - self._initialize_regional_authority() request = self._get_auth_code_request( scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret, **kwargs ) @@ -47,24 +36,20 @@ def obtain_token_by_authorization_code( def obtain_token_by_client_certificate( self, scopes: Iterable[str], certificate: AadClientCertificate, **kwargs: Any ) -> AccessTokenInfo: - self._initialize_regional_authority() request = self._get_client_certificate_request(scopes, certificate, **kwargs) return self._run_pipeline(request, **kwargs) def obtain_token_by_client_secret(self, scopes: Iterable[str], secret: str, **kwargs: Any) -> AccessTokenInfo: - self._initialize_regional_authority() request = self._get_client_secret_request(scopes, secret, **kwargs) return self._run_pipeline(request, **kwargs) def obtain_token_by_jwt_assertion(self, scopes: Iterable[str], assertion: str, **kwargs: Any) -> AccessTokenInfo: - self._initialize_regional_authority() request = self._get_jwt_assertion_request(scopes, assertion, **kwargs) return self._run_pipeline(request, **kwargs) def obtain_token_by_refresh_token( self, scopes: Iterable[str], refresh_token: str, **kwargs: Any ) -> AccessTokenInfo: - self._initialize_regional_authority() request = self._get_refresh_token_request(scopes, refresh_token, **kwargs) return self._run_pipeline(request, **kwargs) @@ -78,37 +63,6 @@ def obtain_token_on_behalf_of( # no need for an implementation, non-async OnBehalfOfCredential acquires tokens through MSAL raise NotImplementedError() - def _initialize_regional_authority(self) -> None: - # This is based on MSAL's regional authority logic. - if self._regional_authority is not False: - return - - regional_authority = self._get_regional_authority_from_env() - if not regional_authority: - self._regional_authority = None - return - - if regional_authority in [RegionalAuthority.AUTO_DISCOVER_REGION, "true"]: - regional_authority = self._discover_region() - if not regional_authority: - _LOGGER.info("Failed to auto-discover region. Using the non-regional authority.") - self._regional_authority = None - return - - self._regional_authority = self._build_regional_authority_url(regional_authority) - - def _discover_region(self) -> Optional[str]: - region = os.environ.get("REGION_NAME", "").replace(" ", "").lower() - if region: - return region - try: - request = self._get_region_discovery_request() - response = self._pipeline.run(request) - return self._process_region_discovery_response(response) - except Exception as ex: # pylint: disable=broad-except - _LOGGER.debug("Failed to discover Azure region from IMDS: %s", ex) - return None - def _build_pipeline(self, **kwargs: Any) -> Pipeline: return build_pipeline(**kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py index 89947835a61d..b66cb15ab1b1 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py @@ -5,11 +5,8 @@ import abc import base64 import json -import logging -import os import time from uuid import uuid4 -from urllib.parse import urlparse from typing import TYPE_CHECKING, List, Any, Iterable, Optional, Union, Dict, cast from msal import TokenCache @@ -22,7 +19,6 @@ from .utils import get_default_authority, normalize_authority, resolve_tenant from .aadclient_certificate import AadClientCertificate from .._persistent_cache import _load_persistent_cache -from .._constants import EnvironmentVariables if TYPE_CHECKING: @@ -34,12 +30,10 @@ PolicyType = Union[AsyncHTTPPolicy, HTTPPolicy, SansIOHTTPPolicy] TransportType = Union[AsyncHttpTransport, HttpTransport] -_LOGGER = logging.getLogger(__name__) - JWT_BEARER_ASSERTION = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" -class AadClientBase(abc.ABC): # pylint: disable=too-many-instance-attributes +class AadClientBase(abc.ABC): _POST = ["POST"] def __init__( @@ -51,13 +45,10 @@ def __init__( cae_cache: Optional[TokenCache] = None, *, additionally_allowed_tenants: Optional[List[str]] = None, - **kwargs: Any, + **kwargs: Any ) -> None: self._authority = normalize_authority(authority) if authority else get_default_authority() - # False indicates uninitialized. Actual value is str or None. - self._regional_authority: Optional[Union[str, bool]] = False - self._tenant_id = tenant_id self._client_id = client_id self._additionally_allowed_tenants = additionally_allowed_tenants or [] @@ -310,7 +301,7 @@ def _get_on_behalf_of_request( scopes: Iterable[str], client_credential: Union[str, AadClientCertificate, Dict[str, Any]], user_assertion: str, - **kwargs: Any, + **kwargs: Any ) -> HttpRequest: data = { "assertion": user_assertion, @@ -365,7 +356,7 @@ def _get_refresh_token_on_behalf_of_request( scopes: Iterable[str], client_credential: Union[str, AadClientCertificate, Dict[str, Any]], refresh_token: str, - **kwargs: Any, + **kwargs: Any ) -> HttpRequest: data = { "grant_type": "refresh_token", @@ -392,43 +383,11 @@ def _get_refresh_token_on_behalf_of_request( request = self._post(data, **kwargs) return request - def _get_region_discovery_request(self) -> HttpRequest: - url = "http://169.254.169.254/metadata/instance/compute/location?format=text&api-version=2021-01-01" - request = HttpRequest("GET", url, headers={"Metadata": "true"}) - return request - - def _process_region_discovery_response(self, response: PipelineResponse) -> Optional[str]: - if response.http_response.status_code == 200: - region = response.http_response.text().strip() - if region: - return region - _LOGGER.warning("IMDS returned empty region") - return None - - def _build_regional_authority_url(self, regional_authority: str) -> Optional[str]: - central_host = urlparse(self._authority).hostname - if not central_host: - return None - - # This mirrors the regional authority logic in MSAL. - if central_host in ("login.microsoftonline.com", "login.microsoft.com", "login.windows.net", "sts.windows.net"): - regional_host = f"{regional_authority}.login.microsoft.com" - else: - regional_host = f"{regional_authority}.{central_host}" - return f"https://{regional_host}" - - def _get_regional_authority_from_env(self) -> Optional[str]: - regional_authority = os.environ.get(EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME) or os.environ.get( - "MSAL_FORCE_REGION" - ) # For parity with creds that rely on MSAL, we check this var too. - return regional_authority.lower() if regional_authority else None - def _get_token_url(self, **kwargs: Any) -> str: tenant = resolve_tenant( self._tenant_id, additionally_allowed_tenants=self._additionally_allowed_tenants, **kwargs ) - authority = cast(str, self._regional_authority) if self._regional_authority else self._authority - return "/".join((authority, tenant, "oauth2/v2.0/token")) + return "/".join((self._authority, tenant, "oauth2/v2.0/token")) def _post(self, data: Dict, **kwargs: Any) -> HttpRequest: url = self._get_token_url(**kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py index f8f434dfe3fd..7b99f85ac912 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py @@ -2,8 +2,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -import logging -import os import time from typing import Iterable, Optional, Union, Dict, Any @@ -14,12 +12,9 @@ from ..._internal import AadClientCertificate from ..._internal import AadClientBase from ..._internal.pipeline import build_async_pipeline -from ..._enums import RegionalAuthority Policy = Union[AsyncHTTPPolicy, SansIOHTTPPolicy] -_LOGGER = logging.getLogger(__name__) - # pylint:disable=invalid-overridden-method class AadClient(AadClientBase): @@ -38,7 +33,6 @@ async def close(self) -> None: async def obtain_token_by_authorization_code( self, scopes: Iterable[str], code: str, redirect_uri: str, client_secret: Optional[str] = None, **kwargs ) -> AccessTokenInfo: - await self._initialize_regional_authority() request = self._get_auth_code_request( scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret, **kwargs ) @@ -47,24 +41,20 @@ async def obtain_token_by_authorization_code( async def obtain_token_by_client_certificate( self, scopes: Iterable[str], certificate: AadClientCertificate, **kwargs ) -> AccessTokenInfo: - await self._initialize_regional_authority() request = self._get_client_certificate_request(scopes, certificate, **kwargs) return await self._run_pipeline(request, stream=False, **kwargs) async def obtain_token_by_client_secret(self, scopes: Iterable[str], secret: str, **kwargs) -> AccessTokenInfo: - await self._initialize_regional_authority() request = self._get_client_secret_request(scopes, secret, **kwargs) return await self._run_pipeline(request, **kwargs) async def obtain_token_by_jwt_assertion(self, scopes: Iterable[str], assertion: str, **kwargs) -> AccessTokenInfo: - await self._initialize_regional_authority() request = self._get_jwt_assertion_request(scopes, assertion, **kwargs) return await self._run_pipeline(request, stream=False, **kwargs) async def obtain_token_by_refresh_token( self, scopes: Iterable[str], refresh_token: str, **kwargs ) -> AccessTokenInfo: - await self._initialize_regional_authority() request = self._get_refresh_token_request(scopes, refresh_token, **kwargs) return await self._run_pipeline(request, **kwargs) @@ -75,7 +65,6 @@ async def obtain_token_by_refresh_token_on_behalf_of( # pylint: disable=name-to refresh_token: str, **kwargs ) -> AccessTokenInfo: - await self._initialize_regional_authority() request = self._get_refresh_token_on_behalf_of_request( scopes, client_credential=client_credential, refresh_token=refresh_token, **kwargs ) @@ -88,7 +77,6 @@ async def obtain_token_on_behalf_of( user_assertion: str, **kwargs ) -> AccessTokenInfo: - await self._initialize_regional_authority() request = self._get_on_behalf_of_request( scopes=scopes, client_credential=client_credential, user_assertion=user_assertion, **kwargs ) @@ -97,38 +85,6 @@ async def obtain_token_on_behalf_of( def _build_pipeline(self, **kwargs) -> AsyncPipeline: return build_async_pipeline(**kwargs) - async def _initialize_regional_authority(self) -> None: - # This is based on MSAL's regional authority logic. - if self._regional_authority is not False: - return - - regional_authority = self._get_regional_authority_from_env() - if not regional_authority: - self._regional_authority = None - return - - if regional_authority in [RegionalAuthority.AUTO_DISCOVER_REGION, "true"]: - # Attempt to discover the region from IMDS - regional_authority = await self._discover_region() - if not regional_authority: - _LOGGER.info("Failed to auto-discover region. Using the non-regional authority.") - self._regional_authority = None - return - - self._regional_authority = self._build_regional_authority_url(regional_authority) - - async def _discover_region(self) -> Optional[str]: - region = os.environ.get("REGION_NAME", "").replace(" ", "").lower() - if region: - return region - try: - request = self._get_region_discovery_request() - response = await self._pipeline.run(request) - return self._process_region_discovery_response(response) - except Exception as ex: # pylint: disable=broad-except - _LOGGER.debug("Failed to discover Azure region from IMDS: %s", ex) - return None - async def _run_pipeline(self, request: HttpRequest, **kwargs) -> AccessTokenInfo: # remove tenant_id and claims kwarg that could have been passed from credential's get_token method # tenant_id is already part of `request` at this point diff --git a/sdk/identity/azure-identity/tests/test_aad_client.py b/sdk/identity/azure-identity/tests/test_aad_client.py index 01a4c1a8f631..3f660e43dc55 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client.py +++ b/sdk/identity/azure-identity/tests/test_aad_client.py @@ -137,105 +137,6 @@ def send(request, **_): client.obtain_token_by_refresh_token("scope", "refresh token") -def test_request_url_with_regional_authority(): - - def send(request, **_): - assert urlparse(request.url).netloc == "centralus.login.microsoft.com" - return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "***"}) - - with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "centralus"}, clear=True): - client = AadClient("tenant-id", "client id", transport=Mock(send=send)) - - client.obtain_token_by_authorization_code("scope", "code", "uri") - client.obtain_token_by_refresh_token("scope", "refresh token") - - # obtain_token_by_refresh_token is client_secret safe - client.obtain_token_by_refresh_token("scope", "refresh token", client_secret="secret") - - -def test_regional_authority_initialized_once(): - """The client should lazily initialize its regional authority only once.""" - - def send(request, **_): - assert urlparse(request.url).netloc == "centralus.login.microsoft.com" - return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "***"}) - - with patch("azure.identity._internal.aad_client.AadClient._get_regional_authority_from_env") as mock_env: - mock_env.return_value = "centralus" - transport = Mock(send=Mock(wraps=send)) - client = AadClient("tenant-id", "client id", transport=transport) - - # The first token request should trigger initialization. - client.obtain_token_by_authorization_code("scope", "code", "uri") - # Subsequent requests shouldn't. - client.obtain_token_by_refresh_token("scope", "refresh token") - client.obtain_token_by_refresh_token("scope", "refresh token", client_secret="secret") - - # Env should be checked only once. - assert mock_env.call_count == 1 - - -def test_initialize_regional_authority(): - client = AadClient("tenant-id", "client-id") - # The initial state should be False (uninitialized) - assert client._regional_authority is False - - client._initialize_regional_authority() - assert client._regional_authority is None - - # Test with usage of AZURE_REGIONAL_AUTHORITY_NAME - with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "centralus"}, clear=True): - client = AadClient("tenant-id", "client-id") - client._initialize_regional_authority() - assert client._regional_authority == "https://centralus.login.microsoft.com" - - # Test with non-Microsoft authority host - with patch.dict( - "os.environ", - { - EnvironmentVariables.AZURE_AUTHORITY_HOST: "https://custom.authority.com", - EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "centralus", - }, - clear=True, - ): - client = AadClient("tenant-id", "client-id") - client._initialize_regional_authority() - assert client._regional_authority == "https://centralus.custom.authority.com" - - # Test with usage of region auto-discovery env var - # Test with AZURE_REGIONAL_AUTHORITY_NAME set to "True" (auto-discovery) - with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "True"}, clear=True): - with patch.dict("os.environ", {"REGION_NAME": "southcentralus"}): - client = AadClient("tenant-id", "client-id") - client._initialize_regional_authority() - assert client._regional_authority == "https://southcentralus.login.microsoft.com" - - # Test with usage of region auto-discovery env var - with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "tryautodetect"}, clear=True): - with patch.dict("os.environ", {"REGION_NAME": "eastus"}): - client = AadClient("tenant-id", "client-id") - client._initialize_regional_authority() - assert client._regional_authority == "https://eastus.login.microsoft.com" - - with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "tryautodetect"}, clear=True): - response = Mock( - status_code=200, - headers={"Content-Type": "text/plain"}, - content_type="text/plain", - text=lambda encoding=None: "westus2", - ) - transport = mock.Mock(send=mock.Mock(return_value=response)) - - client = AadClient("tenant-id", "client-id", transport=transport) - client._initialize_regional_authority() - assert client._regional_authority == "https://westus2.login.microsoft.com" - - with patch.dict("os.environ", {"MSAL_FORCE_REGION": "westus3"}, clear=True): - client = AadClient("tenant-id", "client-id") - client._initialize_regional_authority() - assert client._regional_authority == "https://westus3.login.microsoft.com" - - @pytest.mark.parametrize("secret", (None, "client secret")) def test_authorization_code(secret): tenant_id = "tenant-id" diff --git a/sdk/identity/azure-identity/tests/test_aad_client_async.py b/sdk/identity/azure-identity/tests/test_aad_client_async.py index d10357e545df..db08df523e4e 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client_async.py +++ b/sdk/identity/azure-identity/tests/test_aad_client_async.py @@ -189,44 +189,6 @@ async def send(request, **_): await client.obtain_token_by_refresh_token("scope", "refresh token") -async def test_request_url_with_regional_authority(): - - async def send(request, **_): - assert urlparse(request.url).netloc == "centralus.login.microsoft.com" - return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "***"}) - - with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "centralus"}, clear=True): - client = AadClient("tenant-id", "client id", transport=Mock(send=send)) - - await client.obtain_token_by_authorization_code("scope", "code", "uri") - await client.obtain_token_by_refresh_token("scope", "refresh token") - - # obtain_token_by_refresh_token is client_secret safe - await client.obtain_token_by_refresh_token("scope", "refresh token", client_secret="secret") - - -async def test_regional_authority_initialized_once(): - """The client should lazily initialize its regional authority only once.""" - - async def send(request, **_): - assert urlparse(request.url).netloc == "centralus.login.microsoft.com" - return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "***"}) - - with patch("azure.identity.aio._internal.aad_client.AadClient._get_regional_authority_from_env") as mock_env: - mock_env.return_value = "centralus" - transport = AsyncMockTransport(send=Mock(wraps=send)) - client = AadClient("tenant-id", "client id", transport=transport) - - # The first token request should trigger initialization. - await client.obtain_token_by_authorization_code("scope", "code", "uri") - # Subsequent requests shouldn't. - await client.obtain_token_by_refresh_token("scope", "refresh token") - await client.obtain_token_by_refresh_token("scope", "refresh token", client_secret="secret") - - # Env should be checked only once. - assert mock_env.call_count == 1 - - async def test_evicts_invalid_refresh_token(): """when Microsoft Entra ID rejects a refresh token, the client should evict that token from its cache""" @@ -362,74 +324,3 @@ async def test_multitenant_cache(): assert client_d.get_cached_access_token([scope]) is None with pytest.raises(ClientAuthenticationError, match=message): client_d.get_cached_access_token([scope], tenant_id=tenant_a) - - -async def test_initialize_regional_authority(): - client = AadClient("tenant-id", "client-id") - # The initial state should be False (uninitialized) - assert client._regional_authority is False - - async with client: - await client._initialize_regional_authority() - assert client._regional_authority is None - - # Test with usage of AZURE_REGIONAL_AUTHORITY_NAME - with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "centralus"}, clear=True): - client = AadClient("tenant-id", "client-id") - async with client: - await client._initialize_regional_authority() - assert client._regional_authority == "https://centralus.login.microsoft.com" - - # Test with non-Microsoft authority host - with patch.dict( - "os.environ", - { - EnvironmentVariables.AZURE_AUTHORITY_HOST: "https://custom.authority.com", - EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "centralus", - }, - clear=True, - ): - client = AadClient("tenant-id", "client-id") - async with client: - await client._initialize_regional_authority() - assert client._regional_authority == "https://centralus.custom.authority.com" - - # Test with AZURE_REGIONAL_AUTHORITY_NAME set to "True" (auto-discovery) - with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "True"}, clear=True): - with patch.dict("os.environ", {"REGION_NAME": "southcentralus"}): - client = AadClient("tenant-id", "client-id") - await client._initialize_regional_authority() - assert client._regional_authority == "https://southcentralus.login.microsoft.com" - await client.close() - - # Test with usage of region auto-discovery env var - with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "tryautodetect"}, clear=True): - with patch.dict("os.environ", {"REGION_NAME": "eastus"}): - client = AadClient("tenant-id", "client-id") - await client._initialize_regional_authority() - assert client._regional_authority == "https://eastus.login.microsoft.com" - await client.close() - - with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "tryautodetect"}, clear=True): - response = Mock( - status_code=200, - headers={"Content-Type": "text/plain"}, - content_type="text/plain", - text=lambda encoding=None: "westus2", - ) - - async def send(*args, **kwargs): - return response - - transport = AsyncMockTransport(send=Mock(wraps=send)) - - client = AadClient("tenant-id", "client-id", transport=transport) - async with client: - await client._initialize_regional_authority() - assert client._regional_authority == "https://westus2.login.microsoft.com" - - with patch.dict("os.environ", {"MSAL_FORCE_REGION": "westus3"}, clear=True): - client = AadClient("tenant-id", "client-id") - async with client: - await client._initialize_regional_authority() - assert client._regional_authority == "https://westus3.login.microsoft.com"