Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/art/langgraph/llm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,20 @@
def add_thread(thread_id, base_url, api_key, model):
log_path = f".art/langgraph/{thread_id}"
os.makedirs(os.path.dirname(log_path), exist_ok=True)
logger = FileLogger(log_path)
CURRENT_CONFIG.set(
{
"logger": FileLogger(log_path),
"logger": logger,
"base_url": base_url,
"api_key": api_key,
"model": model,
}
)
return log_path
return logger


def create_messages_from_logs(log_path: str, trajectory: Trajectory):
logs = FileLogger(log_path).load_logs()
def create_messages_from_logs(logger: FileLogger, trajectory: Trajectory):
logs = logger.load_logs()
conversations = []
tools = []

Expand Down Expand Up @@ -95,14 +96,14 @@ def create_messages_from_logs(log_path: str, trajectory: Trajectory):
def wrap_rollout(model, fn):
async def wrapper(*args, **kwargs):
thread_id = str(uuid.uuid4())
log_path = add_thread(
logger = add_thread(
thread_id,
model.inference_base_url,
model.inference_api_key,
model.inference_model_name,
)
result = await fn(*args, **kwargs)
return create_messages_from_logs(log_path, result)
return create_messages_from_logs(logger, result)

return wrapper

Expand Down
22 changes: 5 additions & 17 deletions src/art/langgraph/logging.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,18 @@
import os
import pickle
from typing import Any


class FileLogger:
def __init__(self, filepath):
self.text_path = filepath
self.pickle_path = filepath + ".pkl"
self._logs: list[tuple[str, Any]] = []

def log(self, name, entry):
# Log as readable text
with open(self.text_path, "a") as f:
f.write(f"{name}: {entry}\n")

# Append to pickle log
with open(self.pickle_path, "ab") as pf:
pickle.dump((name, entry), pf)
self._logs.append((name, entry))

def load_logs(self):
"""Load all logs from the pickle file."""
if not os.path.exists(self.pickle_path):
return []
logs = []
with open(self.pickle_path, "rb") as pf:
try:
while True:
logs.append(pickle.load(pf))
except EOFError:
pass
return logs
"""Load all structured logs captured by this logger."""
return list(self._logs)
55 changes: 55 additions & 0 deletions tests/unit/test_langgraph_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from pathlib import Path

import pytest

pytest.importorskip("langchain_openai")
from langchain_core.messages import AIMessage, HumanMessage # noqa: E402

from art import Trajectory # noqa: E402
from art.langgraph.llm_wrapper import create_messages_from_logs # noqa: E402
from art.langgraph.logging import FileLogger


class NonSerializable:
pass


def test_file_logger_keeps_structured_logs_in_memory(tmp_path: Path):
log_path = tmp_path / "rollout"
logger = FileLogger(str(log_path))
entry = {"input": NonSerializable(), "output": NonSerializable()}

logger.log("completion-id", entry)

assert logger.load_logs() == [("completion-id", entry)]
assert not log_path.with_suffix(".pkl").exists()
assert log_path.read_text().startswith("completion-id: ")


def test_file_logger_load_logs_returns_copy(tmp_path: Path):
logger = FileLogger(str(tmp_path / "rollout"))
logger.log("completion-id", {"output": "ok"})

logs = logger.load_logs()
logs.append(("other-id", {"output": "mutated"}))

assert logger.load_logs() == [("completion-id", {"output": "ok"})]


def test_create_messages_from_logs_reads_in_memory_entries(tmp_path: Path):
logger = FileLogger(str(tmp_path / "rollout"))
logger.log(
"completion-id",
{
"input": [HumanMessage(content="hello")],
"output": AIMessage(content="hi"),
"tools": None,
},
)

trajectory = create_messages_from_logs(logger, Trajectory())

assert trajectory.messages() == [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi"},
]
Loading