From 95756097ba852db14c3c2642d995b5eb323164c3 Mon Sep 17 00:00:00 2001 From: Nithin Tatikonda Date: Thu, 25 Jun 2026 11:45:14 -0700 Subject: [PATCH] Internal PiperOrigin-RevId: 938107452 --- .../_src/python/dataset/transformations/BUILD | 1 + .../transformations/interleave_test.py | 51 +++++++++++++++++++ .../dataset/transformations/prefetch.py | 6 ++- .../python/dataset/transformations/repeat.py | 12 ++++- 4 files changed, 66 insertions(+), 4 deletions(-) diff --git a/grain/_src/python/dataset/transformations/BUILD b/grain/_src/python/dataset/transformations/BUILD index fae3190d2..af6389b9f 100644 --- a/grain/_src/python/dataset/transformations/BUILD +++ b/grain/_src/python/dataset/transformations/BUILD @@ -268,6 +268,7 @@ py_test( srcs = ["interleave_test.py"], srcs_version = "PY3", deps = [ + ":zip", "//grain/_src/python:options", "//grain/_src/python/dataset", "//grain/_src/python/dataset:base", diff --git a/grain/_src/python/dataset/transformations/interleave_test.py b/grain/_src/python/dataset/transformations/interleave_test.py index 218569a98..69ddd8407 100644 --- a/grain/_src/python/dataset/transformations/interleave_test.py +++ b/grain/_src/python/dataset/transformations/interleave_test.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import threading from typing import cast + from absl.testing import absltest from absl.testing import flagsaver from absl.testing import parameterized @@ -22,6 +24,8 @@ from grain._src.python.dataset import dataset from grain._src.python.dataset.transformations import interleave from grain._src.python.dataset.transformations import prefetch +from grain._src.python.dataset.transformations import repeat +from grain._src.python.dataset.transformations import zip as zip_dataset from grain._src.python.testing.experimental import assert_equal_output_after_checkpoint import numpy as np @@ -515,6 +519,53 @@ def test_setting_shard_state_with_exhausted_states(self): if isinstance(self, InterleaveIterDatasetTest): self.assertEqual(state["exhausted"], [0, 1]) + def test_options_propagated_with_interleaved_interleaves(self): + ds = ( + dataset.MapDataset.range(0, 1500) + .to_iter_dataset() + .filter(lambda x: False) + ) + interleave_ds = self._create_dataset([ds], cycle_length=1) + interleave_ds_2 = self._create_dataset([interleave_ds], cycle_length=1) + + filter_options = base.DatasetOptions(filter_raise_threshold_ratio=0.1) + ds_with_options = dataset.WithOptionsIterDataset( + interleave_ds_2, filter_options + ) + with self.assertRaisesRegex(ValueError, r"skipped 100\.00 %"): + list(ds_with_options) + + def test_options_propagated_with_zipped_interleaves(self): + no_filter_ds = dataset.MapDataset.range( + 1200, 1500 + ).to_iter_dataset() # 300 elements + + filter_ds = ( + dataset.MapDataset.range(0, 1500) + .to_iter_dataset() + .filter(lambda x: x >= 1200) + ) + interleave_ds1 = self._create_dataset([filter_ds], cycle_length=1) + interleave_ds2 = self._create_dataset([no_filter_ds], cycle_length=1) + zipped_ds = zip_dataset.ZipIterDataset([interleave_ds1, interleave_ds2]) + zipped_ds2 = zip_dataset.ZipIterDataset([interleave_ds2, interleave_ds1]) + + filter_options = base.DatasetOptions(filter_raise_threshold_ratio=0.1) + ds_with_options1 = dataset.WithOptionsIterDataset(zipped_ds, filter_options) + ds_with_options2 = dataset.WithOptionsIterDataset( + zipped_ds2, filter_options + ) + + with self.assertRaisesRegex( + ValueError, r"FilterDatasetIterator.*skipped 100\.00 %" + ): + list(ds_with_options1) + + with self.assertRaisesRegex( + ValueError, r"FilterDatasetIterator.*skipped 100\.00 %" + ): + list(ds_with_options2) + if __name__ == "__main__": absltest.main() diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index 76449dd38..8e9415a01 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -555,8 +555,7 @@ def __init__( assert target_prefetch_buffer_size >= 0, target_prefetch_buffer_size self._target_prefetch_buffer_size = target_prefetch_buffer_size self.autotune_buffer_size = autotune_buffer_size - self._step_zero_state: StateT = parent.get_state() - self._state: StateT | None = self._step_zero_state + self._state: StateT | None = None self._next_index: int | None = 0 self._prefetch_thread: threading.Thread | None = None @@ -628,6 +627,9 @@ def start_prefetch(self): ) def __next__(self): + if self._state is None: + self._state = self._maybe_nonnative_parent.get_state() + timer = dataset_stats.Timer() with timer: if self._target_prefetch_buffer_size > 0: diff --git a/grain/_src/python/dataset/transformations/repeat.py b/grain/_src/python/dataset/transformations/repeat.py index e1d2cc550..3e5df9e13 100644 --- a/grain/_src/python/dataset/transformations/repeat.py +++ b/grain/_src/python/dataset/transformations/repeat.py @@ -95,7 +95,9 @@ def __init__( super().__init__(parent) self._num_epochs = num_epochs self._epoch = 0 - self._parent_starting_state = self._parent.get_state() + self._parent_starting_state = None + + def _ensure_keep_alive_flags(self): # Check for ProcessPrefetchDatasetIterator and InterleaveDatasetIterator and # ensure processes/iterators are not reset on StopIteration. This is needed # to avoid recreating the worker processes on each epoch. @@ -107,13 +109,16 @@ def __init__( if isinstance(node, interleave.InterleaveDatasetIterator): node.set_keep_iterators_after_stop_iteration(True) to_visit.extend(n for n in node._iterators_in_use if n is not None) # pylint: disable=protected-access - to_visit.extend(n for n in node._parents) + to_visit.extend(n for n in node._parents) # pylint: disable=protected-access @stats.record_next_duration_if_output def __next__(self): timer = stats.Timer() if self._epoch == self._num_epochs: raise StopIteration + if self._parent_starting_state is None: + self._parent_starting_state = self._parent.get_state() + self._ensure_keep_alive_flags() while True: try: elem = next(self._parent) @@ -131,6 +136,9 @@ def get_state(self): return {"parent": self._parent.get_state(), "epoch": self._epoch} def set_state(self, state): + if self._parent_starting_state is None: + self._parent_starting_state = self._parent.get_state() + self._ensure_keep_alive_flags() self._epoch = state["epoch"] self._parent.set_state(state["parent"])