diff --git a/pyproject.toml b/pyproject.toml index 008c00ab23f..f4df153e6ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,8 @@ data = [ # "asyncpg", # if you want to use trinity with MySQL, please install aiomysql # "aiomysql", + # if you want to use harbor dataset, please install harbor + # "harbor>=0.15.0", ] agent = [ "agentscope[tuner]>=1.0.19,<2.0.0" diff --git a/tests/buffer/task_dir_reader_test.py b/tests/buffer/task_dir_reader_test.py new file mode 100644 index 00000000000..40d137f9658 --- /dev/null +++ b/tests/buffer/task_dir_reader_test.py @@ -0,0 +1,112 @@ +import shutil +import unittest +from pathlib import Path + +from trinity.buffer.buffer import get_buffer_reader +from trinity.buffer.reader.task_dir_reader import TaskDirReader +from trinity.common.config import TasksetConfig +from trinity.common.constants import StorageType + + +class TestTaskDirReader(unittest.IsolatedAsyncioTestCase): + temp_dir = Path("tmp/test_task_dir_reader") + + @classmethod + def setUpClass(cls): + super().setUpClass() + if cls.temp_dir.exists(): + shutil.rmtree(cls.temp_dir) + cls.temp_dir.mkdir(parents=True) + cls._write_task_dir("task-a") + cls._write_task_dir("task-b") + cls._write_task_dir(".ignored") + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + shutil.rmtree(cls.temp_dir, ignore_errors=True) + + @classmethod + def _write_task_dir(cls, dirname: str) -> None: + task_dir = cls.temp_dir / dirname + task_dir.mkdir() + (task_dir / "payload.txt").write_text(dirname) + + def _config(self, *, index: int = 0, total_epochs: int = 1) -> TasksetConfig: + config = TasksetConfig( + name="folder_tasks", + storage_type=StorageType.TASK_DIR.value, + path=str(self.temp_dir), + default_workflow_type="simple_workflow", + batch_size=1, + index=index, + total_epochs=total_epochs, + ) + config.data_selector = None + return config + + async def test_read_task_dirs(self): + reader = get_buffer_reader(self._config()) + + self.assertIsInstance(reader, TaskDirReader) + self.assertEqual(len(reader), 2) + + tasks = await reader.read(batch_size=2) + self.assertEqual(len(tasks), 2) + self.assertEqual(tasks[0].raw_task["task_name"], "task-a") + self.assertEqual(tasks[0].raw_task["taskset_name"], "folder_tasks") + self.assertEqual(tasks[0].raw_task["source_type"], "task_dir") + self.assertTrue(Path(tasks[0].raw_task["task_dir"]).is_absolute()) + self.assertEqual(tasks[0].index["index"], 0) + self.assertEqual(tasks[1].index["index"], 1) + + async def test_index_file_limits_and_orders_task_dirs(self): + index_dir = Path("tmp/test_task_dir_reader_index") + if index_dir.exists(): + shutil.rmtree(index_dir) + index_dir.mkdir(parents=True) + for dirname in ["task-a", "task-b", "task-c"]: + task_dir = index_dir / dirname + task_dir.mkdir() + (task_dir / "payload.txt").write_text(dirname) + (index_dir / TaskDirReader.INDEX_FILENAME).write_text( + "# one task folder per line\n" "task-b\n" "\n" "task-a\n" + ) + + try: + config = TasksetConfig( + name="indexed_folder_tasks", + storage_type=StorageType.TASK_DIR.value, + path=str(index_dir), + default_workflow_type="simple_workflow", + batch_size=1, + total_epochs=1, + ) + config.data_selector = None + reader = get_buffer_reader(config) + + self.assertEqual(len(reader), 2) + tasks = await reader.read(batch_size=2) + self.assertEqual( + [task.raw_task["task_name"] for task in tasks], + ["task-b", "task-a"], + ) + finally: + shutil.rmtree(index_dir, ignore_errors=True) + + async def test_resume_offset(self): + reader = get_buffer_reader(self._config(index=1, total_epochs=2)) + tasks = await reader.read(batch_size=2) + + self.assertEqual( + [task.raw_task["task_name"] for task in tasks], + [ + "task-b", + "task-a", + ], + ) + self.assertEqual(reader.state_dict()["current_index"], 3) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/common/harbor_workflow_test.py b/tests/common/harbor_workflow_test.py new file mode 100644 index 00000000000..a206ee9a3a1 --- /dev/null +++ b/tests/common/harbor_workflow_test.py @@ -0,0 +1,65 @@ +import shutil +import unittest +from pathlib import Path + +from trinity.common.workflows import HarborWorkflow +from trinity.common.workflows.workflow import Task + + +class _TestHarborWorkflow(HarborWorkflow): + def run(self): + return [] + + +class TestHarborWorkflow(unittest.TestCase): + temp_dir = Path("tmp/test_harbor_workflow") + + def setUp(self): + if self.temp_dir.exists(): + shutil.rmtree(self.temp_dir) + self.temp_dir.mkdir(parents=True) + + def tearDown(self): + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def _write_task(self, name: str = "task-a") -> Path: + task_dir = self.temp_dir / name + task_dir.mkdir() + (task_dir / "task.toml").write_text("[task]\n" f'name = "test-org/{name}"\n') + (task_dir / "instruction.md").write_text("Solve the task.") + return task_dir + + def _task(self, task_dir: Path) -> Task: + return Task( + workflow=_TestHarborWorkflow, + raw_task={ + "task_id": f"test:{task_dir.name}", + "task_name": task_dir.name, + "task_dir": str(task_dir), + "source_type": "task_dir", + }, + ) + + def test_loads_harbor_task_from_task_dir(self): + task_dir = self._write_task() + + workflow = _TestHarborWorkflow(task=self._task(task_dir), model=None) + + self.assertEqual(workflow.harbor_task_dir, task_dir.resolve()) + self.assertEqual(workflow.harbor_task_name, "task-a") + self.assertEqual(workflow.harbor_task_config.task.name, "test-org/task-a") + self.assertEqual(workflow.harbor_instruction, "Solve the task.") + self.assertTrue(workflow.harbor_task_paths.config_path.exists()) + self.assertTrue(workflow.harbor_task_paths_info["has_config"]) + self.assertTrue(workflow.harbor_task_paths_info["has_instruction"]) + + def test_requires_valid_harbor_task_config(self): + task_dir = self.temp_dir / "missing-config" + task_dir.mkdir() + + with self.assertRaisesRegex(ValueError, "Failed to load Harbor task config"): + _TestHarborWorkflow(task=self._task(task_dir), model=None) + + +if __name__ == "__main__": + unittest.main() diff --git a/trinity/buffer/reader/__init__.py b/trinity/buffer/reader/__init__.py index b6968a71581..f22752791cc 100644 --- a/trinity/buffer/reader/__init__.py +++ b/trinity/buffer/reader/__init__.py @@ -6,6 +6,7 @@ "file": "trinity.buffer.reader.file_reader.FileReader", "queue": "trinity.buffer.reader.queue_reader.QueueReader", "sql": "trinity.buffer.reader.sql_reader.SQLReader", + "task_dir": "trinity.buffer.reader.task_dir_reader.TaskDirReader", }, ) diff --git a/trinity/buffer/reader/task_dir_reader.py b/trinity/buffer/reader/task_dir_reader.py new file mode 100644 index 00000000000..35cbdb0a920 --- /dev/null +++ b/trinity/buffer/reader/task_dir_reader.py @@ -0,0 +1,194 @@ +"""Directory-backed taskset reader.""" + +from __future__ import annotations + +from pathlib import Path +from typing import List, Optional, Tuple + +from trinity.buffer.buffer_reader import BufferReader +from trinity.buffer.schema.formatter import TaskFormatter +from trinity.common.config import StorageConfig + + +class _TaskDirBatchReader: + def __init__( + self, + task_dirs: list[Path], + name: str, + default_batch_size: int, + total_epochs: int = 1, + offset: int = 0, + drop_last: bool = True, + total_steps: Optional[int] = None, + ): + self.task_dirs = task_dirs + self.dataset_size = len(task_dirs) + if self.dataset_size == 0: + raise ValueError(f"Task directory dataset [{name}] is empty and cannot be read.") + self.name = name + self.default_batch_size = default_batch_size + self.drop_last = drop_last + self.current_offset = offset + + if total_steps is not None: + self.total_samples = default_batch_size * total_steps + else: + self.total_samples = self.dataset_size * total_epochs + + def _sample_for_index(self, index: int) -> dict: + task_dir = self.task_dirs[index] + task_name = task_dir.name + return { + "task_id": f"{self.name}:{index}:{task_name}", + "task_name": task_name, + "task_dir": str(task_dir), + "taskset_name": self.name, + "source_type": "task_dir", + } + + def read_batch(self, batch_size: int) -> Tuple[List[dict], List[int]]: + batch, indices = [], [] + while len(batch) < batch_size: + if self.current_offset >= self.total_samples: + if not self.drop_last and len(batch) > 0: + break + raise StopIteration + + index = self.current_offset % self.dataset_size + batch.append(self._sample_for_index(index)) + indices.append(index) + self.current_offset += 1 + + return batch, indices + + def select_batch(self, indices: List[int]) -> List[dict]: + batch = [] + for i in indices: + if not 0 <= i < self.dataset_size: + raise IndexError(f"Task directory index {i} out of range.") + if self.current_offset >= self.total_samples: + if not self.drop_last and len(batch) > 0: + break + raise StopIteration + batch.append(self._sample_for_index(int(i))) + self.current_offset += 1 + return batch + + +class TaskDirReader(BufferReader): + """Read folder-style tasksets as Trinity workflow tasks. + + This reader is intentionally format-agnostic. It is useful for datasets where + every task is represented by a directory, such as Harbor-style benchmark + tasks. The workflow owns task parsing; the reader only provides task paths. + """ + + INDEX_FILENAME = "index.txt" + + def __init__(self, config: StorageConfig): + self.config = config + self.name = config.name + self.read_batch_size = config.batch_size + self.formatter = TaskFormatter(config) + self.dataset = _TaskDirBatchReader( + self._discover_task_dirs(config), + name=config.name, + default_batch_size=self.read_batch_size, + total_epochs=config.total_epochs if not config.is_eval else 1, + offset=config.index, + drop_last=not config.is_eval, + total_steps=config.total_steps if not config.is_eval else None, + ) + self._init_selector(config) + + def _init_selector(self, config: StorageConfig) -> None: + if config.data_selector is not None: + from trinity.buffer.selector import SELECTORS + from trinity.buffer.selector.selector import BaseSelector + + selector_cls = SELECTORS.get(config.data_selector.selector_type) + self.selector: BaseSelector = selector_cls(self.dataset, config.data_selector) + else: + self.selector = None + + def _discover_task_dirs(self, config: StorageConfig) -> list[Path]: + if config.path is None: + raise ValueError("TaskDirReader requires `path` to be configured.") + + root = Path(config.path).expanduser().resolve() + if not root.exists(): + raise FileNotFoundError(f"Task directory dataset path does not exist: {root}") + if not root.is_dir(): + raise ValueError(f"Task directory dataset path must be a directory: {root}") + + index_file = root / self.INDEX_FILENAME + if index_file.exists(): + return self._discover_indexed_task_dirs(root, index_file) + + return sorted( + path for path in root.iterdir() if path.is_dir() and not path.name.startswith(".") + ) + + def _discover_indexed_task_dirs(self, root: Path, index_file: Path) -> list[Path]: + task_dirs = [] + for line_number, line in enumerate(index_file.read_text().splitlines(), start=1): + task_path = line.strip() + if not task_path or task_path.startswith("#"): + continue + + relative_path = Path(task_path) + if relative_path.is_absolute() or ".." in relative_path.parts: + raise ValueError( + f"Task directory index entry must stay under dataset root: " + f"{index_file}:{line_number}" + ) + + resolved_path = (root / relative_path).resolve() + if not resolved_path.is_dir(): + raise FileNotFoundError( + f"Task directory index entry is not a directory: " + f"{index_file}:{line_number} -> {resolved_path}" + ) + task_dirs.append(resolved_path) + return task_dirs + + async def read(self, batch_size: Optional[int] = None, **kwargs): + try: + return self._read_sync(batch_size, **kwargs) + except StopIteration as e: + raise StopAsyncIteration from e + + def _read_sync(self, batch_size: Optional[int] = None, **kwargs): + batch_size = batch_size or self.read_batch_size + if self.selector is not None: + indices = self.selector.get_indices(batch_size) + samples = self.dataset.select_batch(indices) + else: + samples, indices = self.dataset.read_batch(batch_size) + + tasks = [] + for sample, index in zip(samples, indices): + task = self.formatter.format(sample) + task.index["index"] = int(index) + task.index["task_name"] = sample["task_name"] + task.index["task_dir"] = sample["task_dir"] + tasks.append(task) + return tasks + + def state_dict(self): + if self.selector is not None: + return self.selector.state_dict() + return {"current_index": self.dataset.current_offset} + + def load_state_dict(self, state_dict): + if self.selector is not None: + self.selector.load_state_dict(state_dict) + else: + self.dataset.current_offset = state_dict["current_index"] + + def feedback(self, **pipeline_metrics): + if self.selector is not None: + self.selector.feedback(**pipeline_metrics) + + def __len__(self): + return self.dataset.dataset_size diff --git a/trinity/buffer/task_scheduler.py b/trinity/buffer/task_scheduler.py index 0de155dee97..6895b8d1091 100644 --- a/trinity/buffer/task_scheduler.py +++ b/trinity/buffer/task_scheduler.py @@ -123,7 +123,7 @@ def __init__(self, explorer_state: Dict, config: Config): self.read_batch_size = config.buffer.batch_size taskset_configs = config.buffer.explorer_input.tasksets - from trinity.buffer.reader.file_reader import FileReader + from trinity.buffer.buffer_reader import BufferReader taskset_states = explorer_state.get( "taskset_states", [{"current_index": 0}] * len(taskset_configs) @@ -132,10 +132,10 @@ def __init__(self, explorer_state: Dict, config: Config): for taskset_config, taskset_state in zip(taskset_configs, taskset_states): assert not taskset_config.is_eval # assume drop last taskset = get_buffer_reader(taskset_config) - if not isinstance(taskset, FileReader): + if not isinstance(taskset, BufferReader): raise TypeError( f"Taskset '{taskset_config.name}' has an unsupported type '{type(taskset).__name__}'." - f"Currently, only 'FileReader' is supported by TasksetScheduler." + "TasksetScheduler requires a BufferReader implementation." ) taskset.load_state_dict(taskset_state) # Restore any prior state self.tasksets.append(taskset) diff --git a/trinity/common/constants.py b/trinity/common/constants.py index 1c85e0c5f16..3e52b75be42 100644 --- a/trinity/common/constants.py +++ b/trinity/common/constants.py @@ -67,6 +67,7 @@ class StorageType(CaseInsensitiveEnum): SQL = "sql" QUEUE = "queue" FILE = "file" + TASK_DIR = "task_dir" class SyncMethodEnumMeta(CaseInsensitiveEnumMeta): diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py index 7627cdca296..42703167865 100644 --- a/trinity/common/workflows/__init__.py +++ b/trinity/common/workflows/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """Workflow module""" +from trinity.common.workflows.harbor_workflow import HarborWorkflow from trinity.common.workflows.workflow import Task, Workflow from trinity.utils.registry import Registry @@ -56,6 +57,7 @@ ) __all__ = [ + "HarborWorkflow", "Task", "Workflow", "WORKFLOWS", diff --git a/trinity/common/workflows/harbor_workflow.py b/trinity/common/workflows/harbor_workflow.py new file mode 100644 index 00000000000..41219e17eb6 --- /dev/null +++ b/trinity/common/workflows/harbor_workflow.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- +"""Base workflow for Harbor directory tasks.""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING, List, Optional + +from trinity.common.workflows.workflow import Task, Workflow + +if TYPE_CHECKING: + from harbor.models.task.config import TaskConfig + from harbor.models.task.paths import TaskPaths + from harbor.viewer.task_scanner import TaskDefinitionScanner + + from trinity.common.models.model import ModelWrapper + + +class HarborWorkflow(Workflow): + """Base workflow that loads a Harbor task directory during initialization. + + This class does not implement Harbor execution. It only bridges Trinity + folder-style tasks to Harbor's own task parser, so concrete subclasses can + focus on rollout, verification, and experience construction. + """ + + def __init__( + self, + *, + task: Task, + model: ModelWrapper, + auxiliary_models: Optional[List[ModelWrapper]] = None, + ): + super().__init__( + task=task, + model=model, + auxiliary_models=auxiliary_models, + ) + self.harbor_task_dir = self._get_harbor_task_dir(task) + self.harbor_task_name = self.harbor_task_dir.name + ( + self.harbor_scanner, + self.harbor_task_paths, + self.harbor_task_config, + self.harbor_instruction, + self.harbor_task_paths_info, + ) = self._load_harbor_task(self.harbor_task_dir) + + def _get_harbor_task_dir(self, task: Task) -> Path: + if task.raw_task is None: + raise ValueError("HarborWorkflow requires `task.raw_task` to be configured.") + + task_dir = task.raw_task.get("task_dir") + if task_dir is None: + raise ValueError("HarborWorkflow requires `task.raw_task['task_dir']`.") + + task_dir_path = Path(task_dir).expanduser().resolve() + if not task_dir_path.exists(): + raise FileNotFoundError(f"Harbor task directory does not exist: {task_dir_path}") + if not task_dir_path.is_dir(): + raise ValueError(f"Harbor task path must be a directory: {task_dir_path}") + return task_dir_path + + def _load_harbor_task( + self, + task_dir: Path, + ) -> tuple["TaskDefinitionScanner", "TaskPaths", "TaskConfig", str | None, dict[str, bool]]: + try: + from harbor.models.task.paths import TaskPaths + from harbor.viewer.task_scanner import TaskDefinitionScanner + except ImportError as exc: + raise ImportError( + "HarborWorkflow requires the `harbor` package to be installed." + ) from exc + + scanner = TaskDefinitionScanner(task_dir.parent) + task_name = task_dir.name + config = scanner.get_task_config(task_name) + if config is None: + raise ValueError( + f"Failed to load Harbor task config from: {task_dir / TaskPaths.CONFIG_FILENAME}" + ) + + paths = TaskPaths(task_dir) + instruction = scanner.get_instruction(task_name) + paths_info = scanner.get_task_paths_info(task_name) + return scanner, paths, config, instruction, paths_info