Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ data = [
# "asyncpg",
# if you want to use trinity with MySQL, please install aiomysql
# "aiomysql",
# if you want to use harbor dataset, please install harbor
# "harbor>=0.15.0",
]
agent = [
"agentscope[tuner]>=1.0.19,<2.0.0"
Expand Down
112 changes: 112 additions & 0 deletions tests/buffer/task_dir_reader_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import shutil
import unittest
from pathlib import Path

from trinity.buffer.buffer import get_buffer_reader
from trinity.buffer.reader.task_dir_reader import TaskDirReader
from trinity.common.config import TasksetConfig
from trinity.common.constants import StorageType


class TestTaskDirReader(unittest.IsolatedAsyncioTestCase):
temp_dir = Path("tmp/test_task_dir_reader")

@classmethod
def setUpClass(cls):
super().setUpClass()
if cls.temp_dir.exists():
shutil.rmtree(cls.temp_dir)
cls.temp_dir.mkdir(parents=True)
cls._write_task_dir("task-a")
cls._write_task_dir("task-b")
cls._write_task_dir(".ignored")

@classmethod
def tearDownClass(cls):
super().tearDownClass()
shutil.rmtree(cls.temp_dir, ignore_errors=True)

@classmethod
def _write_task_dir(cls, dirname: str) -> None:
task_dir = cls.temp_dir / dirname
task_dir.mkdir()
(task_dir / "payload.txt").write_text(dirname)

def _config(self, *, index: int = 0, total_epochs: int = 1) -> TasksetConfig:
config = TasksetConfig(
name="folder_tasks",
storage_type=StorageType.TASK_DIR.value,
path=str(self.temp_dir),
default_workflow_type="simple_workflow",
batch_size=1,
index=index,
total_epochs=total_epochs,
)
config.data_selector = None
return config

async def test_read_task_dirs(self):
reader = get_buffer_reader(self._config())

self.assertIsInstance(reader, TaskDirReader)
self.assertEqual(len(reader), 2)

tasks = await reader.read(batch_size=2)
self.assertEqual(len(tasks), 2)
self.assertEqual(tasks[0].raw_task["task_name"], "task-a")
self.assertEqual(tasks[0].raw_task["taskset_name"], "folder_tasks")
self.assertEqual(tasks[0].raw_task["source_type"], "task_dir")
self.assertTrue(Path(tasks[0].raw_task["task_dir"]).is_absolute())
self.assertEqual(tasks[0].index["index"], 0)
self.assertEqual(tasks[1].index["index"], 1)

async def test_index_file_limits_and_orders_task_dirs(self):
index_dir = Path("tmp/test_task_dir_reader_index")
if index_dir.exists():
shutil.rmtree(index_dir)
index_dir.mkdir(parents=True)
for dirname in ["task-a", "task-b", "task-c"]:
task_dir = index_dir / dirname
task_dir.mkdir()
(task_dir / "payload.txt").write_text(dirname)
(index_dir / TaskDirReader.INDEX_FILENAME).write_text(
"# one task folder per line\n" "task-b\n" "\n" "task-a\n"
)

try:
config = TasksetConfig(
name="indexed_folder_tasks",
storage_type=StorageType.TASK_DIR.value,
path=str(index_dir),
default_workflow_type="simple_workflow",
batch_size=1,
total_epochs=1,
)
config.data_selector = None
reader = get_buffer_reader(config)

self.assertEqual(len(reader), 2)
tasks = await reader.read(batch_size=2)
self.assertEqual(
[task.raw_task["task_name"] for task in tasks],
["task-b", "task-a"],
)
finally:
shutil.rmtree(index_dir, ignore_errors=True)

async def test_resume_offset(self):
reader = get_buffer_reader(self._config(index=1, total_epochs=2))
tasks = await reader.read(batch_size=2)

self.assertEqual(
[task.raw_task["task_name"] for task in tasks],
[
"task-b",
"task-a",
],
)
self.assertEqual(reader.state_dict()["current_index"], 3)


if __name__ == "__main__":
unittest.main()
65 changes: 65 additions & 0 deletions tests/common/harbor_workflow_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import shutil
import unittest
from pathlib import Path

from trinity.common.workflows import HarborWorkflow
from trinity.common.workflows.workflow import Task


class _TestHarborWorkflow(HarborWorkflow):
def run(self):
return []


class TestHarborWorkflow(unittest.TestCase):
temp_dir = Path("tmp/test_harbor_workflow")

def setUp(self):
if self.temp_dir.exists():
shutil.rmtree(self.temp_dir)
self.temp_dir.mkdir(parents=True)

def tearDown(self):
shutil.rmtree(self.temp_dir, ignore_errors=True)

def _write_task(self, name: str = "task-a") -> Path:
task_dir = self.temp_dir / name
task_dir.mkdir()
(task_dir / "task.toml").write_text("[task]\n" f'name = "test-org/{name}"\n')
(task_dir / "instruction.md").write_text("Solve the task.")
return task_dir

def _task(self, task_dir: Path) -> Task:
return Task(
workflow=_TestHarborWorkflow,
raw_task={
"task_id": f"test:{task_dir.name}",
"task_name": task_dir.name,
"task_dir": str(task_dir),
"source_type": "task_dir",
},
)

def test_loads_harbor_task_from_task_dir(self):
task_dir = self._write_task()

workflow = _TestHarborWorkflow(task=self._task(task_dir), model=None)

self.assertEqual(workflow.harbor_task_dir, task_dir.resolve())
self.assertEqual(workflow.harbor_task_name, "task-a")
self.assertEqual(workflow.harbor_task_config.task.name, "test-org/task-a")
self.assertEqual(workflow.harbor_instruction, "Solve the task.")
self.assertTrue(workflow.harbor_task_paths.config_path.exists())
self.assertTrue(workflow.harbor_task_paths_info["has_config"])
self.assertTrue(workflow.harbor_task_paths_info["has_instruction"])

def test_requires_valid_harbor_task_config(self):
task_dir = self.temp_dir / "missing-config"
task_dir.mkdir()

with self.assertRaisesRegex(ValueError, "Failed to load Harbor task config"):
_TestHarborWorkflow(task=self._task(task_dir), model=None)


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions trinity/buffer/reader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"file": "trinity.buffer.reader.file_reader.FileReader",
"queue": "trinity.buffer.reader.queue_reader.QueueReader",
"sql": "trinity.buffer.reader.sql_reader.SQLReader",
"task_dir": "trinity.buffer.reader.task_dir_reader.TaskDirReader",
},
)

Expand Down
194 changes: 194 additions & 0 deletions trinity/buffer/reader/task_dir_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""Directory-backed taskset reader."""

from __future__ import annotations

from pathlib import Path
from typing import List, Optional, Tuple

from trinity.buffer.buffer_reader import BufferReader
from trinity.buffer.schema.formatter import TaskFormatter
from trinity.common.config import StorageConfig


class _TaskDirBatchReader:
def __init__(
self,
task_dirs: list[Path],
name: str,
default_batch_size: int,
total_epochs: int = 1,
offset: int = 0,
drop_last: bool = True,
total_steps: Optional[int] = None,
):
self.task_dirs = task_dirs
self.dataset_size = len(task_dirs)
if self.dataset_size == 0:
raise ValueError(f"Task directory dataset [{name}] is empty and cannot be read.")
self.name = name
self.default_batch_size = default_batch_size
self.drop_last = drop_last
self.current_offset = offset

if total_steps is not None:
self.total_samples = default_batch_size * total_steps
else:
self.total_samples = self.dataset_size * total_epochs

def _sample_for_index(self, index: int) -> dict:
task_dir = self.task_dirs[index]
task_name = task_dir.name
return {
"task_id": f"{self.name}:{index}:{task_name}",
"task_name": task_name,
"task_dir": str(task_dir),
"taskset_name": self.name,
"source_type": "task_dir",
}

def read_batch(self, batch_size: int) -> Tuple[List[dict], List[int]]:
batch, indices = [], []
while len(batch) < batch_size:
if self.current_offset >= self.total_samples:
if not self.drop_last and len(batch) > 0:
break
raise StopIteration

index = self.current_offset % self.dataset_size
batch.append(self._sample_for_index(index))
indices.append(index)
self.current_offset += 1

return batch, indices

def select_batch(self, indices: List[int]) -> List[dict]:
batch = []
for i in indices:
if not 0 <= i < self.dataset_size:
raise IndexError(f"Task directory index {i} out of range.")
if self.current_offset >= self.total_samples:
if not self.drop_last and len(batch) > 0:
break
raise StopIteration
batch.append(self._sample_for_index(int(i)))
self.current_offset += 1
return batch


class TaskDirReader(BufferReader):
"""Read folder-style tasksets as Trinity workflow tasks.

This reader is intentionally format-agnostic. It is useful for datasets where
every task is represented by a directory, such as Harbor-style benchmark
tasks. The workflow owns task parsing; the reader only provides task paths.
"""

INDEX_FILENAME = "index.txt"

def __init__(self, config: StorageConfig):
self.config = config
self.name = config.name
self.read_batch_size = config.batch_size
self.formatter = TaskFormatter(config)
self.dataset = _TaskDirBatchReader(
self._discover_task_dirs(config),
name=config.name,
default_batch_size=self.read_batch_size,
total_epochs=config.total_epochs if not config.is_eval else 1,
offset=config.index,
drop_last=not config.is_eval,
total_steps=config.total_steps if not config.is_eval else None,
)
self._init_selector(config)

def _init_selector(self, config: StorageConfig) -> None:
if config.data_selector is not None:
from trinity.buffer.selector import SELECTORS
from trinity.buffer.selector.selector import BaseSelector

selector_cls = SELECTORS.get(config.data_selector.selector_type)
self.selector: BaseSelector = selector_cls(self.dataset, config.data_selector)
else:
self.selector = None

def _discover_task_dirs(self, config: StorageConfig) -> list[Path]:
if config.path is None:
raise ValueError("TaskDirReader requires `path` to be configured.")

root = Path(config.path).expanduser().resolve()
if not root.exists():
raise FileNotFoundError(f"Task directory dataset path does not exist: {root}")
if not root.is_dir():
raise ValueError(f"Task directory dataset path must be a directory: {root}")

index_file = root / self.INDEX_FILENAME
if index_file.exists():
return self._discover_indexed_task_dirs(root, index_file)

return sorted(
path for path in root.iterdir() if path.is_dir() and not path.name.startswith(".")
)

def _discover_indexed_task_dirs(self, root: Path, index_file: Path) -> list[Path]:
task_dirs = []
for line_number, line in enumerate(index_file.read_text().splitlines(), start=1):
task_path = line.strip()
if not task_path or task_path.startswith("#"):
continue

relative_path = Path(task_path)
if relative_path.is_absolute() or ".." in relative_path.parts:
raise ValueError(
f"Task directory index entry must stay under dataset root: "
f"{index_file}:{line_number}"
)

resolved_path = (root / relative_path).resolve()
if not resolved_path.is_dir():
raise FileNotFoundError(
f"Task directory index entry is not a directory: "
f"{index_file}:{line_number} -> {resolved_path}"
)
task_dirs.append(resolved_path)
return task_dirs

async def read(self, batch_size: Optional[int] = None, **kwargs):
try:
return self._read_sync(batch_size, **kwargs)
except StopIteration as e:
raise StopAsyncIteration from e

def _read_sync(self, batch_size: Optional[int] = None, **kwargs):
batch_size = batch_size or self.read_batch_size
if self.selector is not None:
indices = self.selector.get_indices(batch_size)
samples = self.dataset.select_batch(indices)
else:
samples, indices = self.dataset.read_batch(batch_size)

tasks = []
for sample, index in zip(samples, indices):
task = self.formatter.format(sample)
task.index["index"] = int(index)
task.index["task_name"] = sample["task_name"]
task.index["task_dir"] = sample["task_dir"]
tasks.append(task)
return tasks

def state_dict(self):
if self.selector is not None:
return self.selector.state_dict()
return {"current_index": self.dataset.current_offset}

def load_state_dict(self, state_dict):
if self.selector is not None:
self.selector.load_state_dict(state_dict)
else:
self.dataset.current_offset = state_dict["current_index"]

def feedback(self, **pipeline_metrics):
if self.selector is not None:
self.selector.feedback(**pipeline_metrics)

def __len__(self):
return self.dataset.dataset_size
Loading
Loading