Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 7763dc6

Browse files
committed
Add support for Gemini
This adds support for the two types of Gemini endpoints: - Standard API - OpenAI-compatible NOTE: The OpenAI-compatible endpoints are working in both standalone and muxing scenarios. I'm still trying to get the standard-gemini endpoints working. For now, one can test the Gemini integration in Continue as follows: ``` { "title": "Gemini 2.0 Flash", "provider": "openai", "model": "gemini-2.0-flash", "apiBase": "http://127.0.0.1:8989/gemini/v1beta/openai", "apiKey": "MY_API_KEY" } ``` Signed-off-by: Juan Antonio Osorio <ozz@stacklok.com>
1 parent dca424c commit 7763dc6

File tree

8 files changed

+438
-0
lines changed

8 files changed

+438
-0
lines changed

src/codegate/db/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ class ProviderType(str, Enum):
134134
lm_studio = "lm_studio"
135135
llamacpp = "llamacpp"
136136
openrouter = "openrouter"
137+
gemini = "gemini"
137138

138139

139140
class IntermediatePromptWithOutputUsageAlerts(BaseModel):

src/codegate/muxing/adapter.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ def _get_provider_formatted_url(self, model_route: rulematcher.ModelRoute) -> st
3939
return urljoin(model_route.endpoint.endpoint, "/v1")
4040
if model_route.endpoint.provider_type == db_models.ProviderType.openrouter:
4141
return urljoin(model_route.endpoint.endpoint, "/api/v1")
42+
if model_route.endpoint.provider_type == db_models.ProviderType.gemini:
43+
# Gemini API uses /v1beta/openai as the base URL
44+
return urljoin(model_route.endpoint.endpoint, "/v1beta/openai")
4245
return model_route.endpoint.endpoint
4346

4447
def set_destination_info(self, model_route: rulematcher.ModelRoute, data: dict) -> dict:
@@ -209,6 +212,8 @@ def provider_format_funcs(self) -> Dict[str, Callable]:
209212
db_models.ProviderType.openrouter: self._format_openai,
210213
# VLLM is a dialect of OpenAI
211214
db_models.ProviderType.vllm: self._format_openai,
215+
# Gemini provider emits OpenAI-compatible chunks
216+
db_models.ProviderType.gemini: self._format_openai,
212217
}
213218

214219
def _format_ollama(self, chunk: str) -> str:
@@ -245,6 +250,8 @@ def provider_format_funcs(self) -> Dict[str, Callable]:
245250
# VLLM is a dialect of OpenAI
246251
db_models.ProviderType.vllm: self._format_openai,
247252
db_models.ProviderType.anthropic: self._format_antropic,
253+
# Gemini provider emits OpenAI-compatible chunks
254+
db_models.ProviderType.gemini: self._format_openai,
248255
}
249256

250257
def _format_ollama(self, chunk: str) -> str:

src/codegate/providers/crud/crud.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,7 @@ def provider_default_endpoints(provider_type: str) -> str:
453453
defaults = {
454454
"openai": "https://api.openai.com",
455455
"anthropic": "https://api.anthropic.com",
456+
"gemini": "https://generativelanguage.googleapis.com",
456457
}
457458

458459
# If we have a default, we return it
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
"""
2+
Gemini provider for CodeGate.
3+
"""
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
from typing import Any, Dict, Optional
2+
3+
import structlog
4+
from litellm import ChatCompletionRequest
5+
6+
from codegate.providers.litellmshim import sse_stream_generator
7+
from codegate.providers.litellmshim.adapter import (
8+
BaseAdapter,
9+
LiteLLMAdapterInputNormalizer,
10+
LiteLLMAdapterOutputNormalizer,
11+
)
12+
13+
logger = structlog.get_logger("codegate")
14+
15+
16+
class GeminiAdapter(BaseAdapter):
17+
"""
18+
Adapter for Gemini API to translate between Gemini's format and OpenAI's format.
19+
"""
20+
21+
def __init__(self) -> None:
22+
super().__init__(sse_stream_generator)
23+
24+
def translate_completion_input_params(self, kwargs) -> Optional[ChatCompletionRequest]:
25+
"""
26+
Translate Gemini API parameters to OpenAI format.
27+
28+
Gemini API uses a similar format to OpenAI, but with some differences:
29+
- 'contents' instead of 'messages'
30+
- Different role names
31+
- Different parameter names for temperature, etc.
32+
"""
33+
# Make a copy to avoid modifying the original
34+
translated_params = dict(kwargs)
35+
36+
# Handle Gemini-specific parameters
37+
if "contents" in translated_params:
38+
# Convert Gemini 'contents' to OpenAI 'messages'
39+
contents = translated_params.pop("contents")
40+
messages = []
41+
42+
for content in contents:
43+
role = content.get("role", "user")
44+
# Map Gemini roles to OpenAI roles
45+
if role == "model":
46+
role = "assistant"
47+
48+
message = {
49+
"role": role,
50+
"content": content.get("parts", [{"text": ""}])[0].get("text", ""),
51+
}
52+
messages.append(message)
53+
54+
translated_params["messages"] = messages
55+
56+
# Map other parameters
57+
if "temperature" in translated_params:
58+
# Temperature is the same in both APIs
59+
pass
60+
61+
if "topP" in translated_params:
62+
translated_params["top_p"] = translated_params.pop("topP")
63+
64+
if "topK" in translated_params:
65+
translated_params["top_k"] = translated_params.pop("topK")
66+
67+
if "maxOutputTokens" in translated_params:
68+
translated_params["max_tokens"] = translated_params.pop("maxOutputTokens")
69+
70+
# Check if we're using the OpenAI-compatible endpoint
71+
is_openai_compatible = False
72+
if (
73+
"_is_openai_compatible" in translated_params
74+
and translated_params["_is_openai_compatible"]
75+
):
76+
is_openai_compatible = True
77+
# Remove the custom field to avoid sending it to the API
78+
translated_params.pop("_is_openai_compatible")
79+
elif (
80+
"base_url" in translated_params
81+
and translated_params["base_url"]
82+
and "v1beta/openai" in translated_params["base_url"]
83+
):
84+
is_openai_compatible = True
85+
86+
# Apply the appropriate prefix based on the endpoint
87+
if "model" in translated_params:
88+
model_in_request = translated_params["model"]
89+
if is_openai_compatible:
90+
# For OpenAI-compatible endpoint, use 'openai/' prefix
91+
if not model_in_request.startswith("openai/"):
92+
translated_params["model"] = f"openai/{model_in_request}"
93+
logger.debug(
94+
"Using OpenAI-compatible endpoint, prefixed model name with 'openai/': %s",
95+
translated_params["model"],
96+
)
97+
else:
98+
# For native Gemini API, use 'gemini/' prefix
99+
if not model_in_request.startswith("gemini/"):
100+
translated_params["model"] = f"gemini/{model_in_request}"
101+
logger.debug(
102+
"Using native Gemini API, prefixed model name with 'gemini/': %s",
103+
translated_params["model"],
104+
)
105+
106+
return ChatCompletionRequest(**translated_params)
107+
108+
def translate_completion_output_params(self, response: Any) -> Dict:
109+
"""
110+
Translate OpenAI format response to Gemini format.
111+
"""
112+
# For non-streaming responses, we can just return the response as is
113+
# LiteLLM should handle the conversion
114+
return response
115+
116+
def translate_completion_output_params_streaming(self, completion_stream: Any) -> Any:
117+
"""
118+
Translate streaming response from OpenAI format to Gemini format.
119+
"""
120+
# For streaming, we can just return the stream as is
121+
# The stream generator will handle the conversion
122+
return completion_stream
123+
124+
125+
class GeminiInputNormalizer(LiteLLMAdapterInputNormalizer):
126+
"""
127+
Normalizer for Gemini API input.
128+
"""
129+
130+
def __init__(self):
131+
self.adapter = GeminiAdapter()
132+
super().__init__(self.adapter)
133+
134+
135+
class GeminiOutputNormalizer(LiteLLMAdapterOutputNormalizer):
136+
"""
137+
Normalizer for Gemini API output.
138+
"""
139+
140+
def __init__(self):
141+
super().__init__(GeminiAdapter())
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from typing import AsyncIterator, Optional, Union
2+
3+
import structlog
4+
from litellm import ChatCompletionRequest, ModelResponse
5+
6+
from codegate.providers.litellmshim import LiteLLmShim
7+
8+
logger = structlog.get_logger("codegate")
9+
10+
11+
class GeminiCompletion(LiteLLmShim):
12+
"""
13+
GeminiCompletion used by the Gemini provider to execute completions.
14+
15+
This class extends LiteLLmShim to handle Gemini-specific completion logic.
16+
"""
17+
18+
async def execute_completion(
19+
self,
20+
request: ChatCompletionRequest,
21+
base_url: Optional[str],
22+
api_key: Optional[str],
23+
stream: bool = False,
24+
is_fim_request: bool = False,
25+
) -> Union[ModelResponse, AsyncIterator[ModelResponse]]:
26+
"""
27+
Execute the completion request with LiteLLM's API.
28+
29+
Ensures the model name is prefixed with the appropriate prefix to route to Google's API:
30+
- 'openai/' for the OpenAI-compatible endpoint (v1beta/openai)
31+
- 'gemini/' for the native Gemini API
32+
"""
33+
model_in_request = request["model"]
34+
35+
# Check if we're using the OpenAI-compatible endpoint
36+
is_openai_compatible = False
37+
if "_is_openai_compatible" in request and request["_is_openai_compatible"]:
38+
is_openai_compatible = True
39+
elif base_url and "v1beta/openai" in base_url:
40+
is_openai_compatible = True
41+
42+
# Apply the appropriate prefix based on the endpoint
43+
if is_openai_compatible:
44+
# For OpenAI-compatible endpoint, use 'openai/' prefix
45+
if not model_in_request.startswith("openai/"):
46+
request["model"] = f"openai/{model_in_request}"
47+
logger.debug(
48+
"Using OpenAI-compatible endpoint, prefixed model name with 'openai/': %s",
49+
request["model"],
50+
)
51+
else:
52+
# For native Gemini API, use 'gemini/' prefix
53+
if not model_in_request.startswith("gemini/"):
54+
request["model"] = f"gemini/{model_in_request}"
55+
logger.debug(
56+
"Using native Gemini API, prefixed model name with 'gemini/': %s",
57+
request["model"],
58+
)
59+
60+
# Set the API key and base URL
61+
request["api_key"] = api_key
62+
request["base_url"] = base_url
63+
64+
# Execute the completion
65+
return await super().execute_completion(
66+
request=request,
67+
api_key=api_key,
68+
stream=stream,
69+
is_fim_request=is_fim_request,
70+
base_url=base_url,
71+
)

0 commit comments

Comments
 (0)