Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions montycat/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@ async def send_data(host: str, port: int, query: bytes, callback=None, stop_even
ConnectionRefusedError: If the server refuses the connection.
"""
CHUNK_SIZE = 1024 * 256
try:

if tls:
context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
else:
context = None
if tls:
context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
else:
context = None

writer = None
try:
reader, writer = await asyncio.wait_for(
asyncio.open_connection(host, port, ssl=context),
timeout=10.0
Expand Down Expand Up @@ -56,8 +57,6 @@ async def send_data(host: str, port: int, query: bytes, callback=None, stop_even
callback(response)
data.clear()

writer.close()
await writer.wait_closed()
return None # subscription ended

else:
Expand All @@ -67,11 +66,23 @@ async def send_data(host: str, port: int, query: bytes, callback=None, stop_even
data.extend(chunk)
break
data.extend(chunk)
writer.close()
await writer.wait_closed()
return recursive_parse_orjson(data.decode().strip())
except Exception as e:
return f"Error: {e}"
finally:
# Always close the connection — including on asyncio.CancelledError
# (which doesn't inherit from Exception), so the server-side
# subscription handler sees EOF and tears down its watchers. Without
# this, the cancelled subscription leaves the TCP socket open until
# GC, and the server's sled subscribers stay alive — which then
# deadlocks any subsequent remove_keyspace/remove_store on the same
# store.
if writer is not None:
try:
writer.close()
await writer.wait_closed()
except Exception:
pass

def recursive_parse_orjson(data):
"""
Expand Down
26 changes: 10 additions & 16 deletions montycat/store_classes/inmemory.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,8 @@ async def insert_custom_key(cls, custom_key: str, expire_sec: int = 0):
raise ValueError("No custom key provided for insertion.")

custom_key_converted = convert_custom_key(custom_key)
cls.command = "insert_custom_key"

query = convert_to_binary_query(cls, key=custom_key_converted, expire_sec=expire_sec)
query = convert_to_binary_query(cls, command="insert_custom_key", key=custom_key_converted, expire_sec=expire_sec)
return await cls._run_query(query)

@classmethod
Expand All @@ -103,9 +102,8 @@ async def insert_custom_key_value(cls, custom_key: str, value: dict, expire_sec:
raise ValueError("No custom key provided for insertion.")

custom_key_converted = convert_custom_key(custom_key)
cls.command = "insert_custom_key_value"

query = convert_to_binary_query(cls, key=custom_key_converted, value=value, expire_sec=expire_sec)
query = convert_to_binary_query(cls, command="insert_custom_key_value", key=custom_key_converted, value=value, expire_sec=expire_sec)
return await cls._run_query(query)

@classmethod
Expand All @@ -120,9 +118,7 @@ async def insert_value(cls, value: dict, expire_sec: int = 0):
if not value:
raise ValueError("No value provided for insertion.")

cls.command = "insert_value"

query = convert_to_binary_query(cls, value=value, expire_sec=expire_sec)
query = convert_to_binary_query(cls, command="insert_value", value=value, expire_sec=expire_sec)
return await cls._run_query(query)

@classmethod
Expand All @@ -144,6 +140,9 @@ async def update_value(cls, key: Union[str, None] = None, custom_key: Union[str,
or a string message if the update was unsuccessful.
"""

if key and custom_key:
raise ValueError("Provide either key or custom_key, not both.")

if custom_key and len(custom_key) > 0:
key = convert_custom_key(custom_key)

Expand All @@ -152,9 +151,7 @@ async def update_value(cls, key: Union[str, None] = None, custom_key: Union[str,
if not key:
raise ValueError("No key provided")

cls.command = "update_value"

query = convert_to_binary_query(cls, key=key, value=filters, expire_sec=expire_sec)
query = convert_to_binary_query(cls, command="update_value", key=key, value=filters, expire_sec=expire_sec)
return await cls._run_query(query)

@classmethod
Expand All @@ -172,8 +169,7 @@ async def insert_bulk(cls, bulk_values: list, expire_sec: int = 0):
if not bulk_values:
raise ValueError("No values provided for bulk insertion.")

cls.command = "insert_bulk"
query = convert_to_binary_query(cls, bulk_values=bulk_values, expire_sec=expire_sec)
query = convert_to_binary_query(cls, command="insert_bulk", bulk_values=bulk_values, expire_sec=expire_sec)
return await cls._run_query(query)

@classmethod
Expand All @@ -186,11 +182,9 @@ async def get_keys(cls, volumes: list = [], latest_volume: bool = False):
"""

if not latest_volume and not volumes:
raise ValueError("Please provide keys or volumes/latest volume.")

cls.command = "get_keys"
raise ValueError("Please provide volumes/latest volume.")

query = convert_to_binary_query(cls, volumes=volumes, latest_volume=latest_volume)
query = convert_to_binary_query(cls, command="get_keys", volumes=volumes, latest_volume=latest_volume)
return await cls._run_query(query)

@classmethod
Expand Down
108 changes: 60 additions & 48 deletions montycat/store_classes/kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,46 @@

class generic_kv:
store: str = ""
command: str = ""
limit_output: dict = {}
schema = None

@classmethod
async def _run_query(cls, query: str, callback=None, stop_event: Union[asyncio.Event, None] = None):
port = cls.port + 1 if callback else cls.port
async def subscribe(cls, key: Union[str, None] = None, custom_key: Union[str, None] = None, callback=None, subscription_port: Union[int, None] = None):
"""
Subscribe to real-time changes for a key or the entire keyspace.
Works for both in-memory and persistent keyspaces.

Args:
key: The key to subscribe to. If None, subscribes to the entire keyspace.
custom_key: A custom key to subscribe to. Converted to internal format.
callback: A function called with response data as it is received.
Returns:
A tuple of (task, stop_event). Call stop_event.set() to stop the subscription.
"""
if not callback:
raise ValueError("Callback function is not provided")

if key and custom_key:
raise ValueError("Provide either key or custom_key, not both.")

stop_subscription = asyncio.Event()

query_dict = {
"subscribe": True,
"store": cls.store,
"keyspace": cls.keyspace,
"key": convert_custom_key(custom_key) if custom_key else key,
"persistent": cls.persistent,
"username": cls.username,
"password": cls.password
}

query = orjson.dumps(query_dict)
task = asyncio.create_task(cls._run_query(query, callback=callback, stop_event=stop_subscription, subscription_port=subscription_port))

return task, stop_subscription

@classmethod
async def _run_query(cls, query: str, callback=None, stop_event: Union[asyncio.Event, None] = None, subscription_port: Union[int, None] = None):
port = subscription_port if subscription_port else (cls.port + 1 if callback else cls.port)
return await send_data(cls.host, port, query, callback=callback, stop_event=stop_event, tls=cls.tls)

@classmethod
Expand Down Expand Up @@ -131,18 +164,16 @@ async def get_value(cls, key: Union[str, None] = None, custom_key: Union[str, No
Returns:
The value associated with the key or custom key. Class 'str' if the get operation failed.
"""
if pointers_metadata and with_pointers:
raise ValueError("You select both pointers value and pointers metadata. Choose one")
if key and custom_key:
raise ValueError("Provide either key or custom_key, not both.")

if custom_key and len(custom_key) > 0:
key = convert_custom_key(custom_key)

if not key:
raise ValueError("No key provided")

cls.command = "get_value"

query = convert_to_binary_query(cls, key=key, with_pointers=with_pointers, key_included=key_included, pointers_metadata=pointers_metadata)
query = convert_to_binary_query(cls, command="get_value", key=key, with_pointers=with_pointers, key_included=key_included, pointers_metadata=pointers_metadata)
return await cls._run_query(query)

@classmethod
Expand All @@ -161,15 +192,16 @@ async def delete_key(cls, key: Union[str, None] = None, custom_key: Union[str, N
bool | str: Returns a boolean indicating success (True) or failure (False),
or a string message if the deletion was unsuccessful.
"""
if key and custom_key:
raise ValueError("Provide either key or custom_key, not both.")

if custom_key and len(custom_key) > 0:
key = convert_custom_key(custom_key)

if not key:
raise ValueError("No key provided")

cls.command = "delete_key"

query = convert_to_binary_query(cls, key=key)
query = convert_to_binary_query(cls, command="delete_key", key=key)
return await cls._run_query(query)

@classmethod
Expand All @@ -191,22 +223,20 @@ async def delete_bulk(cls, bulk_keys: list = [], bulk_custom_keys: list = []):

Raises:
ValueError: If both `bulk_keys` and `bulk_custom_keys` are empty.
ValueError: If both `pointers_metadata` and `with_pointers` are True.
"""
if len(bulk_custom_keys) > 0:
bulk_custom_keys = convert_custom_keys(bulk_custom_keys)
bulk_keys += bulk_custom_keys
bulk_keys = bulk_keys + bulk_custom_keys

if not bulk_keys:
raise ValueError("No keys provided for deletion.")

cls.command = "delete_bulk"
query = convert_to_binary_query(cls, bulk_keys=bulk_keys)
query = convert_to_binary_query(cls, command="delete_bulk", bulk_keys=bulk_keys)
return await cls._run_query(query)

@classmethod
async def get_bulk(
cls, bulk_keys: list = [], bulk_custom_keys: list = [], limit: list = [], with_pointers: bool = False, key_included: bool = False, pointers_metadata: bool = False, volumes: list[str] = [], latest_volume: bool = False):
cls, bulk_keys: list = [], bulk_custom_keys: list = [], limit: list[int] = [], with_pointers: bool = False, key_included: bool = False, pointers_metadata: bool = False, volumes: list[str] = [], latest_volume: bool = False):
"""
Retrieve multiple keys in bulk. Custom keys can be converted and added to the bulk retrieval list.
Additionally, a limit on the number of records to retrieve can be applied, and whether to include pointers
Expand Down Expand Up @@ -238,19 +268,17 @@ async def get_bulk(

if len(bulk_custom_keys) > 0:
bulk_custom_keys = convert_custom_keys(bulk_custom_keys)
bulk_keys += bulk_custom_keys
bulk_keys = bulk_keys + bulk_custom_keys

selected_options = sum([
bool(bulk_keys),
bool(volumes and len(volumes) > 0) or latest_volume or bool(limit and len(limit) > 0 and (limit[0] != 0 or limit[1] != 0)),
bool(volumes) or latest_volume or bool(limit and limit != [0, 0]),
])

if selected_options != 1:
raise ValueError("Please provide keys or volumes/latest volume or limit.")

cls.command = "get_bulk"
cls.limit_output = handle_limit(limit)
query = convert_to_binary_query(cls, bulk_keys=bulk_keys, with_pointers=with_pointers, key_included=key_included, pointers_metadata=pointers_metadata, volumes=volumes, latest_volume=latest_volume)
query = convert_to_binary_query(cls, command="get_bulk", limit_output=handle_limit(limit), bulk_keys=bulk_keys, with_pointers=with_pointers, key_included=key_included, pointers_metadata=pointers_metadata, volumes=volumes, latest_volume=latest_volume)
return await cls._run_query(query)

@classmethod
Expand Down Expand Up @@ -281,8 +309,7 @@ async def update_bulk(cls, bulk_keys_values: dict = {}, bulk_custom_keys_values:
bulk_custom_keys_values = convert_custom_keys_values(bulk_custom_keys_values)
bulk_keys_values = {**bulk_keys_values, **bulk_custom_keys_values}

cls.command = "update_bulk"
query = convert_to_binary_query(cls, bulk_keys_values=bulk_keys_values)
query = convert_to_binary_query(cls, command="update_bulk", bulk_keys_values=bulk_keys_values)
return await cls._run_query(query)

@classmethod
Expand All @@ -301,14 +328,7 @@ async def lookup_keys_where(cls, limit: Union[int, list] = 0, schema: Union[str,
ValueError: If no filters are provided.
"""

if schema:
cls.schema = str(schema)
else:
cls.schema = None

cls.command = "lookup_keys"
cls.limit_output = handle_limit(limit)
query = convert_to_binary_query(cls, search_criteria=filters)
query = convert_to_binary_query(cls, command="lookup_keys", limit_output=handle_limit(limit), search_criteria=filters, schema=str(schema) if schema else None)
return await cls._run_query(query)

@classmethod
Expand All @@ -330,14 +350,7 @@ async def lookup_values_where(cls, limit: Union[int, list] = 0, with_pointers: b
ValueError: If no filters are provided.
"""

if schema:
cls.schema = str(schema)
else:
cls.schema = None

cls.command = "lookup_values"
cls.limit_output = handle_limit(limit)
query = convert_to_binary_query(cls, search_criteria=filters, with_pointers=with_pointers, key_included=key_included, pointers_metadata=pointers_metadata)
query = convert_to_binary_query(cls, command="lookup_values", limit_output=handle_limit(limit), search_criteria=filters, with_pointers=with_pointers, key_included=key_included, pointers_metadata=pointers_metadata, schema=str(schema) if schema else None)
return await cls._run_query(query)

@classmethod
Expand All @@ -363,21 +376,21 @@ async def list_all_depending_keys(cls, key: Union[str, None] = None, custom_key:
ValueError: If both `key` and `custom_key` are empty, as one of them
is required to form a valid query.
"""
if key and custom_key:
raise ValueError("Provide either key or custom_key, not both.")

if custom_key and len(custom_key) > 0:
key = convert_custom_key(custom_key)

if not key:
raise ValueError("No key provided")

cls.command = "list_all_depending_keys"

query = convert_to_binary_query(cls, key=key)
query = convert_to_binary_query(cls, command="list_all_depending_keys", key=key)
return await cls._run_query(query)

@classmethod
async def get_len(cls):
cls.command = "get_len"
query = convert_to_binary_query(cls)
query = convert_to_binary_query(cls, command="get_len")
return await cls._run_query(query)

@classmethod
Expand Down Expand Up @@ -416,8 +429,7 @@ async def remove_keyspace(cls):

@classmethod
async def list_all_schemas_in_keyspace(cls):
cls.command = "list_all_schemas_in_keyspace"
query = convert_to_binary_query(cls)
query = convert_to_binary_query(cls, command="list_all_schemas_in_keyspace")
return await cls._run_query(query)

@classmethod
Expand Down
Loading
Loading