Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 10 additions & 16 deletions src/lean_spec/subspecs/networking/gossipsub/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,6 @@ class ProtobufDecodeError(ValueError):
"""Raised when protobuf data cannot be decoded."""


def _decode_varint_at(data: bytes, pos: int) -> tuple[int, int]:
"""
Decode varint returning (value, new_position).

Wrapper around canonical varint.decode for protobuf parsing convenience.
"""
value, consumed = decode_varint(data, pos)
return value, pos + consumed


def _decode_length_at(data: bytes, pos: int) -> tuple[int, int]:
"""Decode a length varint and validate bounds.

Expand All @@ -82,7 +72,8 @@ def _decode_length_at(data: bytes, pos: int) -> tuple[int, int]:
Raises:
ProtobufDecodeError: If the length exceeds available data.
"""
length, new_pos = _decode_varint_at(data, pos)
length, consumed = decode_varint(data, pos)
new_pos = pos + consumed
if new_pos + length > len(data):
raise ProtobufDecodeError(
f"Length field {length} at position {pos} exceeds data size {len(data)}"
Expand All @@ -102,8 +93,8 @@ def decode_tag(data: bytes, pos: int) -> tuple[int, int, int]:
Returns:
(field_number, wire_type, new_position) tuple.
"""
tag, pos = _decode_varint_at(data, pos)
return tag >> 3, tag & 0x07, pos
tag, consumed = decode_varint(data, pos)
return tag >> 3, tag & 0x07, pos + consumed


def encode_length_delimited(field_number: int, data: bytes) -> bytes:
Expand Down Expand Up @@ -159,7 +150,8 @@ def decode(cls, data: bytes) -> SubOpts:
field_num, wire_type, pos = decode_tag(data, pos)

if field_num == 1 and wire_type == WIRE_TYPE_VARINT:
value, pos = _decode_varint_at(data, pos)
value, consumed = decode_varint(data, pos)
pos += consumed
subscribe = value != 0
elif field_num == 2 and wire_type == WIRE_TYPE_LENGTH_DELIMITED:
length, pos = _decode_length_at(data, pos)
Expand Down Expand Up @@ -440,7 +432,8 @@ def decode(cls, data: bytes) -> ControlPrune:
peers.append(PrunePeerInfo.decode(data[pos : pos + length]))
pos += length
elif field_num == 3 and wire_type == WIRE_TYPE_VARINT:
backoff, pos = _decode_varint_at(data, pos)
backoff, consumed = decode_varint(data, pos)
pos += consumed
else:
pos = _skip_field(data, pos, wire_type)

Expand Down Expand Up @@ -644,7 +637,8 @@ def _skip_field(data: bytes, pos: int, wire_type: int) -> int:
group types 3/4) and cannot be skipped.
"""
if wire_type == WIRE_TYPE_VARINT:
_, pos = _decode_varint_at(data, pos)
_, consumed = decode_varint(data, pos)
pos += consumed
elif wire_type == WIRE_TYPE_LENGTH_DELIMITED:
length, pos = _decode_length_at(data, pos)
pos += length
Expand Down
Loading