Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tests/buffer/task_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tests.tools import get_template_config, get_unittest_dataset_config
from trinity.buffer.reader import READER
from trinity.buffer.reader.file_reader import TaskFileReader
from trinity.buffer.selector.selector import ShuffleSelector
from trinity.buffer.task_scheduler import TasksetScheduler, get_taskset_scheduler
from trinity.common.config import DataSelectorConfig, FormatConfig, TasksetConfig
from trinity.common.workflows.workflow import Task
Expand Down Expand Up @@ -356,3 +357,17 @@ async def test_task_scheduler_simple(self):
task_scheduler_state = task_scheduler.state_dict()
self.assertEqual(len(task_scheduler_state), 1)
self.assertEqual(task_scheduler_state[0]["current_index"], 12)


class TestShuffleSelector(unittest.TestCase):
def test_reshuffles_between_epochs(self):
class _DataSource:
dataset_size = 10

selector = ShuffleSelector(_DataSource(), DataSelectorConfig(seed=42))
epoch0 = selector.get_indices(_DataSource.dataset_size)
epoch1 = selector.get_indices(_DataSource.dataset_size)
self.assertEqual(sorted(epoch0), list(range(_DataSource.dataset_size)))
self.assertEqual(sorted(epoch1), list(range(_DataSource.dataset_size)))
# consecutive epochs must not reuse the same permutation
self.assertNotEqual(epoch0, epoch1)
2 changes: 1 addition & 1 deletion trinity/buffer/selector/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def _get_orders(self) -> List[int]:
def get_indices(self, batch_size: int, return_extra_info: bool = False) -> List[int]:
start = self.current_index % self.dataset_size
end = start + batch_size
self.current_index += batch_size
if end <= self.dataset_size:
ret = self.orders[start:end]
# At end of epoch, reshuffle for next epoch
Expand All @@ -147,7 +148,6 @@ def get_indices(self, batch_size: int, return_extra_info: bool = False) -> List[
# At end of epoch, reshuffle for next epoch
self.orders = self._get_orders()
ret += self.orders[: (end - self.dataset_size)]
self.current_index += batch_size
return ret

def feedback(self, indices: List[int], values: List[float]) -> None:
Expand Down
Loading