Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions grain/_src/python/dataset/transformations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ py_test(
srcs = ["interleave_test.py"],
srcs_version = "PY3",
deps = [
":zip",
"//grain/_src/python:options",
"//grain/_src/python/dataset",
"//grain/_src/python/dataset:base",
Expand Down
51 changes: 51 additions & 0 deletions grain/_src/python/dataset/transformations/interleave_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import threading
from typing import cast

from absl.testing import absltest
from absl.testing import flagsaver
from absl.testing import parameterized
Expand All @@ -22,6 +24,8 @@
from grain._src.python.dataset import dataset
from grain._src.python.dataset.transformations import interleave
from grain._src.python.dataset.transformations import prefetch
from grain._src.python.dataset.transformations import repeat
from grain._src.python.dataset.transformations import zip as zip_dataset
from grain._src.python.testing.experimental import assert_equal_output_after_checkpoint
import numpy as np

Expand Down Expand Up @@ -515,6 +519,53 @@ def test_setting_shard_state_with_exhausted_states(self):
if isinstance(self, InterleaveIterDatasetTest):
self.assertEqual(state["exhausted"], [0, 1])

def test_options_propagated_with_interleaved_interleaves(self):
ds = (
dataset.MapDataset.range(0, 1500)
.to_iter_dataset()
.filter(lambda x: False)
)
interleave_ds = self._create_dataset([ds], cycle_length=1)
interleave_ds_2 = self._create_dataset([interleave_ds], cycle_length=1)

filter_options = base.DatasetOptions(filter_raise_threshold_ratio=0.1)
ds_with_options = dataset.WithOptionsIterDataset(
interleave_ds_2, filter_options
)
with self.assertRaisesRegex(ValueError, r"skipped 100\.00 %"):
list(ds_with_options)

def test_options_propagated_with_zipped_interleaves(self):
no_filter_ds = dataset.MapDataset.range(
1200, 1500
).to_iter_dataset() # 300 elements

filter_ds = (
dataset.MapDataset.range(0, 1500)
.to_iter_dataset()
.filter(lambda x: x >= 1200)
)
interleave_ds1 = self._create_dataset([filter_ds], cycle_length=1)
interleave_ds2 = self._create_dataset([no_filter_ds], cycle_length=1)
zipped_ds = zip_dataset.ZipIterDataset([interleave_ds1, interleave_ds2])
zipped_ds2 = zip_dataset.ZipIterDataset([interleave_ds2, interleave_ds1])

filter_options = base.DatasetOptions(filter_raise_threshold_ratio=0.1)
ds_with_options1 = dataset.WithOptionsIterDataset(zipped_ds, filter_options)
ds_with_options2 = dataset.WithOptionsIterDataset(
zipped_ds2, filter_options
)

with self.assertRaisesRegex(
ValueError, r"FilterDatasetIterator.*skipped 100\.00 %"
):
list(ds_with_options1)

with self.assertRaisesRegex(
ValueError, r"FilterDatasetIterator.*skipped 100\.00 %"
):
list(ds_with_options2)


if __name__ == "__main__":
absltest.main()
6 changes: 4 additions & 2 deletions grain/_src/python/dataset/transformations/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,8 +555,7 @@ def __init__(
assert target_prefetch_buffer_size >= 0, target_prefetch_buffer_size
self._target_prefetch_buffer_size = target_prefetch_buffer_size
self.autotune_buffer_size = autotune_buffer_size
self._step_zero_state: StateT = parent.get_state()
self._state: StateT | None = self._step_zero_state
self._state: StateT | None = None
self._next_index: int | None = 0

self._prefetch_thread: threading.Thread | None = None
Expand Down Expand Up @@ -628,6 +627,9 @@ def start_prefetch(self):
)
def __next__(self):

if self._state is None:
self._state = self._maybe_nonnative_parent.get_state()

timer = dataset_stats.Timer()
with timer:
if self._target_prefetch_buffer_size > 0:
Expand Down
12 changes: 10 additions & 2 deletions grain/_src/python/dataset/transformations/repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def __init__(
super().__init__(parent)
self._num_epochs = num_epochs
self._epoch = 0
self._parent_starting_state = self._parent.get_state()
self._parent_starting_state = None

def _ensure_keep_alive_flags(self):
# Check for ProcessPrefetchDatasetIterator and InterleaveDatasetIterator and
# ensure processes/iterators are not reset on StopIteration. This is needed
# to avoid recreating the worker processes on each epoch.
Expand All @@ -107,13 +109,16 @@ def __init__(
if isinstance(node, interleave.InterleaveDatasetIterator):
node.set_keep_iterators_after_stop_iteration(True)
to_visit.extend(n for n in node._iterators_in_use if n is not None) # pylint: disable=protected-access
to_visit.extend(n for n in node._parents)
to_visit.extend(n for n in node._parents) # pylint: disable=protected-access

@stats.record_next_duration_if_output
def __next__(self):
timer = stats.Timer()
if self._epoch == self._num_epochs:
raise StopIteration
if self._parent_starting_state is None:
self._parent_starting_state = self._parent.get_state()
self._ensure_keep_alive_flags()
while True:
try:
elem = next(self._parent)
Expand All @@ -131,6 +136,9 @@ def get_state(self):
return {"parent": self._parent.get_state(), "epoch": self._epoch}

def set_state(self, state):
if self._parent_starting_state is None:
self._parent_starting_state = self._parent.get_state()
self._ensure_keep_alive_flags()
self._epoch = state["epoch"]
self._parent.set_state(state["parent"])

Expand Down
Loading