Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions server/src/agent_control_server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,41 @@ class AgentControlServerDatabaseConfig(BaseSettings):
"DB_DATABASE",
)
driver: str = _env_alias_field("psycopg", "AGENT_CONTROL_DB_DRIVER", "DB_DRIVER")
pool_size: int = Field(
default=5,
ge=1,
validation_alias=AliasChoices("AGENT_CONTROL_DB_POOL_SIZE", "DB_POOL_SIZE"),
)
max_overflow: int = Field(
default=10,
ge=0,
validation_alias=AliasChoices("AGENT_CONTROL_DB_MAX_OVERFLOW", "DB_MAX_OVERFLOW"),
)
pool_timeout_seconds: float = Field(
default=5.0,
gt=0,
validation_alias=AliasChoices(
"AGENT_CONTROL_DB_POOL_TIMEOUT_SECONDS",
"DB_POOL_TIMEOUT_SECONDS",
),
)
connect_timeout_seconds: int = Field(
default=5,
ge=1,
validation_alias=AliasChoices(
"AGENT_CONTROL_DB_CONNECT_TIMEOUT_SECONDS",
"DB_CONNECT_TIMEOUT_SECONDS",
),
)
# 0 disables the server-side statement timeout.
statement_timeout_seconds: float = Field(
default=50.0,
ge=0,
validation_alias=AliasChoices(
"AGENT_CONTROL_DB_STATEMENT_TIMEOUT_SECONDS",
"DB_STATEMENT_TIMEOUT_SECONDS",
),
)

def get_url(self) -> str:
"""Get database URL, preferring an explicit URL if configured."""
Expand Down
85 changes: 81 additions & 4 deletions server/src/agent_control_server/db.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from collections.abc import AsyncGenerator
from typing import Any

from prometheus_client import Gauge
from sqlalchemy import event
from sqlalchemy.engine.url import make_url
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.ext.asyncio.engine import AsyncEngine
from sqlalchemy.orm import DeclarativeBase

from .config import db_config
from .config import AgentControlServerDatabaseConfig, db_config


class Base(DeclarativeBase):
Expand All @@ -13,11 +18,83 @@ class Base(DeclarativeBase):
# Async SQLAlchemy setup for PostgreSQL
db_url = db_config.get_url()

async_engine = create_async_engine(
db_url,
echo=False,
SQLALCHEMY_CHECKED_OUT_CONNECTIONS = Gauge(
"agent_control_server_sqlalchemy_checked_out_connections",
"Number of checked out SQLAlchemy connections.",
["pool_name"],
multiprocess_mode="livesum",
)


def _supports_queue_pool_config(url: str) -> bool:
"""Return whether SQLAlchemy QueuePool kwargs should be applied for this URL."""
return make_url(url).get_backend_name() != "sqlite"


def _build_connect_args(
url: str,
config: AgentControlServerDatabaseConfig,
) -> dict[str, Any]:
"""Build driver-level connect args bounding connection setup and statement runtime.

Pool timeouts only bound how long a request waits for a connection; these
args bound how long a connection takes to establish and how long any one
statement may hold it. Drivers without known timeout args get none.
"""
driver = make_url(url).get_driver_name()
statement_timeout_ms = int(config.statement_timeout_seconds * 1000)
if driver == "psycopg":
connect_args: dict[str, Any] = {"connect_timeout": config.connect_timeout_seconds}
if statement_timeout_ms:
connect_args["options"] = f"-c statement_timeout={statement_timeout_ms}"
return connect_args
if driver == "asyncpg":
connect_args = {"timeout": float(config.connect_timeout_seconds)}
if statement_timeout_ms:
connect_args["server_settings"] = {"statement_timeout": str(statement_timeout_ms)}
return connect_args
return {}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_build_connect_args returns {} for any driver other than psycopg/asyncpg, while pool kwargs still apply. It's documented ("Drivers without known timeout args get none") and a bare postgresql:// URL would fail create_async_engine anyway, so this is mostly theoretical — but a future valid async driver would lose timeouts without warning. A debug-level log when connect_args come back empty for a non-sqlite backend would make that visible.



def _build_async_engine_kwargs(
url: str,
config: AgentControlServerDatabaseConfig,
) -> dict[str, Any]:
"""Build async SQLAlchemy engine kwargs from database config."""
kwargs: dict[str, Any] = {"echo": False}
if not _supports_queue_pool_config(url):
return kwargs

kwargs.update(
pool_pre_ping=True,
pool_size=config.pool_size,
max_overflow=config.max_overflow,
pool_timeout=config.pool_timeout_seconds,
pool_reset_on_return="rollback",
)
connect_args = _build_connect_args(url, config)
if connect_args:
kwargs["connect_args"] = connect_args
return kwargs


def _instrument_connection_pool(engine: AsyncEngine) -> None:
"""Track checked-out connections from the async engine's underlying pool."""
# Create the labeled series eagerly so idle processes scrape as 0, not absent.
SQLALCHEMY_CHECKED_OUT_CONNECTIONS.labels("default").set(0)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The checkout/checkin event pair is the standard approach, but on a connection that's invalidated rather than returned, the matching checkin may not fire, leaking the counter over a long-lived process. pool_pre_ping=True reduces the frequency. If exactness matters, gauging off pool.checkedout() at scrape time avoids the drift; otherwise fine as-is.


@event.listens_for(engine.sync_engine.pool, "checkin")
def receive_checkin(dbapi_conn: Any, connection_record: Any) -> None:
SQLALCHEMY_CHECKED_OUT_CONNECTIONS.labels("default").dec()

@event.listens_for(engine.sync_engine.pool, "checkout")
def receive_checkout(dbapi_conn: Any, connection_record: Any, connection_proxy: Any) -> None:
SQLALCHEMY_CHECKED_OUT_CONNECTIONS.labels("default").inc()


async_engine = create_async_engine(db_url, **_build_async_engine_kwargs(db_url, db_config))
_instrument_connection_pool(async_engine)

AsyncSessionLocal = async_sessionmaker(
bind=async_engine,
autoflush=False,
Expand Down
68 changes: 37 additions & 31 deletions server/src/agent_control_server/endpoints/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
from agent_control_models.errors import ErrorCode, ValidationErrorItem
from fastapi import APIRouter, Depends, Request
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from ..auth_framework import Operation, Principal, require_operation
from ..db import get_async_db
from ..db import AsyncSessionLocal
from ..errors import APIValidationError, NotFoundError
from ..logging_utils import get_logger
from ..models import Agent
Expand Down Expand Up @@ -136,6 +135,41 @@ async def _evaluation_context(request: Request) -> dict[str, object]:
return {"target_type": target_type, "target_id": target_id}


async def _load_engine_controls(
request: EvaluationRequest,
principal: Principal,
) -> list[ControlAdapter]:
"""Load and materialize controls before evaluator execution starts."""
namespace_key = principal.namespace_key

async with AsyncSessionLocal() as db:
agent_result = await db.execute(
select(Agent).where(
Agent.name == request.agent_name,
Agent.namespace_key == namespace_key,
)
)
agent = agent_result.scalar_one_or_none()
if agent is None:
raise NotFoundError(
error_code=ErrorCode.AGENT_NOT_FOUND,
detail=f"Agent '{request.agent_name}' not found",
resource="Agent",
resource_id=request.agent_name,
hint="Register the agent via initAgent before evaluating.",
)

runtime_controls = await ControlService(db).list_runtime_controls_for_agent(
request.agent_name,
namespace_key=namespace_key,
target_type=request.target_type,
target_id=request.target_id,
allow_invalid_step_name_regex=True,
)

return [ControlAdapter(c.id, c.name, c.control) for c in runtime_controls]


@router.post(
"",
response_model=EvaluationResponse,
Expand All @@ -144,7 +178,6 @@ async def _evaluation_context(request: Request) -> dict[str, object]:
)
async def evaluate(
request: EvaluationRequest,
db: AsyncSession = Depends(get_async_db),
principal: Principal = Depends(
require_operation(Operation.RUNTIME_USE, context_builder=_evaluation_context)
),
Expand All @@ -163,34 +196,7 @@ async def evaluate(
on the server; SDKs reconstruct and emit those events separately through
the observability ingestion endpoint.
"""
namespace_key = principal.namespace_key

agent_result = await db.execute(
select(Agent).where(
Agent.name == request.agent_name,
Agent.namespace_key == namespace_key,
)
)
agent = agent_result.scalar_one_or_none()
if agent is None:
raise NotFoundError(
error_code=ErrorCode.AGENT_NOT_FOUND,
detail=f"Agent '{request.agent_name}' not found",
resource="Agent",
resource_id=request.agent_name,
hint="Register the agent via initAgent before evaluating.",
)

runtime_controls = await ControlService(db).list_runtime_controls_for_agent(
request.agent_name,
namespace_key=namespace_key,
target_type=request.target_type,
target_id=request.target_id,
allow_invalid_step_name_regex=True,
)

engine_controls = [ControlAdapter(c.id, c.name, c.control) for c in runtime_controls]

engine_controls = await _load_engine_controls(request, principal)
engine = ControlEngine(engine_controls)
try:
raw_response = await engine.process(request)
Expand Down
5 changes: 4 additions & 1 deletion server/src/agent_control_server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from . import __version__ as server_version
from .auth import get_api_key_from_header
from .config import observability_settings, settings
from .db import AsyncSessionLocal
from .db import AsyncSessionLocal, async_engine
from .endpoints.agents import router as agent_router
from .endpoints.auth import router as auth_router
from .endpoints.control_bindings import router as control_binding_router
Expand Down Expand Up @@ -198,6 +198,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
await app.state.event_store.close()
logger.info("EventStore closed")

await async_engine.dispose()
logger.info("Database engine disposed")


app = FastAPI(
title="Agent Control Server",
Expand Down
42 changes: 42 additions & 0 deletions server/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,48 @@ def test_db_config_ignores_blank_agent_control_url_and_uses_legacy(monkeypatch)
assert config.get_url() == "sqlite:///tmp/legacy.db"


def test_db_config_reads_pool_settings_from_env(monkeypatch) -> None:
# Given: database pool settings are configured via environment variables
monkeypatch.setenv("AGENT_CONTROL_DB_POOL_SIZE", "7")
monkeypatch.setenv("AGENT_CONTROL_DB_MAX_OVERFLOW", "2")
monkeypatch.setenv("AGENT_CONTROL_DB_POOL_TIMEOUT_SECONDS", "3.5")
monkeypatch.setenv("AGENT_CONTROL_DB_CONNECT_TIMEOUT_SECONDS", "4")
monkeypatch.setenv("AGENT_CONTROL_DB_STATEMENT_TIMEOUT_SECONDS", "2.5")

# When: loading DB config from the environment
config = AgentControlServerDatabaseConfig()

# Then: the explicit pool settings are used
assert config.pool_size == 7
assert config.max_overflow == 2
assert config.pool_timeout_seconds == 3.5
assert config.connect_timeout_seconds == 4
assert config.statement_timeout_seconds == 2.5


def test_db_config_pool_defaults(monkeypatch) -> None:
# Given: no pool or timeout settings in the environment
for name in (
"POOL_SIZE",
"MAX_OVERFLOW",
"POOL_TIMEOUT_SECONDS",
"CONNECT_TIMEOUT_SECONDS",
"STATEMENT_TIMEOUT_SECONDS",
):
monkeypatch.delenv(f"AGENT_CONTROL_DB_{name}", raising=False)
monkeypatch.delenv(f"DB_{name}", raising=False)

# When: loading DB config from the environment
config = AgentControlServerDatabaseConfig()

# Then: the pool is bounded but keeps burst overflow and sane timeouts
assert config.pool_size == 5
assert config.max_overflow == 10
assert config.pool_timeout_seconds == 5.0
assert config.connect_timeout_seconds == 5
assert config.statement_timeout_seconds == 50.0


def test_settings_parses_cors_origins_string() -> None:
# Given: a comma-separated CORS origins string
settings = Settings(cors_origins="https://a.example, https://b.example")
Expand Down
Loading
Loading