diff --git a/server/src/agent_control_server/config.py b/server/src/agent_control_server/config.py index 82ce14a4..fe481881 100644 --- a/server/src/agent_control_server/config.py +++ b/server/src/agent_control_server/config.py @@ -125,6 +125,41 @@ class AgentControlServerDatabaseConfig(BaseSettings): "DB_DATABASE", ) driver: str = _env_alias_field("psycopg", "AGENT_CONTROL_DB_DRIVER", "DB_DRIVER") + pool_size: int = Field( + default=5, + ge=1, + validation_alias=AliasChoices("AGENT_CONTROL_DB_POOL_SIZE", "DB_POOL_SIZE"), + ) + max_overflow: int = Field( + default=10, + ge=0, + validation_alias=AliasChoices("AGENT_CONTROL_DB_MAX_OVERFLOW", "DB_MAX_OVERFLOW"), + ) + pool_timeout_seconds: float = Field( + default=5.0, + gt=0, + validation_alias=AliasChoices( + "AGENT_CONTROL_DB_POOL_TIMEOUT_SECONDS", + "DB_POOL_TIMEOUT_SECONDS", + ), + ) + connect_timeout_seconds: int = Field( + default=5, + ge=1, + validation_alias=AliasChoices( + "AGENT_CONTROL_DB_CONNECT_TIMEOUT_SECONDS", + "DB_CONNECT_TIMEOUT_SECONDS", + ), + ) + # 0 disables the server-side statement timeout. + statement_timeout_seconds: float = Field( + default=50.0, + ge=0, + validation_alias=AliasChoices( + "AGENT_CONTROL_DB_STATEMENT_TIMEOUT_SECONDS", + "DB_STATEMENT_TIMEOUT_SECONDS", + ), + ) def get_url(self) -> str: """Get database URL, preferring an explicit URL if configured.""" diff --git a/server/src/agent_control_server/db.py b/server/src/agent_control_server/db.py index 4ba28846..e65f76db 100644 --- a/server/src/agent_control_server/db.py +++ b/server/src/agent_control_server/db.py @@ -1,9 +1,16 @@ +import logging from collections.abc import AsyncGenerator +from typing import Any +from prometheus_client import Gauge +from sqlalchemy.engine.url import make_url from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.ext.asyncio.engine import AsyncEngine from sqlalchemy.orm import DeclarativeBase -from .config import db_config +from .config import AgentControlServerDatabaseConfig, db_config + +logger = logging.getLogger(__name__) class Base(DeclarativeBase): @@ -13,11 +20,91 @@ class Base(DeclarativeBase): # Async SQLAlchemy setup for PostgreSQL db_url = db_config.get_url() -async_engine = create_async_engine( - db_url, - echo=False, +SQLALCHEMY_CHECKED_OUT_CONNECTIONS = Gauge( + "agent_control_server_sqlalchemy_checked_out_connections", + "Number of checked out SQLAlchemy connections.", + ["pool_name"], + multiprocess_mode="livesum", ) + +def _supports_queue_pool_config(url: str) -> bool: + """Return whether SQLAlchemy QueuePool kwargs should be applied for this URL.""" + return make_url(url).get_backend_name() != "sqlite" + + +def _build_connect_args( + url: str, + config: AgentControlServerDatabaseConfig, +) -> dict[str, Any]: + """Build driver-level connect args bounding connection setup and statement runtime. + + Pool timeouts only bound how long a request waits for a connection; these + args bound how long a connection takes to establish and how long any one + statement may hold it. Drivers without known timeout args get none. + """ + driver = make_url(url).get_driver_name() + statement_timeout_ms = int(config.statement_timeout_seconds * 1000) + if driver == "psycopg": + connect_args: dict[str, Any] = {"connect_timeout": config.connect_timeout_seconds} + if statement_timeout_ms: + connect_args["options"] = f"-c statement_timeout={statement_timeout_ms}" + return connect_args + if driver == "asyncpg": + connect_args = {"timeout": float(config.connect_timeout_seconds)} + if statement_timeout_ms: + connect_args["server_settings"] = {"statement_timeout": str(statement_timeout_ms)} + return connect_args + return {} + + +def _build_async_engine_kwargs( + url: str, + config: AgentControlServerDatabaseConfig, +) -> dict[str, Any]: + """Build async SQLAlchemy engine kwargs from database config.""" + kwargs: dict[str, Any] = {"echo": False} + if not _supports_queue_pool_config(url): + return kwargs + + kwargs.update( + pool_pre_ping=True, + pool_size=config.pool_size, + max_overflow=config.max_overflow, + pool_timeout=config.pool_timeout_seconds, + pool_reset_on_return="rollback", + ) + connect_args = _build_connect_args(url, config) + if connect_args: + kwargs["connect_args"] = connect_args + else: + parsed_url = make_url(url) + logger.debug( + "No driver-level database timeout connect args configured for backend=%s driver=%s", + parsed_url.get_backend_name(), + parsed_url.get_driver_name(), + ) + return kwargs + + +def _checked_out_connection_count(engine: AsyncEngine) -> float: + """Return the current checked-out connection count when the pool exposes it.""" + checkedout = getattr(engine.sync_engine.pool, "checkedout", None) + if not callable(checkedout): + return 0.0 + return float(checkedout()) + + +def _instrument_connection_pool(engine: AsyncEngine) -> None: + """Report checked-out connections from the async engine's underlying pool.""" + SQLALCHEMY_CHECKED_OUT_CONNECTIONS.labels("default").set_function( + lambda: _checked_out_connection_count(engine) + ) + + +async_engine = create_async_engine(db_url, **_build_async_engine_kwargs(db_url, db_config)) +_instrument_connection_pool(async_engine) + AsyncSessionLocal = async_sessionmaker( bind=async_engine, autoflush=False, diff --git a/server/src/agent_control_server/endpoints/evaluation.py b/server/src/agent_control_server/endpoints/evaluation.py index bc66381f..a31d757d 100644 --- a/server/src/agent_control_server/endpoints/evaluation.py +++ b/server/src/agent_control_server/endpoints/evaluation.py @@ -13,10 +13,9 @@ from agent_control_models.errors import ErrorCode, ValidationErrorItem from fastapi import APIRouter, Depends, Request from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession from ..auth_framework import Operation, Principal, require_operation -from ..db import get_async_db +from ..db import AsyncSessionLocal from ..errors import APIValidationError, NotFoundError from ..logging_utils import get_logger from ..models import Agent @@ -136,6 +135,41 @@ async def _evaluation_context(request: Request) -> dict[str, object]: return {"target_type": target_type, "target_id": target_id} +async def _load_engine_controls( + request: EvaluationRequest, + principal: Principal, +) -> list[ControlAdapter]: + """Load and materialize controls before evaluator execution starts.""" + namespace_key = principal.namespace_key + + async with AsyncSessionLocal() as db: + agent_result = await db.execute( + select(Agent).where( + Agent.name == request.agent_name, + Agent.namespace_key == namespace_key, + ) + ) + agent = agent_result.scalar_one_or_none() + if agent is None: + raise NotFoundError( + error_code=ErrorCode.AGENT_NOT_FOUND, + detail=f"Agent '{request.agent_name}' not found", + resource="Agent", + resource_id=request.agent_name, + hint="Register the agent via initAgent before evaluating.", + ) + + runtime_controls = await ControlService(db).list_runtime_controls_for_agent( + request.agent_name, + namespace_key=namespace_key, + target_type=request.target_type, + target_id=request.target_id, + allow_invalid_step_name_regex=True, + ) + + return [ControlAdapter(c.id, c.name, c.control) for c in runtime_controls] + + @router.post( "", response_model=EvaluationResponse, @@ -144,7 +178,6 @@ async def _evaluation_context(request: Request) -> dict[str, object]: ) async def evaluate( request: EvaluationRequest, - db: AsyncSession = Depends(get_async_db), principal: Principal = Depends( require_operation(Operation.RUNTIME_USE, context_builder=_evaluation_context) ), @@ -163,34 +196,7 @@ async def evaluate( on the server; SDKs reconstruct and emit those events separately through the observability ingestion endpoint. """ - namespace_key = principal.namespace_key - - agent_result = await db.execute( - select(Agent).where( - Agent.name == request.agent_name, - Agent.namespace_key == namespace_key, - ) - ) - agent = agent_result.scalar_one_or_none() - if agent is None: - raise NotFoundError( - error_code=ErrorCode.AGENT_NOT_FOUND, - detail=f"Agent '{request.agent_name}' not found", - resource="Agent", - resource_id=request.agent_name, - hint="Register the agent via initAgent before evaluating.", - ) - - runtime_controls = await ControlService(db).list_runtime_controls_for_agent( - request.agent_name, - namespace_key=namespace_key, - target_type=request.target_type, - target_id=request.target_id, - allow_invalid_step_name_regex=True, - ) - - engine_controls = [ControlAdapter(c.id, c.name, c.control) for c in runtime_controls] - + engine_controls = await _load_engine_controls(request, principal) engine = ControlEngine(engine_controls) try: raw_response = await engine.process(request) diff --git a/server/src/agent_control_server/main.py b/server/src/agent_control_server/main.py index 4aea5a6e..16152824 100644 --- a/server/src/agent_control_server/main.py +++ b/server/src/agent_control_server/main.py @@ -19,7 +19,7 @@ from . import __version__ as server_version from .auth import get_api_key_from_header from .config import observability_settings, settings -from .db import AsyncSessionLocal +from .db import AsyncSessionLocal, async_engine from .endpoints.agents import router as agent_router from .endpoints.auth import router as auth_router from .endpoints.control_bindings import router as control_binding_router @@ -198,6 +198,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: await app.state.event_store.close() logger.info("EventStore closed") + await async_engine.dispose() + logger.info("Database engine disposed") + app = FastAPI( title="Agent Control Server", diff --git a/server/tests/test_config.py b/server/tests/test_config.py index 51f48be1..2a9bd472 100644 --- a/server/tests/test_config.py +++ b/server/tests/test_config.py @@ -89,6 +89,48 @@ def test_db_config_ignores_blank_agent_control_url_and_uses_legacy(monkeypatch) assert config.get_url() == "sqlite:///tmp/legacy.db" +def test_db_config_reads_pool_settings_from_env(monkeypatch) -> None: + # Given: database pool settings are configured via environment variables + monkeypatch.setenv("AGENT_CONTROL_DB_POOL_SIZE", "7") + monkeypatch.setenv("AGENT_CONTROL_DB_MAX_OVERFLOW", "2") + monkeypatch.setenv("AGENT_CONTROL_DB_POOL_TIMEOUT_SECONDS", "3.5") + monkeypatch.setenv("AGENT_CONTROL_DB_CONNECT_TIMEOUT_SECONDS", "4") + monkeypatch.setenv("AGENT_CONTROL_DB_STATEMENT_TIMEOUT_SECONDS", "2.5") + + # When: loading DB config from the environment + config = AgentControlServerDatabaseConfig() + + # Then: the explicit pool settings are used + assert config.pool_size == 7 + assert config.max_overflow == 2 + assert config.pool_timeout_seconds == 3.5 + assert config.connect_timeout_seconds == 4 + assert config.statement_timeout_seconds == 2.5 + + +def test_db_config_pool_defaults(monkeypatch) -> None: + # Given: no pool or timeout settings in the environment + for name in ( + "POOL_SIZE", + "MAX_OVERFLOW", + "POOL_TIMEOUT_SECONDS", + "CONNECT_TIMEOUT_SECONDS", + "STATEMENT_TIMEOUT_SECONDS", + ): + monkeypatch.delenv(f"AGENT_CONTROL_DB_{name}", raising=False) + monkeypatch.delenv(f"DB_{name}", raising=False) + + # When: loading DB config from the environment + config = AgentControlServerDatabaseConfig() + + # Then: the pool is bounded but keeps burst overflow and sane timeouts + assert config.pool_size == 5 + assert config.max_overflow == 10 + assert config.pool_timeout_seconds == 5.0 + assert config.connect_timeout_seconds == 5 + assert config.statement_timeout_seconds == 50.0 + + def test_settings_parses_cors_origins_string() -> None: # Given: a comma-separated CORS origins string settings = Settings(cors_origins="https://a.example, https://b.example") diff --git a/server/tests/test_db.py b/server/tests/test_db.py new file mode 100644 index 00000000..719d0f29 --- /dev/null +++ b/server/tests/test_db.py @@ -0,0 +1,197 @@ +"""Tests for server database engine configuration.""" + +import logging +from typing import cast + +import pytest +from prometheus_client import REGISTRY +from sqlalchemy.ext.asyncio.engine import AsyncEngine + +from agent_control_server.config import AgentControlServerDatabaseConfig +from agent_control_server.db import ( + _build_async_engine_kwargs, + _checked_out_connection_count, + _instrument_connection_pool, + async_engine, +) + + +class _PoolWithCheckedout: + def __init__(self, value: int) -> None: + self.value = value + + def checkedout(self) -> int: + return self.value + + +class _PoolWithoutCheckedout: + pass + + +class _SyncEngine: + def __init__(self, pool: object) -> None: + self.pool = pool + + +class _Engine: + def __init__(self, pool: object) -> None: + self.sync_engine = _SyncEngine(pool) + + +def _engine_with_pool(pool: object) -> AsyncEngine: + return cast(AsyncEngine, _Engine(pool)) + + +def _checked_out_connections_sample() -> float | None: + return REGISTRY.get_sample_value( + "agent_control_server_sqlalchemy_checked_out_connections", + {"pool_name": "default"}, + ) + + +def test_build_async_engine_kwargs_applies_postgres_pool_config() -> None: + # Given: custom PostgreSQL connection pool and timeout settings + config = AgentControlServerDatabaseConfig( + pool_size=7, + max_overflow=2, + pool_timeout_seconds=3.5, + connect_timeout_seconds=4, + statement_timeout_seconds=2.5, + ) + + # When: building async engine kwargs for Postgres + kwargs = _build_async_engine_kwargs( + "postgresql+psycopg://user:password@localhost:5432/agent_control", + config, + ) + + # Then: the engine is configured with a bounded, health-checked pool and timeouts + assert kwargs == { + "echo": False, + "pool_pre_ping": True, + "pool_size": 7, + "max_overflow": 2, + "pool_timeout": 3.5, + "pool_reset_on_return": "rollback", + "connect_args": { + "connect_timeout": 4, + "options": "-c statement_timeout=2500", + }, + } + + +def test_build_async_engine_kwargs_uses_asyncpg_connect_args() -> None: + # Given: timeout settings with an asyncpg driver URL + config = AgentControlServerDatabaseConfig( + connect_timeout_seconds=4, + statement_timeout_seconds=2.5, + ) + + # When: building async engine kwargs for asyncpg + kwargs = _build_async_engine_kwargs( + "postgresql+asyncpg://user:password@localhost:5432/agent_control", + config, + ) + + # Then: the timeouts are expressed as asyncpg connect args + assert kwargs["connect_args"] == { + "timeout": 4.0, + "server_settings": {"statement_timeout": "2500"}, + } + + +def test_build_async_engine_kwargs_can_disable_statement_timeout() -> None: + # Given: the statement timeout is disabled + config = AgentControlServerDatabaseConfig( + connect_timeout_seconds=5, + statement_timeout_seconds=0, + ) + + # When: building async engine kwargs for Postgres + kwargs = _build_async_engine_kwargs( + "postgresql+psycopg://user:password@localhost:5432/agent_control", + config, + ) + + # Then: no statement timeout option is passed to the driver + assert kwargs["connect_args"] == {"connect_timeout": 5} + + +def test_build_async_engine_kwargs_skips_pool_config_for_sqlite() -> None: + # Given: custom pool settings with a SQLite URL + config = AgentControlServerDatabaseConfig( + pool_size=7, + max_overflow=2, + pool_timeout_seconds=3.5, + ) + + # When: building async engine kwargs for SQLite + kwargs = _build_async_engine_kwargs("sqlite+aiosqlite:///tmp/agent-control.db", config) + + # Then: SQLite keeps SQLAlchemy's default local-dev pool behavior + assert kwargs == {"echo": False} + + +def test_build_async_engine_kwargs_logs_when_driver_timeout_args_unknown( + caplog: pytest.LogCaptureFixture, +) -> None: + # Given: a non-sqlite backend whose driver has no known timeout connect args + config = AgentControlServerDatabaseConfig() + + # When: building async engine kwargs + with caplog.at_level(logging.DEBUG, logger="agent_control_server.db"): + kwargs = _build_async_engine_kwargs( + "postgresql+someasync://user:password@localhost:5432/agent_control", + config, + ) + + # Then: pool bounds still apply, but the missing driver-level timeouts are visible + assert kwargs["pool_pre_ping"] is True + assert "connect_args" not in kwargs + assert ( + "No driver-level database timeout connect args configured for " + "backend=postgresql driver=someasync" + ) in caplog.text + + +def test_checked_out_connections_gauge_reports_zero_when_idle() -> None: + # Given: the database module is imported and the pool is instrumented + + # When: reading the gauge while no connection is checked out + value = _checked_out_connections_sample() + + # Then: the series exists and reports zero instead of being absent + assert value == 0.0 + + +def test_checked_out_connections_gauge_reads_pool_state_at_collection() -> None: + # Given: a pool whose checked-out connection count changes over time + pool = _PoolWithCheckedout(3) + + try: + _instrument_connection_pool(_engine_with_pool(pool)) + + # When: the Prometheus registry collects the gauge + value = _checked_out_connections_sample() + + # Then: the gauge reports the pool's current checked-out count + assert value == 3.0 + + # When: the pool state changes without any metric event + pool.value = 1 + + # Then: the next collection reflects the new pool state directly + assert _checked_out_connections_sample() == 1.0 + finally: + _instrument_connection_pool(async_engine) + + +def test_checked_out_connection_count_defaults_to_zero_without_pool_support() -> None: + # Given: a pool implementation without SQLAlchemy QueuePool's checkedout method + engine = _engine_with_pool(_PoolWithoutCheckedout()) + + # When: reading the checked-out connection count + value = _checked_out_connection_count(engine) + + # Then: unsupported pools report zero instead of failing metrics collection + assert value == 0.0 diff --git a/server/tests/test_evaluation_error_handling.py b/server/tests/test_evaluation_error_handling.py index 1df795da..9e577dff 100644 --- a/server/tests/test_evaluation_error_handling.py +++ b/server/tests/test_evaluation_error_handling.py @@ -5,16 +5,17 @@ from agent_control_models import ( ControlMatch, EvaluationRequest, + EvaluationResponse, EvaluatorResult, Step, ) -from fastapi.testclient import TestClient - +from agent_control_server.db import async_engine from agent_control_server.endpoints.evaluation import ( SAFE_EVALUATOR_ERROR, SAFE_EVALUATOR_TIMEOUT_ERROR, _sanitize_control_match, ) +from fastapi.testclient import TestClient from .utils import create_and_assign_policy @@ -327,6 +328,31 @@ async def raise_value_error(*_args, **_kwargs): assert body["errors"][0]["message"] == "Invalid evaluation request or control configuration." +def test_evaluation_releases_db_connection_before_engine_processing( + client: TestClient, + monkeypatch, +) -> None: + """Evaluation should not hold a DB connection while evaluator work runs.""" + agent_name, _ = create_and_assign_policy(client) + checked_out_counts: list[int] = [] + + import agent_control_engine.core as core_module + + async def process_with_pool_assertion(*_args, **_kwargs): + pool = async_engine.sync_engine.pool + checked_out = pool.checkedout() if hasattr(pool, "checkedout") else 0 + checked_out_counts.append(checked_out) + return EvaluationResponse(is_safe=True, confidence=1.0) + + monkeypatch.setattr(core_module.ControlEngine, "process", process_with_pool_assertion) + + payload = Step(type="llm", name="test-step", input="test content", output=None) + req = EvaluationRequest(agent_name=agent_name, step=payload, stage="pre") + resp = client.post("/api/v1/evaluation", json=req.model_dump(mode="json")) + + assert resp.status_code == 200 + assert checked_out_counts == [0] + def test_evaluation_ignores_merge_headers_and_remains_pure(client: TestClient) -> None: """/evaluation should return only semantic results regardless of merge headers.""" diff --git a/server/tests/test_init_agent.py b/server/tests/test_init_agent.py index 2dfe9eaa..fc973118 100644 --- a/server/tests/test_init_agent.py +++ b/server/tests/test_init_agent.py @@ -16,6 +16,29 @@ engine = create_engine(db_config.get_url(), echo=False) +def _collect_route_paths(routes: list[Any], prefix: str = "") -> set[str]: + paths: set[str] = set() + for route in routes: + path = getattr(route, "path", None) + if isinstance(path, str): + paths.add(f"{prefix}{path}") + continue + + include_context = getattr(route, "include_context", None) + included_router = getattr(include_context, "included_router", None) + if included_router is None: + continue + + include_prefix = getattr(include_context, "prefix", "") + paths.update( + _collect_route_paths( + list(getattr(included_router, "routes", [])), + f"{prefix}{include_prefix}", + ) + ) + return paths + + def make_agent_payload( agent_name: str | None = None, name: str = "testagent0001", @@ -48,7 +71,7 @@ def make_agent_payload( def test_init_agent_route_exists(app: FastAPI) -> None: # Given: an application router - paths = {getattr(route, "path", None) for route in app.router.routes} + paths = _collect_route_paths(list(app.router.routes)) # When: inspecting registered paths # (computation done above to gather all paths) # Then: initAgent and agent retrieval endpoints are present