diff --git a/docs/reference.md b/docs/reference.md index f34ee89ca..f89238fcf 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -648,10 +648,16 @@ vf.print_prompt_completions_sample(outputs: GenerateOutputs, n: int = 3) Pretty-print sample rollouts. ```python -vf.setup_logging(level: str = "INFO") +vf.setup_logging( + level: str = "INFO", + log_format: str | None = None, + date_format: str | None = None, + log_file: str | None = None, + log_file_level: str | None = None, +) ``` -Configure verifiers logging. Set `VF_LOG_LEVEL` env var to change default. +Configure verifiers logging. Set `VF_LOG_LEVEL` env var to change default. Optionally specify `log_file` to write logs to a file in addition to stderr. Use `log_file_level` to set a different log level for the file handler. ```python vf.log_level(level: str | int) diff --git a/pyproject.toml b/pyproject.toml index 07f956b24..7c0dfeec6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "jinja2>=3.1.6", "math-verify>=0.8.0", "mcp>=1.14.1", + "msgpack>=1.1.2", "nest-asyncio>=1.6.0", # for jupyter notebooks "openai>=1.108.1", "openai-agents>=0.0.7", diff --git a/verifiers/envs/env_group.py b/verifiers/envs/env_group.py index cbf2d4724..890a911b8 100644 --- a/verifiers/envs/env_group.py +++ b/verifiers/envs/env_group.py @@ -1,11 +1,10 @@ import time -from typing import TYPE_CHECKING, AsyncContextManager, Mapping, final +from typing import TYPE_CHECKING, Mapping, final from datasets import Dataset, concatenate_datasets -from openai import AsyncOpenAI import verifiers as vf -from verifiers.types import RolloutInput, SamplingArgs +from verifiers.types import ClientConfig, RolloutInput, SamplingArgs if TYPE_CHECKING: pass @@ -37,7 +36,6 @@ def _get_reward_func_names(self) -> list[str]: async def score_rollout( self, state: vf.State, - score_sem: AsyncContextManager, ) -> None: """ Evaluate all reward functions in-place for a single rollout. @@ -56,7 +54,7 @@ async def score_rollout( state["metrics"] = metrics return - await env.rubric.score_rollout(state, score_sem=score_sem) + await env.rubric.score_rollout(state) env_reward = state.get("reward", 0.0) env_metrics = state.get("metrics", {}).copy() if state.get("metrics") else {} @@ -71,7 +69,6 @@ async def score_rollout( async def score_group( self, states: list[vf.State], - score_sem: AsyncContextManager, ) -> None: """ Score a group of rollouts, routing to appropriate environment rubrics based on task. @@ -94,7 +91,7 @@ async def score_group( return # Score all states using the environment's rubric - await env.rubric.score_group(states, score_sem=score_sem) + await env.rubric.score_group(states) # Initialize metrics dict with all reward function names aggregated_metrics: dict[str, list[float]] = { @@ -266,12 +263,12 @@ def add_example_id(example, i): async def rollout( self, input: RolloutInput, - client: AsyncOpenAI, + client_config: ClientConfig, model: str, sampling_args: SamplingArgs | None = None, ) -> vf.State: env = self.get_env_for_task(input["task"]) - return await env.rollout(input, client, model, sampling_args) + return await env.rollout(input, client_config, model, sampling_args) def get_env_for_task(self, task: str) -> vf.Environment: return self.env_map.get(task, self.envs[0]) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index bebca9e85..40d7c55f4 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -1,6 +1,7 @@ import asyncio import atexit import functools +import hashlib import inspect import json import logging @@ -14,7 +15,6 @@ from pathlib import Path from typing import ( TYPE_CHECKING, - AsyncContextManager, Awaitable, Callable, List, @@ -25,7 +25,7 @@ ) from datasets import Dataset -from openai import AsyncOpenAI, BadRequestError, OpenAI +from openai import AsyncOpenAI, BadRequestError from openai.types.chat import ChatCompletion import verifiers as vf @@ -34,6 +34,7 @@ from verifiers.types import ( ChatCompletionToolParam, ChatMessage, + ClientConfig, DatasetBuilder, GenerateMetadata, GenerateOutputs, @@ -46,6 +47,7 @@ State, ) from verifiers.utils.async_utils import maybe_semaphore +from verifiers.utils.client_utils import setup_client from verifiers.utils.error_utils import ErrorChain from verifiers.utils.eval_utils import make_dataset, save_rollout_results from verifiers.utils.message_utils import ( @@ -163,6 +165,8 @@ def __init__( self._cleanup_handlers: list[RolloutCleanup] = [] self._teardown_handlers: list[EnvironmentTeardown] = [] + self._clients: dict[str, AsyncOpenAI] = {} + self.__post_init__() def __post_init__(self): @@ -339,6 +343,16 @@ def _format_dataset_source(self, dataset: Dataset) -> Dataset: else: return self._format_completion_dataset(dataset, map_kwargs=self._map_kwargs) + def _get_client(self, client_config: ClientConfig) -> AsyncOpenAI: + config_hash = hashlib.sha256( + client_config.model_dump_json().encode() + ).hexdigest() + client = self._clients.get(config_hash) + if client is None: + client = setup_client(client_config) + self._clients[config_hash] = client + return client + def _build_dataset(self) -> Dataset | None: """Build and cache the training dataset from source if needed.""" if self._dataset is not None: @@ -429,7 +443,7 @@ def resolve_optional_args( MessageType, ]: """Resolve optional arguments, fallback to state or class defaults.""" - client = client or state["client"] + client = client or self._get_client(state["client_config"]) model = model or state["model"] assert client is not None and model is not None oai_tools = oai_tools or state["oai_tools"] @@ -627,7 +641,7 @@ async def get_model_response_with_tokens( async def init_state( self, input: RolloutInput, - client: AsyncOpenAI, + client_config: ClientConfig, model: str, sampling_args: SamplingArgs | None = None, ) -> State: @@ -644,7 +658,7 @@ async def init_state( if "task" not in state_input: state_input["task"] = self.env_id or "default" state = State(input=RolloutInput(**state_input)) # type: ignore[missing-typed-dict-key] - state["client"] = client + state["client_config"] = client_config state["model"] = model state["sampling_args"] = sampling_args state["is_completed"] = False @@ -674,7 +688,7 @@ async def init_state( async def rollout( self, input: RolloutInput, - client: AsyncOpenAI, + client_config: ClientConfig, model: str, sampling_args: SamplingArgs | None = None, ) -> State: @@ -731,25 +745,21 @@ async def is_completed(self, state: State, **kwargs) -> bool: async def run_rollout( self, input: RolloutInput, - client: AsyncOpenAI, + client_config: ClientConfig, model: str, - gen_sampling_args: SamplingArgs, - gen_sem: AsyncContextManager, - score_sem: AsyncContextManager | None = None, - score: bool = False, + sampling_args: SamplingArgs, + score: bool = True, ) -> State: """Generate and, optionally, score a rollout.""" - async with gen_sem: - state = await self.rollout( - input, - client, - model, - gen_sampling_args, - ) + state = await self.rollout( + input, + client_config, + model, + sampling_args, + ) if score: - assert score_sem is not None if self.score_rollouts: - await self.rubric.score_rollout(state, score_sem=score_sem) + await self.rubric.score_rollout(state) else: await self.rubric.dummy_score_rollout(state) return state @@ -758,29 +768,21 @@ async def run_rollout( async def run_group( self, group_inputs: list[RolloutInput], - client: AsyncOpenAI, + client_config: ClientConfig, model: str, - gen_sampling_args: SamplingArgs, - gen_sem: AsyncContextManager, - score_sem: AsyncContextManager, + sampling_args: SamplingArgs, score: bool = True, **kwargs, ) -> list[State]: """Generate and, optionally, score one group.""" rollout_tasks = [ - self.run_rollout( - input, - client, - model, - gen_sampling_args, - gen_sem, - ) + self.run_rollout(input, client_config, model, sampling_args, score=False) for input in group_inputs ] group_states = await asyncio.gather(*rollout_tasks) if score: if self.score_rollouts: - await self.rubric.score_group(group_states, score_sem=score_sem) + await self.rubric.score_group(group_states) else: await self.rubric.dummy_score_group(group_states) return list(group_states) @@ -789,7 +791,7 @@ def _prepare_rollout_results( self, all_states: list[State], model: str, - client: AsyncOpenAI, + client_config: ClientConfig, state_columns: list[str] | None, results_path: Path | None, gen_sampling_args: SamplingArgs, @@ -828,7 +830,7 @@ def _prepare_rollout_results( env_id=self.env_id, env_args=self.env_args, model=model, - base_url=str(client.base_url) if hasattr(client, "base_url") else "", + base_url=client_config.api_base_url, num_examples=num_unique_examples, rollouts_per_example=rollouts_per_example, sampling_args=gen_sampling_args, @@ -861,12 +863,10 @@ def _prepare_rollout_results( async def generate( self, inputs: Dataset | List[RolloutInput], - client: AsyncOpenAI, + client_config: ClientConfig, model: str, sampling_args: SamplingArgs | None = None, max_concurrent: int = -1, - max_concurrent_generation: int | None = None, - max_concurrent_scoring: int | None = None, results_path: Path | None = None, state_columns: list[str] | None = None, save_results: bool = False, @@ -882,17 +882,8 @@ async def generate( elif isinstance(inputs, list): inputs_list = inputs - # resolve concurrency knobs - gen_limit = max_concurrent_generation - score_limit = max_concurrent_scoring - if gen_limit is None: - gen_limit = max_concurrent - if score_limit is None: - score_limit = max_concurrent - # set up semaphores - gen_sem = await maybe_semaphore(gen_limit) - score_sem = await maybe_semaphore(score_limit) + sem = maybe_semaphore(max_concurrent) # set up sampling args gen_sampling_args = deepcopy(self.sampling_args) @@ -904,22 +895,26 @@ async def generate( # create tasks based on mode tasks: dict[asyncio.Task, int] = {} if independent_scoring: + + async def run_rollout_with_sem(*args, **kwargs) -> State: + async with sem: + return await self.run_rollout(*args, **kwargs) + for i, input_item in enumerate(inputs_list): task = asyncio.create_task( - self.run_rollout( - input_item, - client, - model, - gen_sampling_args, - gen_sem, - score_sem, - score=True, + run_rollout_with_sem( + input_item, client_config, model, gen_sampling_args, sem ) ) tasks[task] = i pbar_total = len(inputs_list) pbar_desc = f"Processing {len(inputs_list)} rollouts" else: + + async def run_group_with_sem(*args, **kwargs) -> list[State]: + async with sem: + return await self.run_group(*args, **kwargs) + input_groups: dict[int, list[RolloutInput]] = {} for input_item in inputs_list: example_id = input_item["example_id"] @@ -930,13 +925,8 @@ async def generate( for i, group in enumerate(group_list): task = asyncio.create_task( - self.run_group( - group, - client, - model, - gen_sampling_args, - gen_sem, - score_sem, + run_group_with_sem( + group, client_config, model, gen_sampling_args, sem ) ) tasks[task] = i @@ -983,7 +973,7 @@ async def generate( temp_results = self._prepare_rollout_results( all_states, model, - client, + client_config, state_columns, results_path, gen_sampling_args, @@ -1003,7 +993,7 @@ async def generate( results = self._prepare_rollout_results( all_states, model, - client, + client_config, state_columns, results_path, gen_sampling_args, @@ -1019,14 +1009,12 @@ async def generate( def generate_sync( self, inputs: Dataset | List[RolloutInput], - client: AsyncOpenAI | OpenAI, + client_config: ClientConfig, **kwargs, ) -> GenerateOutputs: - if isinstance(client, OpenAI): - client = AsyncOpenAI(api_key=client.api_key, base_url=client.base_url) coro = self.generate( inputs, - client=client, + client_config=client_config, **kwargs, ) # check if we're in existing event loop (e.g. Jupyter) @@ -1065,14 +1053,12 @@ def _get_eval_inputs( async def evaluate( self, - client: AsyncOpenAI, + client_config: ClientConfig, model: str, sampling_args: SamplingArgs | None = None, num_examples: int = -1, rollouts_per_example: int = 1, max_concurrent: int = -1, - max_concurrent_generation: int | None = None, - max_concurrent_scoring: int | None = None, results_path: Path | None = None, state_columns: list[str] | None = None, save_results: bool = False, @@ -1086,12 +1072,10 @@ async def evaluate( inputs = self._get_eval_inputs(num_examples, rollouts_per_example) return await self.generate( inputs, - client=client, + client_config=client_config, model=model, sampling_args=sampling_args, max_concurrent=max_concurrent, - max_concurrent_generation=max_concurrent_generation, - max_concurrent_scoring=max_concurrent_scoring, results_path=results_path, state_columns=state_columns, save_results=save_results, @@ -1102,7 +1086,7 @@ async def evaluate( def evaluate_sync( self, - client: OpenAI | AsyncOpenAI, + client_config: ClientConfig, model: str, sampling_args: SamplingArgs | None = None, num_examples: int = -1, @@ -1122,7 +1106,7 @@ def evaluate_sync( inputs = self._get_eval_inputs(num_examples, rollouts_per_example) return self.generate_sync( inputs, - client=client, + client_config=client_config, model=model, sampling_args=sampling_args, max_concurrent=max_concurrent, diff --git a/verifiers/envs/multiturn_env.py b/verifiers/envs/multiturn_env.py index 9841cf6d8..e4433f7a8 100644 --- a/verifiers/envs/multiturn_env.py +++ b/verifiers/envs/multiturn_env.py @@ -2,10 +2,9 @@ from abc import abstractmethod from typing import final -from openai import AsyncOpenAI - import verifiers as vf from verifiers.types import ( + ClientConfig, Messages, ModelResponse, RolloutInput, @@ -129,11 +128,11 @@ async def add_model_response( async def rollout( self, input: RolloutInput, - client: AsyncOpenAI, + client_config: ClientConfig, model: str, sampling_args: SamplingArgs | None = None, ) -> State: - state = await self.init_state(input, client, model, sampling_args) + state = await self.init_state(input, client_config, model, sampling_args) try: state = await self.setup_state(state) except vf.Error as e: diff --git a/verifiers/rl/trainer/orchestrator.py b/verifiers/rl/trainer/orchestrator.py index dab81bbf2..221082585 100644 --- a/verifiers/rl/trainer/orchestrator.py +++ b/verifiers/rl/trainer/orchestrator.py @@ -5,14 +5,13 @@ import time from typing import Any -import httpx import numpy as np from datasets import Dataset -from openai import AsyncOpenAI from pydantic import BaseModel, Field from transformers import PreTrainedTokenizerBase from verifiers import Environment +from verifiers.types import ClientConfig class Microbatch(BaseModel): @@ -73,7 +72,12 @@ def __init__( self.client_api_key = client_api_key self.client_limit = client_limit self.client_timeout = client_timeout - self.client = None # created in worker thread + self.client_config = ClientConfig( + api_base_url=self.client_base_url, + api_key_var=self.client_api_key, + max_connections=self.client_limit, + timeout=self.client_timeout, + ) self.model_name = model_name self.sampling_args = sampling_args self.rollouts_per_example = rollouts_per_example @@ -185,14 +189,6 @@ def generation_worker(self): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) self.worker_loop = loop - self.client = AsyncOpenAI( - base_url=self.client_base_url, - api_key=self.client_api_key, - http_client=httpx.AsyncClient( - limits=httpx.Limits(max_connections=self.client_limit), - timeout=self.client_timeout, - ), - ) try: while not self.stop_event.is_set(): try: @@ -207,7 +203,6 @@ def generation_worker(self): self.logger.error(f"Error in generation worker: {e}") raise e finally: - loop.run_until_complete(self.client.close()) loop.close() asyncio.set_event_loop(None) @@ -216,13 +211,13 @@ async def generate_batch(self, batch_id: int) -> Batch: Generate a single batch asynchronously. """ self.is_generating = True - assert self.client is not None + assert self.client_config is not None start_time = time.time() batch_ds = self.get_dataset_slice(batch_id) repeated_ds = batch_ds.repeat(self.rollouts_per_example) env_results = await self.env.generate( repeated_ds, - client=self.client, + client_config=self.client_config, model=self.model_name, sampling_args=self.sampling_args, max_concurrent=self.max_concurrent, diff --git a/verifiers/rubrics/rubric.py b/verifiers/rubrics/rubric.py index bfe7973df..b8e87d636 100644 --- a/verifiers/rubrics/rubric.py +++ b/verifiers/rubrics/rubric.py @@ -2,7 +2,7 @@ import inspect import logging import time -from typing import Any, AsyncContextManager, cast +from typing import Any, cast import verifiers as vf from verifiers.types import ( @@ -109,7 +109,6 @@ async def _call_individual_reward_func( self, func: RewardFunc, state: State, - score_sem: AsyncContextManager, ) -> float: """ Invoke `func` with only the required arguments. @@ -152,8 +151,7 @@ async def _call(): ans = 0.0 return ans - async with score_sem: - return await _call() + return await _call() # group-level reward helpers def _get_group_reward_func_names(self) -> list[str]: @@ -173,7 +171,6 @@ async def _call_group_reward_func( self, func: GroupRewardFunc, states: list[State], - score_sem: AsyncContextManager, ) -> list[float]: """ Invoke `func` with only the required arguments. @@ -209,15 +206,14 @@ async def _call(): ans = [0.0] * len(states) return ans - async with score_sem: - return await _call() + return await _call() async def dummy_score_rollout(self, state: State): """Score a single rollout with dummy rewards.""" state["reward"] = 0.0 state["metrics"] = {} - async def score_rollout(self, state: State, score_sem: AsyncContextManager): + async def score_rollout(self, state: State): """ Evaluate all reward functions for a single rollout. """ @@ -233,7 +229,6 @@ async def score_rollout(self, state: State, score_sem: AsyncContextManager): await self._call_individual_reward_func( func=func, state=state, - score_sem=score_sem, ) ) rewards = RolloutScore( @@ -261,7 +256,7 @@ async def dummy_score_group(self, states: list[State]): for state in states: await self.dummy_score_rollout(state) - async def score_group(self, states: list[State], score_sem: AsyncContextManager): + async def score_group(self, states: list[State]): """ Score a group of rollouts together. @@ -281,9 +276,7 @@ async def score_group(self, states: list[State], score_sem: AsyncContextManager) if is_group: # GroupRewardFunc: score all states together group_func = cast(GroupRewardFunc, func) - scores = await self._call_group_reward_func( - group_func, states, score_sem=score_sem - ) + scores = await self._call_group_reward_func(group_func, states) func_name = func.__name__ if func_name not in aggregated_metrics: aggregated_metrics[func_name] = [0.0] * num_states @@ -294,9 +287,7 @@ async def score_group(self, states: list[State], score_sem: AsyncContextManager) else: reward_func = cast(RewardFunc, func) score_tasks = [ - self._call_individual_reward_func( - reward_func, state, score_sem=score_sem - ) + self._call_individual_reward_func(reward_func, state) for state in states ] scores = await asyncio.gather(*score_tasks) diff --git a/verifiers/rubrics/rubric_group.py b/verifiers/rubrics/rubric_group.py index c3eff43a1..b0d8bbb71 100644 --- a/verifiers/rubrics/rubric_group.py +++ b/verifiers/rubrics/rubric_group.py @@ -1,4 +1,4 @@ -from typing import Any, AsyncContextManager +from typing import Any from verifiers.rubrics.rubric import Rubric from verifiers.types import ( @@ -52,7 +52,7 @@ def add_class_object(self, name: str, obj: Any): self.logger.warning("Adding class object to the first rubric in the group.") self.rubrics[0].add_class_object(name, obj) - async def score_rollout(self, state: State, score_sem: AsyncContextManager): + async def score_rollout(self, state: State): """ Evaluate all reward functions in-place for a single rollout. """ @@ -63,7 +63,7 @@ async def score_rollout(self, state: State, score_sem: AsyncContextManager): state.get("metrics", {}).copy() if state.get("metrics") else {} ) for rubric in self.rubrics: - await rubric.score_rollout(state, score_sem=score_sem) + await rubric.score_rollout(state) rubric_reward = state.get("reward", 0.0) rubric_metrics = ( state.get("metrics", {}).copy() if state.get("metrics") else {} @@ -77,7 +77,7 @@ async def score_rollout(self, state: State, score_sem: AsyncContextManager): state["reward"] = total_reward state["metrics"] = aggregated_metrics - async def score_group(self, states: list[State], score_sem: AsyncContextManager): + async def score_group(self, states: list[State]): """ Evaluate all reward functions in-place for a group of rollouts. """ @@ -89,7 +89,7 @@ async def score_group(self, states: list[State], score_sem: AsyncContextManager) for state in states ] for rubric in self.rubrics: - await rubric.score_group(states, score_sem=score_sem) + await rubric.score_group(states) for i, state in enumerate(states): rubric_reward = state.get("reward", 0.0) rubric_metrics = ( diff --git a/verifiers/scripts/eval.py b/verifiers/scripts/eval.py index 442c10710..2e1838b16 100644 --- a/verifiers/scripts/eval.py +++ b/verifiers/scripts/eval.py @@ -250,6 +250,13 @@ def main(): default={}, help='Extra environment as JSON object (e.g., \'{"key": "value", "num": 42}\'). Passed to environment constructor.', ) + parser.add_argument( + "--use-env-server", + "-W", + default=False, + action="store_true", + help="Use env servers (will spawn a multi-process env server as a sidecar)", + ) args = parser.parse_args() setup_logging("DEBUG" if args.verbose else os.getenv("VF_LOG_LEVEL", "INFO")) @@ -393,8 +400,7 @@ def resolve_eval_config(raw_env_config: dict) -> EvalConfig: num_examples=num_examples, rollouts_per_example=rollouts_per_example, max_concurrent=env_args.max_concurrent, - max_concurrent_generation=env_args.max_concurrent_generation, - max_concurrent_scoring=env_args.max_concurrent_scoring, + use_env_server=env_args.use_env_server, # logging verbose=env_args.verbose, # saving diff --git a/verifiers/types.py b/verifiers/types.py index 42060d664..3ba803bd6 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -16,7 +16,6 @@ else: from typing import TypedDict -from openai import AsyncOpenAI from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam @@ -99,14 +98,15 @@ class State(dict): INPUT_FIELDS = ["prompt", "answer", "task", "info", "example_id"] # rollout inputs input: RolloutInput - client: AsyncOpenAI + client_config: "ClientConfig" model: str sampling_args: SamplingArgs | None # created during rollout is_completed: bool is_truncated: bool stop_condition: str | None - oai_tools: list[ChatCompletionToolParam] + oai_tools: list[ChatCompletionToolParam] | None + trajectory_id: str trajectory: list[TrajectoryStep] completion: Messages | None reward: float | None @@ -234,10 +234,9 @@ class EvalConfig(BaseModel): num_examples: int rollouts_per_example: int max_concurrent: int - max_concurrent_generation: int | None = None - max_concurrent_scoring: int | None = None independent_scoring: bool = False extra_env_kwargs: dict = {} + use_env_server: bool = False # logging verbose: bool = False # saving diff --git a/verifiers/utils/async_utils.py b/verifiers/utils/async_utils.py index c15a62af3..48f6af7c9 100644 --- a/verifiers/utils/async_utils.py +++ b/verifiers/utils/async_utils.py @@ -20,7 +20,7 @@ async def __aexit__(self, exc_type, exc_value, traceback): return False -async def maybe_semaphore( +def maybe_semaphore( limit: Optional[int] = None, ) -> AsyncContextManager: """ diff --git a/verifiers/utils/eval_utils.py b/verifiers/utils/eval_utils.py index 5325f97d2..2fea0157d 100644 --- a/verifiers/utils/eval_utils.py +++ b/verifiers/utils/eval_utils.py @@ -5,6 +5,7 @@ import time from collections import Counter from contextlib import contextmanager +from multiprocessing import Process from pathlib import Path from typing import cast @@ -26,7 +27,6 @@ MultiEvalConfig, ) from verifiers.utils.async_utils import EventLoopLagMonitor -from verifiers.utils.client_utils import setup_client from verifiers.utils.error_utils import ErrorChain from verifiers.utils.logging_utils import print_prompt_completions_sample, print_time from verifiers.utils.message_utils import messages_to_printable, sanitize_tool_calls @@ -244,20 +244,31 @@ def print_results( print_timing(task_results) -async def run_evaluation(config: EvalConfig) -> GenerateOutputs: +async def run_evaluation(env_idx: int, config: EvalConfig) -> GenerateOutputs: # set up AsyncOpenAI client with high limits to prevent timeouts - client = setup_client(config.client_config) logger.debug( f"Initialized AsyncOpenAI client with base_url: {config.client_config.api_base_url}" ) # load environment - vf_env = vf.load_environment(env_id=config.env_id, **config.env_args) - - # set extra environment kwargs - if config.extra_env_kwargs: - logger.info(f"Setting extra environment kwargs: {config.extra_env_kwargs}") - vf_env.set_kwargs(**config.extra_env_kwargs) + if config.use_env_server: + from verifiers.workers.client.zmq_env_client import ZMQEnvClient + from verifiers.workers.server.zmq_env_server import ZMQEnvServer + + address = f"tcp://127.0.0.1:{5000 + env_idx}" + env_server = Process( + target=ZMQEnvServer.run_server, + args=(config.env_id, config.env_args), + kwargs=dict(address=address), + ) + env_server.start() + env = ZMQEnvClient(address=address) + else: + env_server = None + env = vf.load_environment(env_id=config.env_id, **config.env_args) + if config.extra_env_kwargs: + logger.info(f"Setting extra environment kwargs: {config.extra_env_kwargs}") + env.set_kwargs(**config.extra_env_kwargs) # run evaluation results_path = get_eval_results_path(config) @@ -265,25 +276,36 @@ async def run_evaluation(config: EvalConfig) -> GenerateOutputs: logger.info( f"Configuration: num_examples={config.num_examples}, rollouts_per_example={config.rollouts_per_example}, max_concurrent={config.max_concurrent}" ) - results = await vf_env.evaluate( - client=client, - model=config.model, - sampling_args=config.sampling_args, - num_examples=config.num_examples, - rollouts_per_example=config.rollouts_per_example, - max_concurrent=config.max_concurrent, - max_concurrent_generation=config.max_concurrent_generation, - max_concurrent_scoring=config.max_concurrent_scoring, - results_path=results_path, - state_columns=config.state_columns, - save_results=config.save_results, - save_every=config.save_every, - independent_scoring=config.independent_scoring, - ) + try: + results = await env.evaluate( + client_config=config.client_config, + model=config.model, + sampling_args=config.sampling_args, + num_examples=config.num_examples, + rollouts_per_example=config.rollouts_per_example, + max_concurrent=config.max_concurrent, + results_path=results_path, + state_columns=config.state_columns, + save_results=config.save_results, + save_every=config.save_every, + independent_scoring=config.independent_scoring, + ) + + if config.save_results: + save_rollout_results( + results, config.save_to_hf_hub, config.hf_hub_dataset_name + ) - if config.save_results: - save_rollout_results(results, config.save_to_hf_hub, config.hf_hub_dataset_name) - return results + return results + finally: + # Terminate the env server process if it was started + if env_server is not None: + env_server.terminate() + env_server.join(timeout=5) + if env_server.is_alive(): + logger.warning("Env server did not terminate gracefully, killing it") + env_server.kill() + env_server.join() async def run_multi_evaluation(config: MultiEvalConfig) -> None: @@ -293,7 +315,10 @@ async def run_multi_evaluation(config: MultiEvalConfig) -> None: start_time = time.time() all_results = await asyncio.gather( - *[run_evaluation(eval_config) for eval_config in config.env] + *[ + run_evaluation(env_idx, eval_config) + for env_idx, eval_config in enumerate(config.env) + ] ) end_time = time.time() event_loop_lags = event_loop_lag_monitor.get_lags() diff --git a/verifiers/utils/logging_utils.py b/verifiers/utils/logging_utils.py index 993a61de9..62d902da9 100644 --- a/verifiers/utils/logging_utils.py +++ b/verifiers/utils/logging_utils.py @@ -20,30 +20,47 @@ def setup_logging( level: str = "INFO", log_format: str | None = None, date_format: str | None = None, + log_file: str | None = None, + log_file_level: str | None = None, ) -> None: """ Setup basic logging configuration for the verifiers package. Args: - level: The logging level to use. Defaults to "INFO". + level: The logging level to use for console output. Defaults to "INFO". log_format: Custom log format string. If None, uses default format. date_format: Custom date format string. If None, uses default format. + log_file: Optional path to a log file. If specified, logs will be written to this file. + log_file_level: The logging level for the file handler. If None, uses the same level as console. """ if log_format is None: log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" if date_format is None: date_format = "%Y-%m-%d %H:%M:%S" - handler = logging.StreamHandler(sys.stderr) - handler.setFormatter(logging.Formatter(fmt=log_format, datefmt=date_format)) + formatter = logging.Formatter(fmt=log_format, datefmt=date_format) logger = logging.getLogger(LOGGER_NAME) - # Remove any existing handlers to avoid duplicates + + # remove any existing handlers to avoid duplicates logger.handlers.clear() logger.setLevel(level.upper()) - logger.addHandler(handler) - # Prevent the logger from propagating messages to the root logger + # add console handler (stderr) + console_handler = logging.StreamHandler(sys.stderr) + console_handler.setFormatter(formatter) + console_handler.setLevel(level.upper()) + logger.addHandler(console_handler) + + # add file handler if log_file is specified + if log_file is not None: + file_handler = logging.FileHandler(log_file, mode="a", encoding="utf-8") + file_handler.setFormatter(formatter) + file_level = log_file_level.upper() if log_file_level else level.upper() + file_handler.setLevel(file_level) + logger.addHandler(file_handler) + + # prevent the logger from propagating messages to the root logger logger.propagate = False diff --git a/verifiers/workers/__init__.py b/verifiers/workers/__init__.py new file mode 100644 index 000000000..2a5279e3f --- /dev/null +++ b/verifiers/workers/__init__.py @@ -0,0 +1,31 @@ +from verifiers.workers.client import ZMQEnvClient +from verifiers.workers.server import ZMQEnvServer +from verifiers.workers.types import ( + BaseRequest, + BaseResponse, + EvaluateRequest, + EvaluateResponse, + HealthRequest, + HealthResponse, + RunGroupRequest, + RunGroupResponse, + RunRolloutRequest, + RunRolloutResponse, +) + +__all__ = [ + # types + "BaseRequest", + "BaseResponse", + "HealthRequest", + "HealthResponse", + "RunRolloutRequest", + "RunRolloutResponse", + "RunGroupRequest", + "RunGroupResponse", + "EvaluateRequest", + "EvaluateResponse", + # clients/servers + "ZMQEnvClient", + "ZMQEnvServer", +] diff --git a/verifiers/workers/client/__init__.py b/verifiers/workers/client/__init__.py new file mode 100644 index 000000000..dafdccf29 --- /dev/null +++ b/verifiers/workers/client/__init__.py @@ -0,0 +1,3 @@ +from verifiers.workers.client.zmq_env_client import ZMQEnvClient + +__all__ = ["ZMQEnvClient"] diff --git a/verifiers/workers/client/env_client.py b/verifiers/workers/client/env_client.py new file mode 100644 index 000000000..679aa308f --- /dev/null +++ b/verifiers/workers/client/env_client.py @@ -0,0 +1,60 @@ +from abc import ABC, abstractmethod +from pathlib import Path + +from verifiers.types import ( + ClientConfig, + GenerateOutputs, + RolloutInput, + SamplingArgs, + State, +) + + +class EnvClient(ABC): + def __init__(self, address: str): + self.address = address + + @abstractmethod + async def health(self) -> bool: ... + + @abstractmethod + async def run_rollout( + self, + input: RolloutInput, + client_config: ClientConfig, + model: str, + sampling_args: SamplingArgs, + score: bool = True, + ) -> State: + """Mirrors Environment.run_rollout""" + ... + + @abstractmethod + async def run_group( + self, + group_inputs: list[RolloutInput], + client_config: ClientConfig, + model: str, + sampling_args: SamplingArgs, + score: bool = True, + ) -> list[State]: + """Mirrors Environment.run_group""" + ... + + @abstractmethod + async def evaluate( + self, + client_config: ClientConfig, + model: str, + sampling_args: SamplingArgs, + num_examples: int, + rollouts_per_example: int, + max_concurrent: int, + results_path: Path | None, + state_columns: list[str] | None, + save_results: bool, + save_every: int, + independent_scoring: bool = False, + ) -> GenerateOutputs: + """Mirrors Environment.evaluate""" + ... diff --git a/verifiers/workers/client/zmq_env_client.py b/verifiers/workers/client/zmq_env_client.py new file mode 100644 index 000000000..6e88a052f --- /dev/null +++ b/verifiers/workers/client/zmq_env_client.py @@ -0,0 +1,279 @@ +import asyncio +import logging +import uuid +from pathlib import Path +from typing import cast + +import msgpack +import zmq +import zmq.asyncio + +import verifiers as vf +from verifiers.workers.client.env_client import EnvClient +from verifiers.workers.types import ( + BaseRequest, + BaseResponseT, + EvaluateRequest, + EvaluateResponse, + HealthRequest, + HealthResponse, + RunGroupRequest, + RunGroupResponse, + RunRolloutRequest, + RunRolloutResponse, +) +from verifiers.workers.utils import msgpack_encoder + + +class ZMQEnvClient(EnvClient): + def __init__(self, address: str = "tcp://127.0.0.1:5000"): + super().__init__(address=address) + self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}") + self.address = address + + # DEALER socket for async request/response + self.ctx = zmq.asyncio.Context() + self.socket = self.ctx.socket(zmq.DEALER) + self.socket.setsockopt(zmq.SNDHWM, 10000) + self.socket.setsockopt(zmq.RCVHWM, 10000) + self.socket.setsockopt(zmq.LINGER, 0) + + # TCP keepalive for faster dead server detection + self.socket.setsockopt(zmq.TCP_KEEPALIVE, 1) + self.socket.setsockopt( + zmq.TCP_KEEPALIVE_IDLE, 10 + ) # Start probes after 10s idle + self.socket.setsockopt(zmq.TCP_KEEPALIVE_INTVL, 2) # Probe every 2s + self.socket.setsockopt( + zmq.TCP_KEEPALIVE_CNT, 3 + ) # Give up after 3 failed probes + + self.pending: dict[str, asyncio.Future] = {} + self._receiver_task: asyncio.Task | None = None + + async def health(self) -> bool: + request = HealthRequest() + response = await self._send_request(request, HealthResponse, timeout=1.0) + return response.success + + async def run_rollout( + self, + input: vf.RolloutInput, + client_config: vf.ClientConfig, + model: str, + sampling_args: vf.SamplingArgs, + score: bool = True, + ) -> vf.State: + request = RunRolloutRequest( + input=input, + client_config=client_config, + model=model, + sampling_args=sampling_args, + score=score, + ) + response = await self._send_request( + request, RunRolloutResponse, timeout=36000.0 + ) + assert response.state is not None + return response.state + + async def run_group( + self, + group_inputs: list[vf.RolloutInput], + client_config: vf.ClientConfig, + model: str, + sampling_args: vf.SamplingArgs, + score: bool = True, + ) -> list[vf.State]: + request = RunGroupRequest( + group_inputs=group_inputs, + client_config=client_config, + model=model, + sampling_args=sampling_args, + score=score, + ) + response = await self._send_request(request, RunGroupResponse, timeout=36000.0) + assert response.states is not None + return response.states + + async def evaluate( + self, + client_config: vf.ClientConfig, + model: str, + sampling_args: vf.SamplingArgs, + num_examples: int, + rollouts_per_example: int, + max_concurrent: int, + results_path: Path | None, + state_columns: list[str] | None, + save_results: bool, + save_every: int, + independent_scoring: bool = False, + ) -> vf.GenerateOutputs: + request = EvaluateRequest( + client_config=client_config, + model=model, + sampling_args=sampling_args, + num_examples=num_examples, + rollouts_per_example=rollouts_per_example, + max_concurrent=max_concurrent, + results_path=str(results_path) if results_path else None, + state_columns=state_columns, + save_results=save_results, + save_every=save_every, + independent_scoring=independent_scoring, + ) + response = await self._send_request(request, EvaluateResponse, timeout=36000.0) + assert response.results is not None + return response.results + + def _fail_all_pending(self, reason: str): + """Fail all pending futures with the given reason.""" + for _, future in list(self.pending.items()): + if not future.done(): + future.set_exception(RuntimeError(reason)) + self.pending.clear() + + async def _receive_loop(self): + """Continuously receive responses from environment servers.""" + while True: + try: + # Receive multipart: [request_id, payload] + msg = await self.socket.recv_multipart() + + if len(msg) < 2: + self.logger.error( + f"Invalid message format: expected 2 frames, got {len(msg)}" + ) + continue + + request_id_bytes, response_data = msg[0], msg[1] + request_id = request_id_bytes.decode() + + if request_id in self.pending: + future = self.pending.pop(request_id) + if not future.done(): + try: + response = msgpack.unpackb(response_data, raw=False) + future.set_result(response) + except Exception as unpack_error: + # Unpacking failed - fail the specific future + self.logger.error( + f"Failed to unpack response for request {request_id}: {unpack_error}" + ) + future.set_exception( + RuntimeError( + f"Failed to deserialize response: {unpack_error}" + ) + ) + else: + self.logger.warning( + f"Received response for unknown request_id: {request_id}" + ) + + except asyncio.CancelledError: + break + except zmq.ZMQError as e: + # Socket-level error - fail all pending futures and exit + self.logger.error(f"ZMQ socket error in receive loop: {e}") + self._fail_all_pending(f"ZMQ socket error: {e}") + break + except Exception as e: + self.logger.error( + f"Unexpected error in ZMQ receive loop: {e}", exc_info=True + ) + # Don't break - log and continue for non-socket errors + + async def _start(self): + self._receiver_task = asyncio.create_task(self._receive_loop()) + self.socket.connect(self.address) + self.logger.debug("ZMQ client started") + + async def _send_request( + self, + request: BaseRequest, + response_type: type[BaseResponseT], + timeout: float | None = None, + ) -> BaseResponseT: + """ + Send typed request to environment and parse typed response. + + Args: + request: Pydantic request model (contains action and request_id) + response_type: Expected Pydantic response type + + Returns: + Validated response of type T + """ + # Auto-start receiver if not already running + if self._receiver_task is None: + await self._start() + + # Use request_id from Pydantic model, encode to bytes for ZMQ frame + request_id = uuid.uuid4().hex + + # Serialize using Pydantic + payload_bytes = cast( + bytes, + msgpack.packb( + request.model_dump(mode="python"), + default=msgpack_encoder, + use_bin_type=True, + ), + ) + + future: asyncio.Future[dict] = asyncio.Future() + self.pending[request_id] = future + + await self.socket.send_multipart([request_id.encode(), payload_bytes]) + + try: + raw_response = await asyncio.wait_for(future, timeout=timeout) + except asyncio.TimeoutError: + self.pending.pop(request_id, None) + raise TimeoutError( + f"Environment timeout for {request.request_type} request after {timeout}s" + ) + + # validate response with Pydantic + response = response_type.model_validate(raw_response) + + if not response.success: + raise RuntimeError(f"Server error: {response.error}") + + return response + + +async def main(): + import argparse + + parser = argparse.ArgumentParser(description="ZMQ Environment Client") + parser.add_argument( + "--address", type=str, default="tcp://127.0.0.1:5000", help="ZMQ bind address" + ) + args = parser.parse_args() + + # initialize client + client = ZMQEnvClient(address=args.address) + + is_healthy = await client.health() + assert is_healthy, "ZMQEnvServer is not healthy" + print("Checked that ZMQEnvServer is running and healthy.") + results = await client.evaluate( + client_config=vf.ClientConfig(), + model="openai/gpt-4.1-mini", + sampling_args=vf.SamplingArgs(), + num_examples=5, + rollouts_per_example=3, + max_concurrent=-1, + results_path=None, + state_columns=[], + save_results=False, + save_every=-1, + independent_scoring=True, + ) + print(results) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/verifiers/workers/server/__init__.py b/verifiers/workers/server/__init__.py new file mode 100644 index 000000000..7bb5e988e --- /dev/null +++ b/verifiers/workers/server/__init__.py @@ -0,0 +1,3 @@ +from verifiers.workers.server.zmq_env_server import ZMQEnvServer + +__all__ = ["ZMQEnvServer"] diff --git a/verifiers/workers/server/env_server.py b/verifiers/workers/server/env_server.py new file mode 100644 index 000000000..d4d3155a1 --- /dev/null +++ b/verifiers/workers/server/env_server.py @@ -0,0 +1,133 @@ +import asyncio +import logging +import signal +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any + +import verifiers as vf +from verifiers.workers.types import ( + EvaluateRequest, + EvaluateResponse, + HealthRequest, + HealthResponse, + RunGroupRequest, + RunGroupResponse, + RunRolloutRequest, + RunRolloutResponse, +) + + +class EnvServer(ABC): + """Server that exposes an environment as a service.""" + + def __init__( + self, + # environment + env_id: str, + env_args: dict[str, Any] = {}, + extra_env_kwargs: dict[str, Any] = {}, + log_level: str | None = None, + log_file: str | None = None, + log_file_level: str | None = None, + ): + # setup logging + log_file = log_file or f"logs/{env_id}.log" + Path(log_file).parent.mkdir(parents=True, exist_ok=True) + if log_level is None: + vf.setup_logging(log_file=log_file, log_file_level=log_file_level) + else: + vf.setup_logging( + level=log_level, log_file=log_file, log_file_level=log_file_level + ) + + self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}") + self.logger.info( + f"Initializing {self.__class__.__name__} to serve {env_id} ({env_args=}, {extra_env_kwargs=})" + ) + + self.env_id = env_id + self.env_args = env_args + self.extra_env_kwargs = extra_env_kwargs + + # load environment + self.env = vf.load_environment(env_id, **self.env_args) + if self.extra_env_kwargs: + self.env.set_kwargs(**self.extra_env_kwargs) + + @abstractmethod + async def run(self, stop_event: asyncio.Event | None = None): + pass + + @abstractmethod + async def close(self): + pass + + @classmethod + def run_server(cls, *args, **kwargs): + server = cls(*args, **kwargs) + + async def run_with_graceful_shutdown(): + # setup graceful shutdown for SIGTERM (K8s, Docker, Slurm) and SIGINT (Ctrl+C) + stop_event = asyncio.Event() + + def signal_handler(sig): + server.logger.debug( + f"Received signal {sig.name}, initiating graceful shutdown" + ) + stop_event.set() + + loop = asyncio.get_running_loop() + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, lambda s=sig: signal_handler(s)) + + try: + await server.run(stop_event=stop_event) + finally: + await server.close() + + return asyncio.run(run_with_graceful_shutdown()) + + async def _handle_health(self, _request: HealthRequest) -> HealthResponse: + return HealthResponse() + + async def _handle_run_rollout( + self, request: RunRolloutRequest + ) -> RunRolloutResponse: + state = await self.env.run_rollout( + request.input, + request.client_config, + request.model, + request.sampling_args, + request.score, + ) + return RunRolloutResponse(state=state) + + async def _handle_run_group(self, request: RunGroupRequest) -> RunGroupResponse: + states = await self.env.run_group( + request.group_inputs, + request.client_config, + request.model, + request.sampling_args, + request.score, + ) + return RunGroupResponse(states=states) + + async def _handle_evaluate(self, request: EvaluateRequest) -> EvaluateResponse: + from pathlib import Path + + results_path = Path(request.results_path) if request.results_path else None + results = await self.env.evaluate( + request.client_config, + request.model, + request.sampling_args, + request.num_examples, + request.rollouts_per_example, + request.max_concurrent, + results_path, + request.state_columns, + request.save_results, + request.save_every, + request.independent_scoring, + ) + return EvaluateResponse(results=results) diff --git a/verifiers/workers/server/zmq_env_server.py b/verifiers/workers/server/zmq_env_server.py new file mode 100644 index 000000000..03791a963 --- /dev/null +++ b/verifiers/workers/server/zmq_env_server.py @@ -0,0 +1,194 @@ +import asyncio +import json +from typing import cast + +import msgpack +import zmq +import zmq.asyncio +from openai import AsyncOpenAI + +from verifiers.workers.server.env_server import EnvServer +from verifiers.workers.types import ( + BaseResponse, + EvaluateRequest, + HealthRequest, + RunGroupRequest, + RunRolloutRequest, +) +from verifiers.workers.utils import msgpack_encoder + + +class ZMQEnvServer(EnvServer): + """Server that exposes an environment via ZMQ.""" + + def __init__(self, *args, address: str = "tcp://127.0.0.1:5000", **kwargs): + super().__init__(*args, **kwargs) + self.address = address + + self.ctx = zmq.asyncio.Context() + self.socket = self.ctx.socket(zmq.ROUTER) + self.socket.setsockopt(zmq.SNDHWM, 10000) + self.socket.setsockopt(zmq.RCVHWM, 10000) + self.socket.setsockopt(zmq.LINGER, 0) + self.socket.bind(self.address) + + self.clients: dict[str, AsyncOpenAI] = {} + + async def run(self, stop_event: asyncio.Event | None = None): + self.logger.debug(f"{self.__class__.__name__} started on {self.address}") + + # Create a task to wait for stop signal + stop_task = asyncio.create_task(stop_event.wait()) if stop_event else None + + try: + while True: + # exit gracefully on stop signal + if stop_event and stop_event.is_set(): + self.logger.debug("Stop event received, shutting down gracefully") + break + + try: + # receive with timeout to periodically check stop_event + frames = await asyncio.wait_for( + self.socket.recv_multipart(), + timeout=1.0 if stop_event else None, + ) + + if len(frames) != 3: + self.logger.warning( + f"Invalid message: expected 3 frames, got {len(frames)}" + ) + continue + + client_id, request_id, payload_bytes = frames + + # Process in background with concurrency limit + asyncio.create_task( + self._process_request(client_id, request_id, payload_bytes) + ) + + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + break + except Exception as e: + self.logger.error(f"Error in server loop: {e}", exc_info=True) + finally: + if stop_task and not stop_task.done(): + stop_task.cancel() + + async def close(self): + self.socket.close() + self.ctx.term() + self.logger.debug("Environment server shut down") + + async def _process_request( + self, + client_id: bytes, + request_id_bytes: bytes, + payload_bytes: bytes, + ): + # Default request_id from ZMQ frame, may be overwritten by Pydantic model + request_id = request_id_bytes.decode() + response: BaseResponse + + try: + # deserialize request + raw = msgpack.unpackb(payload_bytes, raw=False) + request_type = raw.get("request_type") + request_id = raw.get("request_id", request_id) + self.logger.debug(f"Got {request_type} request (request_id={request_id})") + + # validate and route to handler + if request_type == "health": + request = HealthRequest.model_validate(raw) + response = await self._handle_health(request) + elif request_type == "run_rollout": + request = RunRolloutRequest.model_validate(raw) + response = await self._handle_run_rollout(request) + elif request_type == "run_group": + request = RunGroupRequest.model_validate(raw) + response = await self._handle_run_group(request) + elif request_type == "evaluate": + request = EvaluateRequest.model_validate(raw) + response = await self._handle_evaluate(request) + else: + self.logger.warning(f"Got unknown request type: {request_type}") + response = BaseResponse( + success=False, error=f"Unknown request type: {request_type}" + ) + + except Exception as e: + self.logger.error(f"Error processing request: {e}", exc_info=True) + response = BaseResponse( + success=False, + error=str(e), + ) + + # serialize response using Pydantic + response_bytes = cast( + bytes, + msgpack.packb( + response.model_dump(mode="python"), + default=msgpack_encoder, + use_bin_type=True, + ), + ) + + # send response: [client_id, request_id, response] + await self.socket.send_multipart( + [client_id, request_id.encode(), response_bytes] + ) + + self.logger.debug( + f"Sent {response.__class__.__name__} (request_id={request_id}, {len(response_bytes)} bytes)" + ) + + +async def main(): + import argparse + + parser = argparse.ArgumentParser(description="ZMQ Environment Server") + parser.add_argument( + "env_id", + type=str, + default="gsm8k", + help="Environment module name(s) (comma-separated) or path to TOML config.", + ) + parser.add_argument( + "--env-args", + "-a", + type=json.loads, + default={}, + help='Environment module arguments as JSON object (e.g., \'{"key": "value", "num": 42}\')', + ) + parser.add_argument( + "--address", + default="tcp://127.0.0.1:5000", + help="ZMQ bind address", + ) + parser.add_argument( + "--verbose", + "-v", + default=False, + action="store_true", + help="Logging level", + ) + args = parser.parse_args() + + # initialize server + server = ZMQEnvServer( + env_id=args.env_id, + env_args=args.env_args, + address=args.address, + log_level="DEBUG" if args.verbose else "INFO", + ) + + try: + await server.run() + finally: + await server.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/verifiers/workers/types.py b/verifiers/workers/types.py new file mode 100644 index 000000000..bcd9daf17 --- /dev/null +++ b/verifiers/workers/types.py @@ -0,0 +1,104 @@ +from typing import Literal, TypeVar + +from pydantic import BaseModel, ConfigDict, field_validator + +from verifiers.types import ( + ClientConfig, + GenerateOutputs, + RolloutInput, + SamplingArgs, + State, +) + + +class BaseRequest(BaseModel): + request_type: str + + +class BaseResponse(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + success: bool = True + error: str | None = None # TODO: type errors later + + +class HealthRequest(BaseRequest): + request_type: Literal["health"] = "health" # type: ignore[override] + + +class HealthResponse(BaseResponse): ... + + +class RunRolloutRequest(BaseRequest): + request_type: Literal["run_rollout"] = "run_rollout" # type: ignore[override] + + input: RolloutInput + client_config: ClientConfig + model: str + sampling_args: SamplingArgs + score: bool = True + + +class RunRolloutResponse(BaseResponse): + state: State | None = None + + @field_validator("state", mode="before") + @classmethod + def convert_state(cls, v: dict | None) -> State | None: + if v is None: + return None + return State(**v) + + +class RunGroupRequest(BaseRequest): + request_type: Literal["run_group"] = "run_group" # type: ignore[override] + + group_inputs: list[RolloutInput] + client_config: ClientConfig + model: str + sampling_args: SamplingArgs + score: bool = True + + +class RunGroupResponse(BaseResponse): + states: list[State] | None = None + + @field_validator("states", mode="before") + @classmethod + def convert_states(cls, v: list[dict] | None) -> list[State] | None: + if v is None: + return [] + return [State(**s) for s in v] + + +class EvaluateRequest(BaseRequest): + request_type: Literal["evaluate"] = "evaluate" # type: ignore[override] + + client_config: ClientConfig + model: str + sampling_args: SamplingArgs + num_examples: int = -1 + rollouts_per_example: int = 1 + max_concurrent: int = -1 + results_path: str | None = None + state_columns: list[str] | None = None + save_results: bool = False + save_every: int = -1 + independent_scoring: bool = False + + +class EvaluateResponse(BaseResponse): + results: GenerateOutputs | None = None + + @field_validator("results", mode="before") + @classmethod + def convert_results_state(cls, v: dict | None) -> dict | None: + if v is None: + return None + if isinstance(v, dict) and "state" in v: + v["state"] = [State(**s) for s in v["state"]] + return v + + +BaseRequestT = TypeVar("BaseRequestT", bound=BaseRequest) +BaseResponseT = TypeVar("BaseResponseT", bound=BaseResponse) diff --git a/verifiers/workers/utils.py b/verifiers/workers/utils.py new file mode 100644 index 000000000..c2daa924c --- /dev/null +++ b/verifiers/workers/utils.py @@ -0,0 +1,40 @@ +import socket +from datetime import date, datetime +from enum import Enum +from pathlib import Path +from uuid import UUID + +import numpy as np + + +def msgpack_encoder(obj): + """ + Custom encoder for non-standard types. + + IMPORTANT: msgpack traverses lists/dicts in optimized C code. This function + is ONLY called for types msgpack doesn't recognize. This avoids the massive + performance penalty of recursing through millions of tokens in Python. + + Handles: Path, UUID, Enum, datetime, Pydantic models, numpy scalars. + Does NOT handle: lists, dicts, basic types (msgpack does this natively in C). + """ + if isinstance(obj, (Path, UUID)): + return str(obj) + elif isinstance(obj, Enum): + return obj.value + elif isinstance(obj, (datetime, date)): + return obj.isoformat() + elif isinstance(obj, (np.integer, np.floating)): + return obj.item() + elif hasattr(obj, "model_dump"): + return obj.model_dump() + else: + # raise on unknown types to make issues visible + raise TypeError(f"Object of type {type(obj)} is not msgpack serializable") + + +def get_free_port() -> int: + """Get a free port on the system.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("localhost", 0)) + return s.getsockname()[1]