Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions agentops/instrumentation/providers/openai/provider_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""OpenAI-compatible provider detection for AgentOps instrumentation.

When users use the OpenAI SDK with a custom base_url pointing to an
OpenAI-compatible provider (e.g., MiniMax, Groq, Together AI), this module
detects the actual provider from the client's base_url so that telemetry
spans are attributed to the correct system.
"""

import logging
from typing import Any, Optional

logger = logging.getLogger(__name__)

# Mapping of base_url host patterns to provider names.
# Each entry maps a substring found in the base_url host to the provider name
# used in the gen_ai.system span attribute.
_PROVIDER_HOST_MAP = {
"api.minimax.io": "MiniMax",
"api.minimax.chat": "MiniMax",
"api.groq.com": "Groq",
"api.together.xyz": "Together AI",
"api.together.ai": "Together AI",
"api.fireworks.ai": "Fireworks AI",
"api.deepseek.com": "DeepSeek",
"api.mistral.ai": "Mistral AI",
"api.perplexity.ai": "Perplexity AI",
"generativelanguage.googleapis.com": "Google AI",
"api.x.ai": "xAI",
"api.sambanova.ai": "SambaNova",
"api.cerebras.ai": "Cerebras",
}

_DEFAULT_PROVIDER = "OpenAI"


def detect_provider_from_instance(instance: Any) -> str:
"""Detect the LLM provider from an OpenAI SDK resource instance.

Inspects the client's base_url to determine if the OpenAI SDK is being
used with an OpenAI-compatible provider (e.g., MiniMax, Groq).

Args:
instance: The OpenAI SDK resource instance (e.g., Completions,
AsyncCompletions). Expected to have ``_client.base_url``.

Returns:
The detected provider name (e.g., "MiniMax", "OpenAI").
"""
base_url = _extract_base_url(instance)
if not base_url:
return _DEFAULT_PROVIDER

return _match_provider(base_url)


def _extract_base_url(instance: Any) -> Optional[str]:
"""Extract the base_url string from an OpenAI SDK resource instance."""
try:
client = getattr(instance, "_client", None)
if client is None:
return None
base_url = getattr(client, "base_url", None)
if base_url is None:
return None
# base_url may be a URL object or a string
return str(base_url)
except Exception:
logger.debug("[PROVIDER DETECTION] Failed to extract base_url from instance")
return None


def _match_provider(base_url: str) -> str:
"""Match a base_url string against known provider hosts.

Args:
base_url: The base URL string (e.g., "https://api.minimax.io/v1/").

Returns:
The matched provider name, or "OpenAI" if no match is found.
"""
base_url_lower = base_url.lower()
for host_pattern, provider_name in _PROVIDER_HOST_MAP.items():
if host_pattern in base_url_lower:
return provider_name
return _DEFAULT_PROVIDER
23 changes: 23 additions & 0 deletions agentops/instrumentation/providers/openai/stream_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from agentops.instrumentation.common.wrappers import _with_tracer_wrapper
from agentops.instrumentation.providers.openai.utils import is_metrics_enabled
from agentops.instrumentation.providers.openai.wrappers.chat import handle_chat_attributes, _create_tool_span
from agentops.instrumentation.providers.openai.provider_detection import detect_provider_from_instance
from agentops.semconv import SpanAttributes, LLMRequestTypeValues, MessageAttributes


Expand Down Expand Up @@ -477,6 +478,11 @@ def chat_completion_stream_wrapper(tracer, wrapped, instance, args, kwargs):
# Extract and set request attributes
request_attributes = handle_chat_attributes(kwargs=kwargs)

# Detect actual provider from client base_url (e.g., MiniMax, Groq)
provider = detect_provider_from_instance(instance)
if provider != "OpenAI":
request_attributes[SpanAttributes.LLM_SYSTEM] = provider

for key, value in request_attributes.items():
span.set_attribute(key, value)

Expand Down Expand Up @@ -546,6 +552,11 @@ async def async_chat_completion_stream_wrapper(tracer, wrapped, instance, args,
# Extract and set request attributes
request_attributes = handle_chat_attributes(kwargs=kwargs)

# Detect actual provider from client base_url (e.g., MiniMax, Groq)
provider = detect_provider_from_instance(instance)
if provider != "OpenAI":
request_attributes[SpanAttributes.LLM_SYSTEM] = provider

for key, value in request_attributes.items():
span.set_attribute(key, value)

Expand Down Expand Up @@ -852,6 +863,12 @@ def responses_stream_wrapper(tracer, wrapped, instance, args, kwargs):
from agentops.instrumentation.providers.openai.wrappers.responses import handle_responses_attributes

request_attributes = handle_responses_attributes(kwargs=kwargs)

# Detect actual provider from client base_url (e.g., MiniMax, Groq)
provider = detect_provider_from_instance(instance)
if provider != "OpenAI":
request_attributes[SpanAttributes.LLM_SYSTEM] = provider

for key, value in request_attributes.items():
span.set_attribute(key, value)

Expand Down Expand Up @@ -909,6 +926,12 @@ async def async_responses_stream_wrapper(tracer, wrapped, instance, args, kwargs
from agentops.instrumentation.providers.openai.wrappers.responses import handle_responses_attributes

request_attributes = handle_responses_attributes(kwargs=kwargs)

# Detect actual provider from client base_url (e.g., MiniMax, Groq)
provider = detect_provider_from_instance(instance)
if provider != "OpenAI":
request_attributes[SpanAttributes.LLM_SYSTEM] = provider

for key, value in request_attributes.items():
span.set_attribute(key, value)

Expand Down
162 changes: 162 additions & 0 deletions tests/unit/instrumentation/openai_core/test_provider_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
"""Tests for OpenAI-compatible provider detection.

Verifies that the provider detection utility correctly identifies
LLM providers from the OpenAI SDK client's base_url.
"""

import pytest

from agentops.instrumentation.providers.openai.provider_detection import (
detect_provider_from_instance,
_extract_base_url,
_match_provider,
_PROVIDER_HOST_MAP,
)


class MockClient:
"""Mock OpenAI client with configurable base_url."""

def __init__(self, base_url=None):
self.base_url = base_url


class MockResource:
"""Mock OpenAI SDK resource (e.g., Completions) with a _client attribute."""

def __init__(self, client=None):
self._client = client


class TestMatchProvider:
"""Tests for _match_provider function."""

def test_minimax_io(self):
assert _match_provider("https://api.minimax.io/v1/") == "MiniMax"

def test_minimax_chat(self):
assert _match_provider("https://api.minimax.chat/v1") == "MiniMax"

def test_groq(self):
assert _match_provider("https://api.groq.com/openai/v1") == "Groq"

def test_together_xyz(self):
assert _match_provider("https://api.together.xyz/v1") == "Together AI"

def test_together_ai(self):
assert _match_provider("https://api.together.ai/v1") == "Together AI"

def test_fireworks(self):
assert _match_provider("https://api.fireworks.ai/inference/v1") == "Fireworks AI"

def test_deepseek(self):
assert _match_provider("https://api.deepseek.com/v1") == "DeepSeek"

def test_mistral(self):
assert _match_provider("https://api.mistral.ai/v1") == "Mistral AI"

def test_perplexity(self):
assert _match_provider("https://api.perplexity.ai/") == "Perplexity AI"

def test_xai(self):
assert _match_provider("https://api.x.ai/v1") == "xAI"

def test_sambanova(self):
assert _match_provider("https://api.sambanova.ai/v1") == "SambaNova"

def test_cerebras(self):
assert _match_provider("https://api.cerebras.ai/v1") == "Cerebras"

def test_openai_default(self):
assert _match_provider("https://api.openai.com/v1") == "OpenAI"

def test_unknown_url(self):
assert _match_provider("https://my-custom-llm.example.com/v1") == "OpenAI"

def test_case_insensitive(self):
assert _match_provider("https://API.MINIMAX.IO/v1") == "MiniMax"

def test_empty_url(self):
assert _match_provider("") == "OpenAI"


class TestExtractBaseUrl:
"""Tests for _extract_base_url function."""

def test_with_string_base_url(self):
client = MockClient(base_url="https://api.minimax.io/v1/")
resource = MockResource(client=client)
assert _extract_base_url(resource) == "https://api.minimax.io/v1/"

def test_with_url_object(self):
"""Test with URL-like object that has __str__."""

class URLObject:
def __str__(self):
return "https://api.minimax.io/v1/"

client = MockClient(base_url=URLObject())
resource = MockResource(client=client)
assert _extract_base_url(resource) == "https://api.minimax.io/v1/"

def test_no_client(self):
resource = MockResource(client=None)
assert _extract_base_url(resource) is None

def test_no_base_url(self):
client = MockClient(base_url=None)
resource = MockResource(client=client)
assert _extract_base_url(resource) is None

def test_no_client_attribute(self):
"""Test with an object that has no _client attribute."""

class NoClient:
pass

assert _extract_base_url(NoClient()) is None


class TestDetectProviderFromInstance:
"""Tests for detect_provider_from_instance function."""

def test_minimax_provider(self):
client = MockClient(base_url="https://api.minimax.io/v1/")
resource = MockResource(client=client)
assert detect_provider_from_instance(resource) == "MiniMax"

def test_groq_provider(self):
client = MockClient(base_url="https://api.groq.com/openai/v1")
resource = MockResource(client=client)
assert detect_provider_from_instance(resource) == "Groq"

def test_openai_provider(self):
client = MockClient(base_url="https://api.openai.com/v1")
resource = MockResource(client=client)
assert detect_provider_from_instance(resource) == "OpenAI"

def test_none_instance(self):
assert detect_provider_from_instance(None) == "OpenAI"

def test_no_client_attribute(self):
assert detect_provider_from_instance(object()) == "OpenAI"

def test_no_base_url(self):
client = MockClient(base_url=None)
resource = MockResource(client=client)
assert detect_provider_from_instance(resource) == "OpenAI"

def test_deepseek_provider(self):
client = MockClient(base_url="https://api.deepseek.com/v1")
resource = MockResource(client=client)
assert detect_provider_from_instance(resource) == "DeepSeek"

def test_all_registered_providers(self):
"""Verify all providers in the host map are detectable."""
for host, expected_name in _PROVIDER_HOST_MAP.items():
client = MockClient(base_url=f"https://{host}/v1")
resource = MockResource(client=client)
result = detect_provider_from_instance(resource)
assert result == expected_name, (
f"Expected '{expected_name}' for host '{host}', got '{result}'"
)
Loading