From 852770a947d3a8cc980db4e93fd5f7b12a696bc1 Mon Sep 17 00:00:00 2001 From: Mengsu Date: Fri, 26 Jun 2026 13:24:52 +0800 Subject: [PATCH] fix(buffer): reshuffle with correct epoch in ShuffleSelector --- tests/buffer/task_scheduler_test.py | 15 +++++++++++++++ trinity/buffer/selector/selector.py | 2 +- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/tests/buffer/task_scheduler_test.py b/tests/buffer/task_scheduler_test.py index ff6e4400e5f..2708ba4bbeb 100644 --- a/tests/buffer/task_scheduler_test.py +++ b/tests/buffer/task_scheduler_test.py @@ -8,6 +8,7 @@ from tests.tools import get_template_config, get_unittest_dataset_config from trinity.buffer.reader import READER from trinity.buffer.reader.file_reader import TaskFileReader +from trinity.buffer.selector.selector import ShuffleSelector from trinity.buffer.task_scheduler import TasksetScheduler, get_taskset_scheduler from trinity.common.config import DataSelectorConfig, FormatConfig, TasksetConfig from trinity.common.workflows.workflow import Task @@ -356,3 +357,17 @@ async def test_task_scheduler_simple(self): task_scheduler_state = task_scheduler.state_dict() self.assertEqual(len(task_scheduler_state), 1) self.assertEqual(task_scheduler_state[0]["current_index"], 12) + + +class TestShuffleSelector(unittest.TestCase): + def test_reshuffles_between_epochs(self): + class _DataSource: + dataset_size = 10 + + selector = ShuffleSelector(_DataSource(), DataSelectorConfig(seed=42)) + epoch0 = selector.get_indices(_DataSource.dataset_size) + epoch1 = selector.get_indices(_DataSource.dataset_size) + self.assertEqual(sorted(epoch0), list(range(_DataSource.dataset_size))) + self.assertEqual(sorted(epoch1), list(range(_DataSource.dataset_size))) + # consecutive epochs must not reuse the same permutation + self.assertNotEqual(epoch0, epoch1) diff --git a/trinity/buffer/selector/selector.py b/trinity/buffer/selector/selector.py index 316a87a6d84..d133cad93b0 100644 --- a/trinity/buffer/selector/selector.py +++ b/trinity/buffer/selector/selector.py @@ -137,6 +137,7 @@ def _get_orders(self) -> List[int]: def get_indices(self, batch_size: int, return_extra_info: bool = False) -> List[int]: start = self.current_index % self.dataset_size end = start + batch_size + self.current_index += batch_size if end <= self.dataset_size: ret = self.orders[start:end] # At end of epoch, reshuffle for next epoch @@ -147,7 +148,6 @@ def get_indices(self, batch_size: int, return_extra_info: bool = False) -> List[ # At end of epoch, reshuffle for next epoch self.orders = self._get_orders() ret += self.orders[: (end - self.dataset_size)] - self.current_index += batch_size return ret def feedback(self, indices: List[int], values: List[float]) -> None: