diff --git a/src/art/model.py b/src/art/model.py
index 182207458..d0388939a 100644
--- a/src/art/model.py
+++ b/src/art/model.py
@@ -25,7 +25,10 @@
from .preprocessing.moe_routing import attach_moe_routing_metadata_to_choice
from .trajectories import Trajectory, TrajectoryGroup
from .types import TrainSFTConfig
-from .utils.trajectory_logging import write_trajectory_groups_parquet
+from .utils.trajectory_logging import (
+ calculate_step_std_dev,
+ write_trajectory_groups_parquet,
+)
if TYPE_CHECKING:
from wandb.sdk.wandb_run import Run
@@ -991,11 +994,6 @@ async def log(
group_key = f"group_{metric}"
averages[group_key] = sum(values) / len(values)
- # Calculate average standard deviation of rewards within groups
- from .utils.old_benchmarking.calculate_step_metrics import (
- calculate_step_std_dev,
- )
-
averages[reward_std_dev_key] = calculate_step_std_dev(trajectory_groups)
# Merge in any additional metrics passed directly
diff --git a/src/art/utils/benchmarking/aggregate_trajectories.py b/src/art/utils/benchmarking/aggregate_trajectories.py
index 1b9dbbf03..7abc69f33 100644
--- a/src/art/utils/benchmarking/aggregate_trajectories.py
+++ b/src/art/utils/benchmarking/aggregate_trajectories.py
@@ -21,8 +21,7 @@ async def load_aggregated_trajectories(
Load trajectories and aggregate metrics at the step level.
This function builds on top of load_trajectories to provide step-level
- aggregation similar to load_benchmarked_models, but returns a DataFrame
- instead of custom objects.
+ aggregation as a DataFrame.
Parameters
----------
diff --git a/src/art/utils/old_benchmarking/calculate_step_metrics.py b/src/art/utils/old_benchmarking/calculate_step_metrics.py
deleted file mode 100644
index be8b07f38..000000000
--- a/src/art/utils/old_benchmarking/calculate_step_metrics.py
+++ /dev/null
@@ -1,23 +0,0 @@
-import numpy as np
-
-from art.trajectories import TrajectoryGroup
-
-
-# calculate the average standard deviation of rewards within groups
-def calculate_step_std_dev(trajectory_groups: list[TrajectoryGroup]) -> float:
- std_devs = []
- for group in trajectory_groups:
- group_rewards = []
-
- for trajectory in group.trajectories:
- if isinstance(trajectory, BaseException):
- continue
- group_rewards.append(trajectory.reward)
-
- if len(group_rewards) > 1:
- std_devs.append(np.std(group_rewards))
-
- if len(std_devs) == 0:
- return 0
-
- return sum(std_devs) / len(std_devs)
diff --git a/src/art/utils/old_benchmarking/display_image_grid.py b/src/art/utils/old_benchmarking/display_image_grid.py
deleted file mode 100644
index 541173a3d..000000000
--- a/src/art/utils/old_benchmarking/display_image_grid.py
+++ /dev/null
@@ -1,11 +0,0 @@
-from IPython.display import HTML, display
-
-
-def display_image_grid(image_paths: list[str], images_per_row: int = 2):
- html = f"""
-
- """
- for path in image_paths:
- html += f"

"
- html += "
"
- display(HTML(html))
diff --git a/src/art/utils/old_benchmarking/generate_comparison_table.py b/src/art/utils/old_benchmarking/generate_comparison_table.py
deleted file mode 100644
index 6eb143fc6..000000000
--- a/src/art/utils/old_benchmarking/generate_comparison_table.py
+++ /dev/null
@@ -1,30 +0,0 @@
-import pandas as pd
-
-from .load_benchmarked_models import load_benchmarked_models
-from .types import BenchmarkedModelKey
-
-
-def generate_comparison_table(
- project: str,
- benchmark_keys: list[BenchmarkedModelKey],
- metrics: list[str] = ["reward"],
- api_path: str = "./.art",
-) -> pd.DataFrame:
- benchmarked_models = load_benchmarked_models(
- project, benchmark_keys, metrics, api_path
- )
-
- rows: list[dict[str, str]] = []
-
- for benchmarked_model in benchmarked_models:
- for step in benchmarked_model.steps:
- row = {
- "Model": benchmarked_model.model_key.model,
- "Split": benchmarked_model.model_key.split,
- "Step": f"{step.index:04d}",
- }
- for metric in metrics:
- row[metric] = str(step.metrics.get(metric, "N/A"))
- rows.append(row)
-
- return pd.DataFrame(rows, columns=pd.Index(["Model", "Split", "Step"] + metrics))
diff --git a/src/art/utils/old_benchmarking/generate_line_graphs.py b/src/art/utils/old_benchmarking/generate_line_graphs.py
deleted file mode 100644
index 182c354b1..000000000
--- a/src/art/utils/old_benchmarking/generate_line_graphs.py
+++ /dev/null
@@ -1,107 +0,0 @@
-from datetime import datetime
-import os
-from typing import Literal
-
-try:
- import matplotlib.pyplot as plt
-except ImportError:
- raise ImportError(
- "Plotting dependencies are not installed. Please install them with: "
- "pip install openpipe-art[plotting]"
- )
-
-from ..output_dirs import get_default_art_path
-from .load_benchmarked_models import load_benchmarked_models
-from .types import BenchmarkedModelKey
-
-
-# returns an array of paths to image files, one for each metric
-def generate_line_graphs(
- project: str,
- line_graph_keys: list[BenchmarkedModelKey],
- comparison_keys: list[BenchmarkedModelKey],
- metrics: list[str] = ["reward"],
- x_axis_metric: Literal["step", "time"] = "step",
- api_path: str = "./.art",
-) -> list[str]:
- benchmarks_dir = f"{get_default_art_path()}/{project}/benchmarks"
- os.makedirs(benchmarks_dir, exist_ok=True)
-
- line_graph_models = load_benchmarked_models(
- project, line_graph_keys, metrics, api_path
- )
-
- if x_axis_metric == "time":
-
- def has_all_recorded(model):
- for step in model.steps:
- if step.recorded_at is None:
- print(
- f"WARNING: Model {model.model_key} is missing a recorded_at time for step {step.index}, removing from line graph models"
- )
- return False
- return True
-
- line_graph_models = [
- model for model in line_graph_models if has_all_recorded(model)
- ]
-
- comparison_models = load_benchmarked_models(
- project, comparison_keys, metrics, api_path
- )
- image_paths = []
-
- for metric in metrics:
- plt.figure() # Create a new figure for each metric
- last_x_global: float | None = None
- for model in line_graph_models:
- if x_axis_metric == "time":
- from matplotlib import dates as mdates
-
- x_values_float = [
- float(mdates.date2num(step.recorded_at or datetime.min))
- for step in model.steps
- ]
- else:
- x_values_float = [float(step.index) for step in model.steps]
-
- values = [step.metrics.get(metric, float("nan")) for step in model.steps]
- label = f"{model.model_key.model} {model.model_key.split}"
- plt.plot(x_values_float, values, label=label)
- if x_values_float:
- last_x_global = x_values_float[-1]
-
- # Add a dot only at the last point
- if x_values_float and values:
- plt.scatter(x_values_float[-1], values[-1], s=10)
-
- for model in comparison_models:
- last_step = model.steps[-1]
- # draw horizontal black dashed line at the last step's value
- plt.axhline(y=last_step.metrics[metric], color="black", linestyle="--")
- plt.text(
- last_x_global if last_x_global is not None else 0.0,
- last_step.metrics[metric],
- f"{model.model_key.model} {model.model_key.split}",
- ha="right",
- va="bottom",
- fontsize=8,
- color="black",
- )
-
- plt.title(metric)
- plt.xlabel(x_axis_metric)
- plt.ylabel(metric)
- plt.legend()
- plt.grid(axis="y", color="lightgray", linestyle="--", linewidth=0.25)
-
- # 2025-04-17_22:09:57.865_reward_line_graph.png
- current_time = datetime.now().strftime("%Y-%m-%d_%H:%M:%S.%f")[:-3]
- metric_graph_path = os.path.join(
- benchmarks_dir, f"{current_time}_{metric}_line_graph.png"
- )
- plt.savefig(metric_graph_path)
- plt.close()
- image_paths.append(metric_graph_path)
-
- return image_paths
diff --git a/src/art/utils/old_benchmarking/load_benchmarked_models.py b/src/art/utils/old_benchmarking/load_benchmarked_models.py
deleted file mode 100644
index 215841305..000000000
--- a/src/art/utils/old_benchmarking/load_benchmarked_models.py
+++ /dev/null
@@ -1,119 +0,0 @@
-import copy
-import json
-import os
-
-from art.utils.old_benchmarking.calculate_step_metrics import calculate_step_std_dev
-from art.utils.old_benchmarking.types import (
- BenchmarkedModel,
- BenchmarkedModelKey,
- BenchmarkedModelStep,
-)
-from art.utils.output_dirs import (
- get_output_dir_from_model_properties,
- get_trajectories_split_dir,
-)
-from art.utils.trajectory_migration import deserialize_trajectory_groups
-
-
-def load_benchmarked_models(
- project: str,
- benchmark_keys: list[BenchmarkedModelKey],
- metrics: list[str] = ["reward"],
- api_path: str = "./.art",
-) -> list[BenchmarkedModel]:
- benchmark_keys_copy = copy.deepcopy(benchmark_keys)
-
- benchmarked_models = []
-
- for benchmark_key in benchmark_keys_copy:
- benchmarked_model = BenchmarkedModel(benchmark_key)
- model_output_dir = get_output_dir_from_model_properties(
- project=project, name=benchmark_key.model, art_path=api_path
- )
- split_dir = get_trajectories_split_dir(model_output_dir, benchmark_key.split)
-
- history_logs = []
-
- with open(os.path.join(model_output_dir, "history.jsonl"), "r") as f:
- for line in f:
- # only include logs that have a recorded_at value
- log = json.loads(line)
- if "recorded_at" in log:
- history_logs.append(log)
-
- # get last file name in split_dir
- max_step_index = -1
-
- try:
- max_step_index = int(os.listdir(split_dir)[-1].split(".")[0])
- except Exception as e:
- print(f"Error getting max iteration index for {benchmark_key}")
- raise e
-
- if benchmark_key.step_indices is None:
- # load all iterations
- benchmark_key.step_indices = list(range(max_step_index + 1))
-
- # allow users to count backward from max_step_index using negative indices
- benchmark_key.step_indices = [
- index - 1 + max_step_index if index < 0 else index
- for index in benchmark_key.step_indices
- ]
-
- for index in benchmark_key.step_indices:
- step = BenchmarkedModelStep(index)
-
- # find the most recent log that has a step value equal to index
- for log in reversed(history_logs):
- if log["step"] == index:
- step.recorded_at = log["recorded_at"]
- break
-
- # Try both .jsonl and .yaml extensions
- jsonl_path = os.path.join(split_dir, f"{index:04d}.jsonl")
- yaml_path = os.path.join(split_dir, f"{index:04d}.yaml")
-
- if os.path.exists(jsonl_path):
- file_path = jsonl_path
- elif os.path.exists(yaml_path):
- file_path = yaml_path
- else:
- raise FileNotFoundError(
- f"No trajectory file found for step {index} in {split_dir}"
- )
-
- with open(file_path, "r") as f:
- trajectory_groups = deserialize_trajectory_groups(f.read())
-
- # add "reward" to trajectory metrics to ensure it is treated like a metric
- for trajectory_group in trajectory_groups:
- for trajectory in trajectory_group.trajectories:
- if "reward" not in trajectory.metrics:
- trajectory.metrics["reward"] = trajectory.reward
-
- for metric in metrics:
- group_averages = []
- for trajectory_group in trajectory_groups:
- trajectories_with_metric = [
- trajectory
- for trajectory in trajectory_group.trajectories
- if metric in trajectory.metrics
- ]
- if len(trajectories_with_metric) == 0:
- continue
- average = sum(
- trajectory.metrics[metric]
- for trajectory in trajectories_with_metric
- ) / len(trajectories_with_metric)
- group_averages.append(average)
- if len(group_averages) == 0:
- continue
- step.metrics[metric] = sum(group_averages) / len(group_averages)
-
- step.metrics["reward_std_dev"] = calculate_step_std_dev(trajectory_groups)
-
- benchmarked_model.steps.append(step)
-
- benchmarked_models.append(benchmarked_model)
-
- return benchmarked_models
diff --git a/src/art/utils/old_benchmarking/types.py b/src/art/utils/old_benchmarking/types.py
deleted file mode 100644
index acbfb0b9a..000000000
--- a/src/art/utils/old_benchmarking/types.py
+++ /dev/null
@@ -1,51 +0,0 @@
-from datetime import datetime
-
-
-class BenchmarkedModelKey:
- model: str
- split: str
- step_indices: list[int] | None = None
-
- def __init__(self, model: str, split: str, step_indices: list[int] | None = None):
- self.model = model
- self.split = split
- self.step_indices = step_indices
-
- def __str__(self):
- steps_str = ""
- if self.step_indices is not None:
- if len(self.step_indices) == 1:
- steps_str = f"{self.step_indices[0]}"
- else:
- steps_str = f"{self.step_indices[0]}-{self.step_indices[-1]}"
- return f"{self.model} {self.split} {steps_str}"
-
-
-class BenchmarkedModelStep:
- index: int
- recorded_at: datetime | None = None
- metrics: dict[str, float] = {}
-
- def __init__(self, index: int, metrics: dict[str, float] | None = None):
- self.index = index
- self.metrics = metrics if metrics is not None else {}
-
- def __str__(self):
- return f"{self.index} {self.metrics}"
-
-
-class BenchmarkedModel:
- model_key: BenchmarkedModelKey
- steps: list[BenchmarkedModelStep] = []
-
- def __init__(
- self,
- model_key: BenchmarkedModelKey,
- steps: list[BenchmarkedModelStep] | None = None,
- ):
- self.model_key = model_key
- self.steps = steps if steps is not None else []
-
- def __str__(self):
- steps_str = "\n".join([str(step) for step in self.steps])
- return f"{self.model_key}\n{steps_str}"
diff --git a/src/art/utils/trajectory_logging.py b/src/art/utils/trajectory_logging.py
index 481c7c8c1..fadc9274c 100644
--- a/src/art/utils/trajectory_logging.py
+++ b/src/art/utils/trajectory_logging.py
@@ -18,6 +18,25 @@
from art.trajectories import Trajectory, TrajectoryGroup
+def calculate_step_std_dev(trajectory_groups: list[TrajectoryGroup]) -> float:
+ """Calculate the average population std-dev of rewards within groups."""
+ std_devs: list[float] = []
+ for group in trajectory_groups:
+ group_rewards = [trajectory.reward for trajectory in group.trajectories]
+
+ if len(group_rewards) > 1:
+ mean_reward = sum(group_rewards) / len(group_rewards)
+ variance = sum(
+ (reward - mean_reward) ** 2 for reward in group_rewards
+ ) / len(group_rewards)
+ std_devs.append(variance**0.5)
+
+ if not std_devs:
+ return 0.0
+
+ return sum(std_devs) / len(std_devs)
+
+
def _flatten_message(msg: dict) -> dict:
"""Convert a message or Choice to flat parquet format."""
if "finish_reason" in msg:
diff --git a/tests/unit/test_frontend_logging.py b/tests/unit/test_frontend_logging.py
index 9e72f82d3..e663e7555 100644
--- a/tests/unit/test_frontend_logging.py
+++ b/tests/unit/test_frontend_logging.py
@@ -393,6 +393,7 @@ async def test_standard_metrics_present(self, tmp_path: Path):
# Check reward average is correct
assert entry["val/reward"] == 0.7 # (0.8 + 0.6) / 2
+ assert entry["val/reward_std_dev"] == pytest.approx(0.1)
@pytest.mark.asyncio
async def test_group_metric_aggregation(self, tmp_path: Path):