diff --git a/sentry_sdk/integrations/langchain.py b/sentry_sdk/integrations/langchain.py index d19d9bbdd5..98357a32ae 100644 --- a/sentry_sdk/integrations/langchain.py +++ b/sentry_sdk/integrations/langchain.py @@ -108,6 +108,15 @@ OllamaEmbeddings = None +def _get_ai_system(all_params: "Dict[str, Any]") -> "Optional[str]": + ai_type = all_params.get("_type") + + if not ai_type or not isinstance(ai_type, str): + return None + + return ai_type + + DATA_FIELDS = { "frequency_penalty": SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY, "function_call": SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS, @@ -381,11 +390,9 @@ def on_llm_start( model, ) - ai_type = all_params.get("_type", "") - if "anthropic" in ai_type: - span.set_data(SPANDATA.GEN_AI_SYSTEM, "anthropic") - elif "openai" in ai_type: - span.set_data(SPANDATA.GEN_AI_SYSTEM, "openai") + ai_system = _get_ai_system(all_params) + if ai_system: + span.set_data(SPANDATA.GEN_AI_SYSTEM, ai_system) for key, attribute in DATA_FIELDS.items(): if key in all_params and all_params[key] is not None: @@ -449,11 +456,9 @@ def on_chat_model_start( if model: span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model) - ai_type = all_params.get("_type", "") - if "anthropic" in ai_type: - span.set_data(SPANDATA.GEN_AI_SYSTEM, "anthropic") - elif "openai" in ai_type: - span.set_data(SPANDATA.GEN_AI_SYSTEM, "openai") + ai_system = _get_ai_system(all_params) + if ai_system: + span.set_data(SPANDATA.GEN_AI_SYSTEM, ai_system) agent_name = _get_current_agent() if agent_name: diff --git a/tests/integrations/langchain/test_langchain.py b/tests/integrations/langchain/test_langchain.py index 132da0a9a0..a440a3b0ae 100644 --- a/tests/integrations/langchain/test_langchain.py +++ b/tests/integrations/langchain/test_langchain.py @@ -2000,6 +2000,94 @@ def test_transform_google_file_data(self): } +@pytest.mark.parametrize( + "ai_type,expected_system", + [ + # Real LangChain _type values (from _llm_type properties) + # OpenAI + ("openai-chat", "openai-chat"), + ("openai", "openai"), + # Azure OpenAI + ("azure-openai-chat", "azure-openai-chat"), + ("azure", "azure"), + # Anthropic + ("anthropic-chat", "anthropic-chat"), + # Google + ("vertexai", "vertexai"), + ("chat-google-generative-ai", "chat-google-generative-ai"), + ("google_gemini", "google_gemini"), + # AWS Bedrock + ("amazon_bedrock_chat", "amazon_bedrock_chat"), + ("amazon_bedrock", "amazon_bedrock"), + # Cohere + ("cohere-chat", "cohere-chat"), + # Ollama + ("chat-ollama", "chat-ollama"), + ("ollama-llm", "ollama-llm"), + # Mistral + ("mistralai-chat", "mistralai-chat"), + # Fireworks + ("fireworks-chat", "fireworks-chat"), + ("fireworks", "fireworks"), + # HuggingFace + ("huggingface-chat-wrapper", "huggingface-chat-wrapper"), + # Groq + ("groq-chat", "groq-chat"), + # NVIDIA + ("chat-nvidia-ai-playground", "chat-nvidia-ai-playground"), + # xAI + ("xai-chat", "xai-chat"), + # DeepSeek + ("chat-deepseek", "chat-deepseek"), + # Edge cases + ("", None), + (None, None), + ], +) +def test_langchain_ai_system_detection( + sentry_init, capture_events, ai_type, expected_system +): + sentry_init( + integrations=[LangchainIntegration()], + traces_sample_rate=1.0, + ) + events = capture_events() + + callback = SentryLangchainCallback(max_span_map_size=100, include_prompts=True) + + run_id = "test-ai-system-uuid" + serialized = {"_type": ai_type} if ai_type is not None else {} + prompts = ["Test prompt"] + + with start_transaction(): + callback.on_llm_start( + serialized=serialized, + prompts=prompts, + run_id=run_id, + invocation_params={"_type": ai_type, "model": "test-model"}, + ) + + generation = Mock(text="Test response", message=None) + response = Mock(generations=[[generation]]) + callback.on_llm_end(response=response, run_id=run_id) + + assert len(events) > 0 + tx = events[0] + assert tx["type"] == "transaction" + + llm_spans = [ + span for span in tx.get("spans", []) if span.get("op") == "gen_ai.pipeline" + ] + assert len(llm_spans) > 0 + + llm_span = llm_spans[0] + + if expected_system is not None: + assert llm_span["data"][SPANDATA.GEN_AI_SYSTEM] == expected_system + else: + assert SPANDATA.GEN_AI_SYSTEM not in llm_span.get("data", {}) + + class TestTransformLangchainMessageContent: """Tests for _transform_langchain_message_content function."""