Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,30 @@
# ---------------------------------------------------------
__path__ = __import__("pkgutil").extend_path(__path__, __name__)

from typing import TYPE_CHECKING, Optional, Any
from typing import TYPE_CHECKING, Any, Optional

from azure.ai.agentserver.agentframework.agent_framework import AgentFrameworkCBAgent
from azure.ai.agentserver.agentframework.tool_client import ToolClient
from azure.ai.agentserver.agentframework._version import VERSION
from azure.ai.agentserver.agentframework._agent_framework import AgentFrameworkCBAgent
from azure.ai.agentserver.agentframework._foundry_tools import FoundryToolsChatMiddleware
from azure.ai.agentserver.core.application._package_metadata import PackageMetadata, set_current_app

if TYPE_CHECKING: # pragma: no cover
from azure.core.credentials_async import AsyncTokenCredential


def from_agent_framework(agent,
credentials: Optional["AsyncTokenCredential"] = None,
**kwargs: Any) -> "AgentFrameworkCBAgent":
def from_agent_framework(
agent,
credentials: Optional["AsyncTokenCredential"] = None,
**kwargs: Any,
) -> "AgentFrameworkCBAgent":

return AgentFrameworkCBAgent(agent, credentials=credentials, **kwargs)


__all__ = ["from_agent_framework", "ToolClient"]
__all__ = [
"from_agent_framework",
"FoundryToolsChatMiddleware",
]
__version__ = VERSION

set_current_app(PackageMetadata.from_dist("azure-ai-agentserver-agentframework"))
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import os
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Optional, Protocol, Union, List
import inspect

from agent_framework import AgentProtocol, AIFunction
from agent_framework.azure import AzureAIClient # pylint: disable=no-name-in-module
Expand All @@ -21,14 +20,14 @@
Response as OpenAIResponse,
ResponseStreamEvent,
)
from azure.ai.agentserver.core.models.projects import ResponseErrorEvent, ResponseFailedEvent

from .models.agent_framework_input_converters import AgentFrameworkInputConverter
from .models.agent_framework_output_non_streaming_converter import (
AgentFrameworkOutputNonStreamingConverter,
)
from .models.agent_framework_output_streaming_converter import AgentFrameworkOutputStreamingConverter
from .models.constants import Constants
from .tool_client import ToolClient

if TYPE_CHECKING:
from azure.core.credentials_async import AsyncTokenCredential
Expand All @@ -39,12 +38,12 @@
class AgentFactory(Protocol):
"""Protocol for agent factory functions.

An agent factory is a callable that takes a ToolClient and returns
An agent factory is a callable that takes a list of tools and returns
an AgentProtocol, either synchronously or asynchronously.
"""

def __call__(self, tools: List[AIFunction]) -> Union[AgentProtocol, Awaitable[AgentProtocol]]:
"""Create an AgentProtocol using the provided ToolClient.
"""Create an AgentProtocol using the provided tools.

:param tools: The list of AIFunction tools available to the agent.
:type tools: List[AIFunction]
Expand All @@ -71,7 +70,7 @@ class AgentFrameworkCBAgent(FoundryCBAgent):
- Supports both streaming and non-streaming responses based on the `stream` flag.
"""

def __init__(self, agent: Union[AgentProtocol, AgentFactory],
def __init__(self, agent: AgentProtocol,
credentials: "Optional[AsyncTokenCredential]" = None,
**kwargs: Any):
"""Initialize the AgentFrameworkCBAgent with an AgentProtocol or a factory function.
Expand All @@ -83,14 +82,7 @@ def __init__(self, agent: Union[AgentProtocol, AgentFactory],
:type credentials: Optional[AsyncTokenCredential]
"""
super().__init__(credentials=credentials, **kwargs) # pylint: disable=unexpected-keyword-arg
self._agent_or_factory: Union[AgentProtocol, AgentFactory] = agent
self._resolved_agent: "Optional[AgentProtocol]" = None
# If agent is already instantiated, use it directly
if isinstance(agent, AgentProtocol):
self._resolved_agent = agent
logger.info(f"Initialized AgentFrameworkCBAgent with agent: {type(agent).__name__}")
else:
logger.info("Initialized AgentFrameworkCBAgent with agent factory")
self._agent: AgentProtocol = agent

@property
def agent(self) -> "Optional[AgentProtocol]":
Expand All @@ -99,7 +91,7 @@ def agent(self) -> "Optional[AgentProtocol]":
:return: The resolved AgentProtocol if available, None otherwise.
:rtype: Optional[AgentProtocol]
"""
return self._resolved_agent
return self._agent

def _resolve_stream_timeout(self, request_body: CreateResponse) -> float:
"""Resolve idle timeout for streaming updates.
Expand All @@ -121,51 +113,6 @@ def _resolve_stream_timeout(self, request_body: CreateResponse) -> float:
env_val = os.getenv(Constants.AGENTS_ADAPTER_STREAM_TIMEOUT_S)
return float(env_val) if env_val is not None else float(Constants.DEFAULT_STREAM_TIMEOUT_S)

async def _resolve_agent(self, context: AgentRunContext):
"""Resolve the agent if it's a factory function (for single-use/first-time resolution).
Creates a ToolClient and calls the factory function with it.
This is used for the initial resolution.

:param context: The agent run context containing tools and user information.
:type context: AgentRunContext
"""
if callable(self._agent_or_factory):
logger.debug("Resolving agent from factory function")

# Create ToolClient with credentials
tool_client = self.get_tool_client(tools=context.get_tools(), user_info=context.get_user_info()) # pylint: disable=no-member
tool_client_wrapper = ToolClient(tool_client)
tools = await tool_client_wrapper.list_tools()

result = self._agent_or_factory(tools)
if inspect.iscoroutine(result):
self._resolved_agent = await result
else:
self._resolved_agent = result

logger.debug("Agent resolved successfully")
else:
# Should not reach here, but just in case
self._resolved_agent = self._agent_or_factory

async def _resolve_agent_for_request(self, context: AgentRunContext):

logger.debug("Resolving fresh agent from factory function for request")

# Create ToolClient with credentials
tool_client = self.get_tool_client(tools=context.get_tools(), user_info=context.get_user_info()) # pylint: disable=no-member
tool_client_wrapper = ToolClient(tool_client)
tools = await tool_client_wrapper.list_tools()

result = self._agent_or_factory(tools)
if inspect.iscoroutine(result):
agent = await result
else:
agent = result

logger.debug("Fresh agent resolved successfully for request")
return agent, tool_client_wrapper

def init_tracing(self):
try:
exporter = os.environ.get(AdapterConstants.OTEL_EXPORTER_ENDPOINT)
Expand Down Expand Up @@ -209,18 +156,7 @@ async def agent_run( # pylint: disable=too-many-statements
OpenAIResponse,
AsyncGenerator[ResponseStreamEvent, Any],
]:
# Resolve agent - always resolve if it's a factory function to get fresh agent each time
# For factories, get a new agent instance per request to avoid concurrency issues
tool_client = None
try:
if callable(self._agent_or_factory):
agent, tool_client = await self._resolve_agent_for_request(context)
elif self._resolved_agent is None:
await self._resolve_agent(context)
agent = self._resolved_agent
else:
agent = self._resolved_agent

logger.info(f"Starting agent_run with stream={context.stream}")
request_input = context.request.get("input")

Expand All @@ -236,27 +172,56 @@ async def agent_run( # pylint: disable=too-many-statements
async def stream_updates():
try:
update_count = 0
updates = agent.run_stream(message)
async for event in streaming_converter.convert(updates):
update_count += 1
yield event

logger.info("Streaming completed with %d updates", update_count)
try:
updates = self.agent.run_stream(message)
async for event in streaming_converter.convert(updates):
update_count += 1
yield event

logger.info("Streaming completed with %d updates", update_count)
except OAuthConsentRequiredError as e:
logger.info("OAuth consent required during streaming updates")
if update_count == 0:
async for event in self.respond_with_oauth_consent_astream(context, e):
yield event
else:
# If we've already emitted events, we cannot safely restart a new
# OAuth-consent stream (it would reset sequence numbers).
yield ResponseErrorEvent(
sequence_number=streaming_converter.next_sequence(),
code="server_error",
message=f"OAuth consent required: {e.consent_url}",
param="agent_run",
)
yield ResponseFailedEvent(
sequence_number=streaming_converter.next_sequence(),
response=streaming_converter._build_response(status="failed"), # pylint: disable=protected-access
)
except Exception as e: # pylint: disable=broad-exception-caught
logger.error("Unhandled exception during streaming updates: %s", e, exc_info=True)

# Emit well-formed error events instead of terminating the stream.
yield ResponseErrorEvent(
sequence_number=streaming_converter.next_sequence(),
code="server_error",
message=str(e),
param="agent_run",
)
yield ResponseFailedEvent(
sequence_number=streaming_converter.next_sequence(),
response=streaming_converter._build_response(status="failed"), # pylint: disable=protected-access
)
finally:
# Close tool_client if it was created for this request
if tool_client is not None:
try:
await tool_client.close()
logger.debug("Closed tool_client after streaming completed")
except Exception as ex: # pylint: disable=broad-exception-caught
logger.warning(f"Error closing tool_client in stream: {ex}")
# No request-scoped resources to clean up here today.
# Keep this block as a hook for future request-scoped cleanup.
pass

return stream_updates()

# Non-streaming path
logger.info("Running agent in non-streaming mode")
non_streaming_converter = AgentFrameworkOutputNonStreamingConverter(context)
result = await agent.run(message)
result = await self.agent.run(message)
logger.debug(f"Agent run completed, result type: {type(result)}")
transformed_result = non_streaming_converter.transform_output_for_response(result)
logger.info("Agent run and transformation completed successfully")
Expand All @@ -272,10 +237,4 @@ async def oauth_consent_stream(error=e):
return oauth_consent_stream()
return await self.respond_with_oauth_consent(context, e)
finally:
# Close tool_client if it was created for this request (non-streaming only, streaming handles in generator)
if not context.stream and tool_client is not None:
try:
await tool_client.close()
logger.debug("Closed tool_client after request processing")
except Exception as ex: # pylint: disable=broad-exception-caught
logger.warning(f"Error closing tool_client: {ex}")
pass
Loading