Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions envlogger/backends/backend_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, count: int, get_nth_item: Callable[[int], T]):
self._index = 0
self._get_nth_item = get_nth_item

def __getitem__(self, index: Union[int, slice]) -> Union[T, list[T]]:
def __getitem__(self, index: Union[int, slice]) -> Union[T, list[T]]: # pyrefly: ignore[bad-override]
"""Retrieves items from this sequence.

Args:
Expand Down Expand Up @@ -71,15 +71,15 @@ def __len__(self) -> int:

def __iter__(self) -> Iterator[T]:
while self._index < len(self):
yield self[self._index]
yield self[self._index] # pyrefly: ignore[invalid-yield]
self._index += 1
self._index = 0

def __next__(self) -> T:
if self._index < len(self):
index = self._index
self._index += 1
return self[index]
return self[index] # pyrefly: ignore[bad-return]
else:
raise StopIteration()

Expand Down
2 changes: 1 addition & 1 deletion envlogger/backends/riegeli_backend_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _get_nth_step(self, i: int) -> step_data.StepData:
"""Returns the timestep given by offset `i` (0-based)."""
serialized_data = self._reader.serialized_step(i)
data = storage_pb2.Data.FromString(serialized_data)
return self._decode_step_data(codec.decode(data))
return self._decode_step_data(codec.decode(data)) # pyrefly: ignore[bad-argument-type]

def _get_nth_episode_info(self,
i: int,
Expand Down
2 changes: 1 addition & 1 deletion envlogger/backends/rlds_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def maybe_recover_last_shard(builder: tfds.core.DatasetBuilder):
continue
logging.info('Recovering data for shard %s.', extra_shard)
splits_to_update += 1
ds = tf.data.TFRecordDataset(extra_shard)
ds = tf.data.TFRecordDataset(extra_shard) # pyrefly: ignore[bad-instantiation]
num_examples = 0
num_bytes = 0
for ex in ds:
Expand Down
2 changes: 1 addition & 1 deletion envlogger/backends/tfds_backend_testlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def catch_env_tfds_config(
name: str = 'catch_example') -> tfds.rlds.rlds_base.DatasetConfig:
"""Creates a TFDS DatasetConfig for the Catch environment."""
return tfds.rlds.rlds_base.DatasetConfig(
name=name,
name=name, # pyrefly: ignore[unexpected-keyword]
observation_info=tfds.features.Tensor(
shape=(10, 5), dtype=tf.float32,
encoding=tfds.features.Encoding.ZLIB),
Expand Down
2 changes: 1 addition & 1 deletion envlogger/backends/tfds_backend_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _record_step(self, data: step_data.StepData,
self._current_episode.add_step(data)

def set_episode_metadata(self, data: dict[str, Any]) -> None:
self._current_episode.metadata = data
self._current_episode.metadata = data # pyrefly: ignore[missing-attribute]

def close(self) -> None:
logging.info('Deleting the backend with data_dir: %r', self._data_directory)
Expand Down
8 changes: 4 additions & 4 deletions envlogger/converters/codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,9 @@ def _set_datum_values_from_scalar(
elif dtype == 'uint16':
values.uint16_values = uint16struct.pack(scalar)
elif dtype == 'float32':
values.float_values.append(scalar)
values.float_values.append(scalar) # pyrefly: ignore[bad-argument-type]
elif dtype == 'float64':
values.double_values.append(scalar)
values.double_values.append(scalar) # pyrefly: ignore[bad-argument-type]
elif dtype in ['int32', 'int64', 'uint32', 'uint64']:
getattr(values, f'{dtype}_values').append(scalar)
else:
Expand Down Expand Up @@ -199,7 +199,7 @@ def _set_datum_values_from_array(
]:
if np.issubdtype(array.dtype, dtype):
for x in array.flatten():
vs.append(cast_type(x))
vs.append(cast_type(x)) # pyrefly: ignore[bad-argument-type]
return

for key, dtype, cast_type in [
Expand Down Expand Up @@ -376,7 +376,7 @@ def decode_datum(
]:
if vs:
if is_scalar:
return dtype(converter.unpack(vs)[0])
return dtype(converter.unpack(vs)[0]) # pyrefly: ignore[bad-return]
array = np.frombuffer(vs, dtype=dtype_code).astype(dtype)

if values.string_values:
Expand Down
2 changes: 1 addition & 1 deletion envlogger/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def copy(self):
c = copy.copy(self)

c._backend = self._backend.copy()
c._observation = self._observation_spec
c._observation = self._observation_spec # pyrefly: ignore[missing-attribute]
c._action_spec = self._action_spec
c._reward_spec = self._reward_spec
c._discount_spec = self._discount_spec
Expand Down
2 changes: 1 addition & 1 deletion envlogger/testing/catch_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def step(self, action):
self._paddle_x = np.clip(self._paddle_x + dx, 0, self._columns - 1)

# Drop the ball.
self._ball_y += 1
self._ball_y += 1 # pyrefly: ignore[unsupported-operation]

# Check for termination.
if self._ball_y == self._paddle_y:
Expand Down
Loading