diff --git a/src/lean_spec/subspecs/networking/gossipsub/rpc.py b/src/lean_spec/subspecs/networking/gossipsub/rpc.py index 95b8c4288..2dceb3dc8 100644 --- a/src/lean_spec/subspecs/networking/gossipsub/rpc.py +++ b/src/lean_spec/subspecs/networking/gossipsub/rpc.py @@ -63,16 +63,6 @@ class ProtobufDecodeError(ValueError): """Raised when protobuf data cannot be decoded.""" -def _decode_varint_at(data: bytes, pos: int) -> tuple[int, int]: - """ - Decode varint returning (value, new_position). - - Wrapper around canonical varint.decode for protobuf parsing convenience. - """ - value, consumed = decode_varint(data, pos) - return value, pos + consumed - - def _decode_length_at(data: bytes, pos: int) -> tuple[int, int]: """Decode a length varint and validate bounds. @@ -82,7 +72,8 @@ def _decode_length_at(data: bytes, pos: int) -> tuple[int, int]: Raises: ProtobufDecodeError: If the length exceeds available data. """ - length, new_pos = _decode_varint_at(data, pos) + length, consumed = decode_varint(data, pos) + new_pos = pos + consumed if new_pos + length > len(data): raise ProtobufDecodeError( f"Length field {length} at position {pos} exceeds data size {len(data)}" @@ -102,8 +93,8 @@ def decode_tag(data: bytes, pos: int) -> tuple[int, int, int]: Returns: (field_number, wire_type, new_position) tuple. """ - tag, pos = _decode_varint_at(data, pos) - return tag >> 3, tag & 0x07, pos + tag, consumed = decode_varint(data, pos) + return tag >> 3, tag & 0x07, pos + consumed def encode_length_delimited(field_number: int, data: bytes) -> bytes: @@ -159,7 +150,8 @@ def decode(cls, data: bytes) -> SubOpts: field_num, wire_type, pos = decode_tag(data, pos) if field_num == 1 and wire_type == WIRE_TYPE_VARINT: - value, pos = _decode_varint_at(data, pos) + value, consumed = decode_varint(data, pos) + pos += consumed subscribe = value != 0 elif field_num == 2 and wire_type == WIRE_TYPE_LENGTH_DELIMITED: length, pos = _decode_length_at(data, pos) @@ -440,7 +432,8 @@ def decode(cls, data: bytes) -> ControlPrune: peers.append(PrunePeerInfo.decode(data[pos : pos + length])) pos += length elif field_num == 3 and wire_type == WIRE_TYPE_VARINT: - backoff, pos = _decode_varint_at(data, pos) + backoff, consumed = decode_varint(data, pos) + pos += consumed else: pos = _skip_field(data, pos, wire_type) @@ -644,7 +637,8 @@ def _skip_field(data: bytes, pos: int, wire_type: int) -> int: group types 3/4) and cannot be skipped. """ if wire_type == WIRE_TYPE_VARINT: - _, pos = _decode_varint_at(data, pos) + _, consumed = decode_varint(data, pos) + pos += consumed elif wire_type == WIRE_TYPE_LENGTH_DELIMITED: length, pos = _decode_length_at(data, pos) pos += length