Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 71 additions & 8 deletions lib/crewai/src/crewai/llms/providers/azure/completion.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

from collections.abc import Callable
import json
import logging
import os
import time
from typing import TYPE_CHECKING, Any

from pydantic import BaseModel
Expand Down Expand Up @@ -35,7 +37,9 @@
StreamingChatCompletionsUpdate,
)
from azure.core.credentials import (
AccessToken,
AzureKeyCredential,
TokenCredential,
)
from azure.core.exceptions import (
HttpResponseError,
Expand All @@ -50,6 +54,41 @@
) from None


class _TokenProviderCredential(TokenCredential):
"""Wrapper class to convert an azure_ad_token_provider callable into a TokenCredential.

This allows users to pass a token provider function (like the one returned by
azure.identity.get_bearer_token_provider) to the Azure AI Inference client.
"""

def __init__(self, provider: Callable[..., Any]):
"""Initialize with a token provider callable.

Args:
provider: A callable that returns an access token. This is typically
the result of azure.identity.get_bearer_token_provider().
"""
self._provider = provider

def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
"""Get an access token from the provider.

Args:
*scopes: The scopes for the token (ignored, as the provider handles this).
**kwargs: Additional keyword arguments (ignored).

Returns:
An AccessToken instance.
"""
raw = self._provider()

if isinstance(raw, AccessToken):
return raw

# If it's a bare string, wrap it with a default expiry of 1 hour
return AccessToken(str(raw), int(time.time()) + 3600)


class AzureCompletion(BaseLLM):
"""Azure AI Inference native completion implementation.

Expand All @@ -73,6 +112,8 @@ def __init__(
stop: list[str] | None = None,
stream: bool = False,
interceptor: BaseInterceptor[Any, Any] | None = None,
azure_ad_token_provider: Callable[..., Any] | None = None,
credential: TokenCredential | None = None,
**kwargs: Any,
):
"""Initialize Azure AI Inference chat completion client.
Expand All @@ -92,6 +133,13 @@ def __init__(
stop: Stop sequences
stream: Enable streaming responses
interceptor: HTTP interceptor (not yet supported for Azure).
azure_ad_token_provider: A callable that returns an Azure AD token.
This is typically the result of azure.identity.get_bearer_token_provider().
Use this for Azure AD token-based authentication instead of API keys.
credential: An Azure TokenCredential instance for authentication.
This can be any credential from azure.identity (e.g., DefaultAzureCredential,
ManagedIdentityCredential). Takes precedence over azure_ad_token_provider
and api_key.
**kwargs: Additional parameters
"""
if interceptor is not None:
Expand All @@ -107,6 +155,7 @@ def __init__(
self.api_key = api_key or os.getenv("AZURE_API_KEY")
self.endpoint = (
endpoint
or kwargs.get("base_url")
or os.getenv("AZURE_ENDPOINT")
or os.getenv("AZURE_OPENAI_ENDPOINT")
or os.getenv("AZURE_API_BASE")
Expand All @@ -115,31 +164,45 @@ def __init__(
self.timeout = timeout
self.max_retries = max_retries

if not self.api_key:
raise ValueError(
"Azure API key is required. Set AZURE_API_KEY environment variable or pass api_key parameter."
)
if not self.endpoint:
raise ValueError(
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
)

# Determine the credential to use (priority: credential > azure_ad_token_provider > api_key)
chosen_credential: TokenCredential | AzureKeyCredential | None = None

if credential is not None:
chosen_credential = credential
elif azure_ad_token_provider is not None:
chosen_credential = _TokenProviderCredential(azure_ad_token_provider)
elif self.api_key:
chosen_credential = AzureKeyCredential(self.api_key)

if chosen_credential is None:
raise ValueError(
"Azure authentication is required. Provide one of: "
"api_key (or set AZURE_API_KEY environment variable), "
"azure_ad_token_provider (callable from azure.identity.get_bearer_token_provider), "
"or credential (TokenCredential instance from azure.identity)."
)

# Validate and potentially fix Azure OpenAI endpoint URL
self.endpoint = self._validate_and_fix_endpoint(self.endpoint, model)

# Build client kwargs
client_kwargs = {
client_kwargs: dict[str, Any] = {
"endpoint": self.endpoint,
"credential": AzureKeyCredential(self.api_key),
"credential": chosen_credential,
}

# Add api_version if specified (primarily for Azure OpenAI endpoints)
if self.api_version:
client_kwargs["api_version"] = self.api_version

self.client = ChatCompletionsClient(**client_kwargs) # type: ignore[arg-type]
self.client = ChatCompletionsClient(**client_kwargs)

self.async_client = AsyncChatCompletionsClient(**client_kwargs) # type: ignore[arg-type]
self.async_client = AsyncChatCompletionsClient(**client_kwargs)

self.top_p = top_p
self.frequency_penalty = frequency_penalty
Expand Down
212 changes: 208 additions & 4 deletions lib/crewai/tests/llms/azure/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,13 +382,13 @@ def test_azure_raises_error_when_endpoint_missing():
AzureCompletion(model="gpt-4", api_key="test-key")


def test_azure_raises_error_when_api_key_missing():
"""Test that AzureCompletion raises ValueError when API key is missing"""
def test_azure_raises_error_when_no_auth_provided():
"""Test that AzureCompletion raises ValueError when no authentication is provided"""
from crewai.llms.providers.azure.completion import AzureCompletion

# Clear environment variables
with patch.dict(os.environ, {}, clear=True):
with pytest.raises(ValueError, match="Azure API key is required"):
with pytest.raises(ValueError, match="Azure authentication is required"):
AzureCompletion(model="gpt-4", endpoint="https://test.openai.azure.com")


Expand Down Expand Up @@ -1112,4 +1112,208 @@ def test_azure_completion_params_preparation_with_drop_params():
messages = [{"role": "user", "content": "Hello"}]
params = llm._prepare_completion_params(messages)

assert params.get('stop') == None
assert params.get('stop') == None


def test_azure_ad_token_provider_authentication():
"""
Test that AzureCompletion can be initialized with azure_ad_token_provider
for Azure AD token-based authentication instead of API keys.
"""
from crewai.llms.providers.azure.completion import AzureCompletion

# Mock token provider that returns a string token
def mock_token_provider():
return "mock-azure-ad-token"

# Clear environment variables to ensure no API key is used
with patch.dict(os.environ, {}, clear=True):
completion = AzureCompletion(
model="gpt-4",
endpoint="https://test.openai.azure.com",
azure_ad_token_provider=mock_token_provider
)

assert completion.endpoint == "https://test.openai.azure.com/openai/deployments/gpt-4"
assert completion.api_key is None


def test_azure_ad_token_provider_with_access_token():
"""
Test that azure_ad_token_provider works when it returns an AccessToken object.
"""
from crewai.llms.providers.azure.completion import AzureCompletion, _TokenProviderCredential
from azure.core.credentials import AccessToken

# Mock token provider that returns an AccessToken object
mock_access_token = AccessToken("mock-token-string", 1234567890)

def mock_token_provider():
return mock_access_token

# Test the _TokenProviderCredential wrapper
credential = _TokenProviderCredential(mock_token_provider)
token = credential.get_token("https://cognitiveservices.azure.com/.default")

assert token.token == "mock-token-string"
assert token.expires_on == 1234567890


def test_azure_ad_token_provider_with_string_token():
"""
Test that azure_ad_token_provider works when it returns a plain string token.
"""
from crewai.llms.providers.azure.completion import _TokenProviderCredential

# Mock token provider that returns a plain string
def mock_token_provider():
return "plain-string-token"

credential = _TokenProviderCredential(mock_token_provider)
token = credential.get_token("https://cognitiveservices.azure.com/.default")

assert token.token == "plain-string-token"
# Should have a default expiry time (approximately 1 hour from now)
assert token.expires_on > 0


def test_azure_credential_authentication():
"""
Test that AzureCompletion can be initialized with a TokenCredential instance.
"""
from crewai.llms.providers.azure.completion import AzureCompletion
from azure.core.credentials import AccessToken, TokenCredential

# Create a mock TokenCredential
class MockTokenCredential(TokenCredential):
def get_token(self, *scopes, **kwargs):
return AccessToken("mock-credential-token", 1234567890)

mock_credential = MockTokenCredential()

# Clear environment variables to ensure no API key is used
with patch.dict(os.environ, {}, clear=True):
completion = AzureCompletion(
model="gpt-4",
endpoint="https://test.openai.azure.com",
credential=mock_credential
)

assert completion.endpoint == "https://test.openai.azure.com/openai/deployments/gpt-4"
assert completion.api_key is None


def test_azure_credential_takes_precedence_over_api_key():
"""
Test that credential parameter takes precedence over api_key when both are provided.
"""
from crewai.llms.providers.azure.completion import AzureCompletion
from azure.core.credentials import AccessToken, TokenCredential

class MockTokenCredential(TokenCredential):
def get_token(self, *scopes, **kwargs):
return AccessToken("credential-token", 1234567890)

mock_credential = MockTokenCredential()

# Provide both credential and api_key
with patch.dict(os.environ, {}, clear=True):
completion = AzureCompletion(
model="gpt-4",
endpoint="https://test.openai.azure.com",
api_key="should-not-be-used",
credential=mock_credential
)

# The completion should be created successfully with the credential
assert completion.endpoint == "https://test.openai.azure.com/openai/deployments/gpt-4"


def test_azure_ad_token_provider_takes_precedence_over_api_key():
"""
Test that azure_ad_token_provider takes precedence over api_key when both are provided.
"""
from crewai.llms.providers.azure.completion import AzureCompletion

def mock_token_provider():
return "token-provider-token"

# Provide both azure_ad_token_provider and api_key
with patch.dict(os.environ, {}, clear=True):
completion = AzureCompletion(
model="gpt-4",
endpoint="https://test.openai.azure.com",
api_key="should-not-be-used",
azure_ad_token_provider=mock_token_provider
)

# The completion should be created successfully with the token provider
assert completion.endpoint == "https://test.openai.azure.com/openai/deployments/gpt-4"


def test_azure_credential_takes_precedence_over_token_provider():
"""
Test that credential takes precedence over azure_ad_token_provider when both are provided.
"""
from crewai.llms.providers.azure.completion import AzureCompletion
from azure.core.credentials import AccessToken, TokenCredential

class MockTokenCredential(TokenCredential):
def get_token(self, *scopes, **kwargs):
return AccessToken("credential-token", 1234567890)

mock_credential = MockTokenCredential()

def mock_token_provider():
return "token-provider-token"

# Provide both credential and azure_ad_token_provider
with patch.dict(os.environ, {}, clear=True):
completion = AzureCompletion(
model="gpt-4",
endpoint="https://test.openai.azure.com",
credential=mock_credential,
azure_ad_token_provider=mock_token_provider
)

# The completion should be created successfully with the credential
assert completion.endpoint == "https://test.openai.azure.com/openai/deployments/gpt-4"


def test_azure_ad_token_provider_via_llm_factory():
"""
Test that azure_ad_token_provider can be passed through the LLM factory class.
"""
def mock_token_provider():
return "mock-token"

# Clear environment variables
with patch.dict(os.environ, {
"AZURE_ENDPOINT": "https://test.openai.azure.com"
}, clear=True):
llm = LLM(
model="azure/gpt-4",
azure_ad_token_provider=mock_token_provider
)

from crewai.llms.providers.azure.completion import AzureCompletion
assert isinstance(llm, AzureCompletion)
assert llm.api_key is None


def test_azure_base_url_parameter_for_endpoint():
"""
Test that base_url parameter can be used as an alternative to endpoint.
This is useful for users migrating from other providers.
"""
from crewai.llms.providers.azure.completion import AzureCompletion

# Clear environment variables
with patch.dict(os.environ, {}, clear=True):
completion = AzureCompletion(
model="gpt-4",
base_url="https://test.openai.azure.com",
api_key="test-key"
)

assert completion.endpoint == "https://test.openai.azure.com/openai/deployments/gpt-4"
Loading