From 3d96f2b3cda5cd8c3f6ac152ab44d3bcefa441f9 Mon Sep 17 00:00:00 2001 From: Hana Joo Date: Wed, 1 Jul 2026 16:20:01 -0700 Subject: [PATCH] Automated Code Change PiperOrigin-RevId: 941350765 --- envlogger/backends/backend_reader.py | 6 +++--- envlogger/backends/riegeli_backend_reader.py | 2 +- envlogger/backends/rlds_utils.py | 2 +- envlogger/backends/tfds_backend_testlib.py | 2 +- envlogger/backends/tfds_backend_writer.py | 2 +- envlogger/converters/codec.py | 8 ++++---- envlogger/reader.py | 2 +- envlogger/testing/catch_env.py | 2 +- 8 files changed, 13 insertions(+), 13 deletions(-) diff --git a/envlogger/backends/backend_reader.py b/envlogger/backends/backend_reader.py index 91f6b5f..be51554 100644 --- a/envlogger/backends/backend_reader.py +++ b/envlogger/backends/backend_reader.py @@ -42,7 +42,7 @@ def __init__(self, count: int, get_nth_item: Callable[[int], T]): self._index = 0 self._get_nth_item = get_nth_item - def __getitem__(self, index: Union[int, slice]) -> Union[T, list[T]]: + def __getitem__(self, index: Union[int, slice]) -> Union[T, list[T]]: # pyrefly: ignore[bad-override] """Retrieves items from this sequence. Args: @@ -71,7 +71,7 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[T]: while self._index < len(self): - yield self[self._index] + yield self[self._index] # pyrefly: ignore[invalid-yield] self._index += 1 self._index = 0 @@ -79,7 +79,7 @@ def __next__(self) -> T: if self._index < len(self): index = self._index self._index += 1 - return self[index] + return self[index] # pyrefly: ignore[bad-return] else: raise StopIteration() diff --git a/envlogger/backends/riegeli_backend_reader.py b/envlogger/backends/riegeli_backend_reader.py index 4601938..2feb377 100644 --- a/envlogger/backends/riegeli_backend_reader.py +++ b/envlogger/backends/riegeli_backend_reader.py @@ -96,7 +96,7 @@ def _get_nth_step(self, i: int) -> step_data.StepData: """Returns the timestep given by offset `i` (0-based).""" serialized_data = self._reader.serialized_step(i) data = storage_pb2.Data.FromString(serialized_data) - return self._decode_step_data(codec.decode(data)) + return self._decode_step_data(codec.decode(data)) # pyrefly: ignore[bad-argument-type] def _get_nth_episode_info(self, i: int, diff --git a/envlogger/backends/rlds_utils.py b/envlogger/backends/rlds_utils.py index 343f665..709fa7c 100644 --- a/envlogger/backends/rlds_utils.py +++ b/envlogger/backends/rlds_utils.py @@ -104,7 +104,7 @@ def maybe_recover_last_shard(builder: tfds.core.DatasetBuilder): continue logging.info('Recovering data for shard %s.', extra_shard) splits_to_update += 1 - ds = tf.data.TFRecordDataset(extra_shard) + ds = tf.data.TFRecordDataset(extra_shard) # pyrefly: ignore[bad-instantiation] num_examples = 0 num_bytes = 0 for ex in ds: diff --git a/envlogger/backends/tfds_backend_testlib.py b/envlogger/backends/tfds_backend_testlib.py index 929e8b6..e06b362 100644 --- a/envlogger/backends/tfds_backend_testlib.py +++ b/envlogger/backends/tfds_backend_testlib.py @@ -72,7 +72,7 @@ def catch_env_tfds_config( name: str = 'catch_example') -> tfds.rlds.rlds_base.DatasetConfig: """Creates a TFDS DatasetConfig for the Catch environment.""" return tfds.rlds.rlds_base.DatasetConfig( - name=name, + name=name, # pyrefly: ignore[unexpected-keyword] observation_info=tfds.features.Tensor( shape=(10, 5), dtype=tf.float32, encoding=tfds.features.Encoding.ZLIB), diff --git a/envlogger/backends/tfds_backend_writer.py b/envlogger/backends/tfds_backend_writer.py index 08450bd..6f1dd4c 100644 --- a/envlogger/backends/tfds_backend_writer.py +++ b/envlogger/backends/tfds_backend_writer.py @@ -119,7 +119,7 @@ def _record_step(self, data: step_data.StepData, self._current_episode.add_step(data) def set_episode_metadata(self, data: dict[str, Any]) -> None: - self._current_episode.metadata = data + self._current_episode.metadata = data # pyrefly: ignore[missing-attribute] def close(self) -> None: logging.info('Deleting the backend with data_dir: %r', self._data_directory) diff --git a/envlogger/converters/codec.py b/envlogger/converters/codec.py index bcd9175..7752c97 100644 --- a/envlogger/converters/codec.py +++ b/envlogger/converters/codec.py @@ -164,9 +164,9 @@ def _set_datum_values_from_scalar( elif dtype == 'uint16': values.uint16_values = uint16struct.pack(scalar) elif dtype == 'float32': - values.float_values.append(scalar) + values.float_values.append(scalar) # pyrefly: ignore[bad-argument-type] elif dtype == 'float64': - values.double_values.append(scalar) + values.double_values.append(scalar) # pyrefly: ignore[bad-argument-type] elif dtype in ['int32', 'int64', 'uint32', 'uint64']: getattr(values, f'{dtype}_values').append(scalar) else: @@ -199,7 +199,7 @@ def _set_datum_values_from_array( ]: if np.issubdtype(array.dtype, dtype): for x in array.flatten(): - vs.append(cast_type(x)) + vs.append(cast_type(x)) # pyrefly: ignore[bad-argument-type] return for key, dtype, cast_type in [ @@ -376,7 +376,7 @@ def decode_datum( ]: if vs: if is_scalar: - return dtype(converter.unpack(vs)[0]) + return dtype(converter.unpack(vs)[0]) # pyrefly: ignore[bad-return] array = np.frombuffer(vs, dtype=dtype_code).astype(dtype) if values.string_values: diff --git a/envlogger/reader.py b/envlogger/reader.py index 1fbea29..c60fc32 100644 --- a/envlogger/reader.py +++ b/envlogger/reader.py @@ -58,7 +58,7 @@ def copy(self): c = copy.copy(self) c._backend = self._backend.copy() - c._observation = self._observation_spec + c._observation = self._observation_spec # pyrefly: ignore[missing-attribute] c._action_spec = self._action_spec c._reward_spec = self._reward_spec c._discount_spec = self._discount_spec diff --git a/envlogger/testing/catch_env.py b/envlogger/testing/catch_env.py index 1cff360..afa4289 100644 --- a/envlogger/testing/catch_env.py +++ b/envlogger/testing/catch_env.py @@ -73,7 +73,7 @@ def step(self, action): self._paddle_x = np.clip(self._paddle_x + dx, 0, self._columns - 1) # Drop the ball. - self._ball_y += 1 + self._ball_y += 1 # pyrefly: ignore[unsupported-operation] # Check for termination. if self._ball_y == self._paddle_y: