diff --git a/.gitignore b/.gitignore index 9b69e2eb4c..ce9052d074 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ dist .vscode tmp/ requirements-musa.txt -logs/ \ No newline at end of file +logs/ +target/ diff --git a/docker/Dockerfile b/docker/Dockerfile index 313c4c72a5..46786b581d 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -141,4 +141,13 @@ RUN if [ "${ENABLE_NIXL}" = "1" ]; then \ fi COPY . /lightllm -RUN pip install -e /lightllm --no-cache-dir + +ARG ENABLE_RUST=1 +ENV PATH="/root/.cargo/bin:${PATH}" +RUN if [ "${ENABLE_RUST}" = "1" ]; then \ + curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain stable -y && \ + pip install --no-cache-dir --upgrade pip setuptools-rust wheel; \ + fi + +RUN pip install -v -e /lightllm --no-cache-dir --no-build-isolation + diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 04e0187452..e898e439c9 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -57,7 +57,7 @@ def make_argument_parser() -> argparse.ArgumentParser: "--select_p_d_node_strategy", type=str, default="round_robin", - choices=["random", "round_robin", "adaptive_load"], + choices=["random", "round_robin", "adaptive_load", "cache_aware"], help="pd master use this strategy to select p d node, can be round_robin, random or adaptive_load", ) parser.add_argument( diff --git a/lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py b/lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py index a48024a39d..d2c9506396 100644 --- a/lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py +++ b/lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py @@ -1,4 +1,10 @@ -from .pd_selector import PDSelector, RandomSelector, RoundRobinSelector, AdaptiveLoadSelector +from .pd_selector import ( + PDSelector, + RandomSelector, + RoundRobinSelector, + AdaptiveLoadSelector, + LoadBalancedCacheAwareSelector, +) def create_selector(selector_type: str, pd_manager) -> PDSelector: @@ -8,5 +14,7 @@ def create_selector(selector_type: str, pd_manager) -> PDSelector: return RoundRobinSelector(pd_manager) elif selector_type == "adaptive_load": return AdaptiveLoadSelector(pd_manager) + elif selector_type == "cache_aware": + return LoadBalancedCacheAwareSelector(pd_manager) else: raise ValueError(f"Invalid selector type: {selector_type}") diff --git a/lightllm/server/httpserver_for_pd_master/pd_selector/cache_aware.py b/lightllm/server/httpserver_for_pd_master/pd_selector/cache_aware.py new file mode 100644 index 0000000000..aea57369a7 --- /dev/null +++ b/lightllm/server/httpserver_for_pd_master/pd_selector/cache_aware.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +import threading +from dataclasses import dataclass +from typing import List, Optional, runtime_checkable + +from lightllm.server.pd_io_struct import PD_Client_Obj +from lightllm.utils.log_utils import init_logger + +from .tree import Tree + + +logger = init_logger(__name__) + + +@dataclass(slots=True) +class CacheAwareConfig: + cache_threshold: float = 0.5 + balance_rel_threshold: float = 1.2 + eviction_interval_secs: int = 30 + max_tree_size: int = 1000000 + + +class CacheAwarePolicy: + def __init__(self, config: Optional[CacheAwareConfig] = None) -> None: + self.config = config or CacheAwareConfig() + self.tree: Tree = Tree() + self._stop_eviction = threading.Event() + self._eviction_thread: Optional[threading.Thread] = None + if self.config.eviction_interval_secs > 0: + self._eviction_thread = threading.Thread( + target=self._run_eviction_loop, name="cache-aware-eviction", daemon=True + ) + self._eviction_thread.start() + + def _run_eviction_loop(self) -> None: + while not self._stop_eviction.wait(self.config.eviction_interval_secs): + logger.info("Running cache eviction...") + self.evict_cache(self.config.max_tree_size) + logger.info(f"Cache eviction completed.: {self.tree.get_used_size_per_tenant()}") + + def close(self) -> None: + self._stop_eviction.set() + if self._eviction_thread is not None and self._eviction_thread.is_alive(): + self._eviction_thread.join(timeout=1.0) + + def init_workers(self, workers: List[PD_Client_Obj]) -> None: + for worker in workers: + self.tree.insert("", worker.url()) + + def add_worker(self, worker: PD_Client_Obj) -> None: + self.tree.insert("", worker.url()) + + def remove_worker(self, worker: PD_Client_Obj) -> None: + self.tree.remove_tenant(worker.url()) + + def remove_worker_by_url(self, url: str) -> None: + self.tree.remove_tenant(url) + + def evict_cache(self, max_size: int) -> None: + self.tree.evict_tenant_by_size(max_size) + + def _select_worker_min_load( + self, + workers: List[PD_Client_Obj], + request_text: Optional[str], + ) -> Optional[PD_Client_Obj]: + + min_load_worker = min(workers, key=lambda worker: worker.load()) + + if request_text is not None: + self.tree.insert(request_text, min_load_worker.url()) + + return min_load_worker + + def select_worker( + self, workers: List[PD_Client_Obj], request_text: Optional[str] = None + ) -> Optional[PD_Client_Obj]: + + if not workers: + return None + + loads = [worker.load() for worker in workers] + min_load = min(loads) if loads else 0 + max_load = max(loads) if loads else 0 + + is_imbalanced = max_load > (min_load * self.config.balance_rel_threshold) + + logger.info( + f"CacheAwarePolicy: min_load={min_load:.4f}, max_load={max_load:.4f}, " + f"balance_rel_threshold={self.config.balance_rel_threshold:.4f}, " + f"is_imbalanced={is_imbalanced}" + ) + + if is_imbalanced: + return self._select_worker_min_load( + workers=workers, + request_text=request_text, + ) + + text = request_text or "" + + result = self.tree.prefix_match_with_counts(text) + match_rate = 0.0 if result.input_char_count == 0 else result.matched_char_count / result.input_char_count + + logger.info( + f"CacheAwarePolicy: matched_char_count={result.matched_char_count}, " + f"input_char_count={result.input_char_count}, match_rate={match_rate:.4f}, " + f"cache_threshold={self.config.cache_threshold:.4f}" + ) + + selected_worker: Optional[PD_Client_Obj] = None + if match_rate > self.config.cache_threshold: + for worker in workers: + if worker.url() == result.tenant: + selected_worker = worker + break + + if selected_worker is None: + # If the matched tenant is not in the current workers, we can evict it from the tree + logger.info(f"Evicting tenant: {result.tenant}") + self.tree.remove_tenant(result.tenant) + + logger.info( + f"CacheAwarePolicy: selected_worker={selected_worker.url() if selected_worker else None}, " + f"match_rate={match_rate:.4f}, cache_threshold={self.config.cache_threshold:.4f}" + ) + + if selected_worker is not None: + self.tree.insert(text, selected_worker.url()) + return selected_worker + else: + return self._select_worker_min_load( + workers=workers, + request_text=request_text, + ) + + def __del__(self) -> None: + self.close() diff --git a/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py b/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py index 2a728fb8c7..2fa98dd77d 100644 --- a/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py +++ b/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py @@ -3,6 +3,12 @@ from lightllm.server.pd_io_struct import PD_Client_Obj from lightllm.server.core.objs import SamplingParams from lightllm.server.multimodal_params import MultimodalParams +from lightllm.utils.log_utils import init_logger + +from .cache_aware import CacheAwarePolicy, CacheAwareConfig + + +logger = init_logger(__name__) class PDSelector: @@ -64,4 +70,43 @@ def select_p_d_node( return p_node, d_node def _importance_sampling(self, nodes: List[PD_Client_Obj]): - return random.choices(nodes, weights=[max(1.0 - e.run_status.total_token_usage_rate, 0.02) for e in nodes]) + return random.choices(nodes, weights=[max(1.0 - e.run_status.total_token_usage_rate, 0.02) for e in nodes])[0] + + +class LoadBalancedCacheAwareSelector(AdaptiveLoadSelector): + """refer to: https://github.com/sgl-project/sglang/blob/main/sgl-model-gateway/src/policies/cache_aware.rs""" + + def __init__(self, pd_manager): + super().__init__(pd_manager) + self.policy = CacheAwarePolicy(CacheAwareConfig()) + self.tree_workers = [] + + def update_nodes(self, prefill_nodes, decode_nodes): + super().update_nodes(prefill_nodes, decode_nodes) + + add_workers = set(self.prefill_nodes) - set(self.tree_workers) + remove_workers = set(self.tree_workers) - set(self.prefill_nodes) + + for worker in add_workers: + self.tree_workers.append(worker) + + for worker in remove_workers: + self.tree_workers.remove(worker) + + self.tree_workers = self.prefill_nodes + + def select_p_d_node( + self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams + ) -> Tuple[PD_Client_Obj, PD_Client_Obj]: + assert isinstance(prompt, str), "prompt must be a string for cache-aware selection" + p_node = self.policy.select_worker(self.prefill_nodes, request_text=prompt) + d_node = self._importance_sampling(self.decode_nodes) + + p_node.update_load(len(prompt)) + + logger.info( + f"LoadBalancedCacheAwareSelector: selected p_node={p_node.url() if p_node else None}, " + f"d_node={d_node.url() if d_node else None}" + ) + + return p_node, d_node diff --git a/lightllm/server/httpserver_for_pd_master/pd_selector/tree.py b/lightllm/server/httpserver_for_pd_master/pd_selector/tree.py new file mode 100644 index 0000000000..5b9afd95df --- /dev/null +++ b/lightllm/server/httpserver_for_pd_master/pd_selector/tree.py @@ -0,0 +1,318 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from heapq import heappop, heappush +from itertools import count +from threading import RLock +from typing import Dict, List, Optional, Tuple + +try: + from ._pd_tree_rust import PrefixMatchResult as RustPrefixMatchResult + from ._pd_tree_rust import Tree as RustTree +except Exception: + RustPrefixMatchResult = None + RustTree = None + + +_EPOCH_COUNTER = count() + + +def _get_epoch() -> int: + return next(_EPOCH_COUNTER) + + +def _shared_prefix_count(a: str, b: str) -> int: + matched = 0 + for ca, cb in zip(a, b): + if ca != cb: + break + matched += 1 + return matched + + +@dataclass(slots=True) +class PrefixMatchResult: + tenant: str + matched_char_count: int + input_char_count: int + + +PythonPrefixMatchResult = PrefixMatchResult + + +@dataclass(slots=True) +class _Node: + text: str + children: Dict[str, "_Node"] = field(default_factory=dict) + tenant_last_access_time: Dict[str, int] = field(default_factory=dict) + parent: Optional["_Node"] = None + last_tenant: Optional[str] = None + + +class Tree: + """ + Python translation of the Rust cache-aware radix tree. + + Notes: + - Uses a coarse-grained lock for correctness and simpler behavior parity. + - Keeps per-tenant char counts for eviction decisions. + """ + + def __init__(self) -> None: + self.root = _Node(text="") + self.tenant_char_count: Dict[str, int] = {} + self._lock = RLock() + + def insert(self, text: str, tenant: str) -> None: + with self._lock: + self.root.tenant_last_access_time.setdefault(tenant, 0) + self.tenant_char_count.setdefault(tenant, 0) + + remaining = text + prev = self.root + + while remaining: + first_char = remaining[0] + child = prev.children.get(first_char) + + if child is None: + remaining_char_count = len(remaining) + epoch = _get_epoch() + new_node = _Node( + text=remaining, + tenant_last_access_time={tenant: epoch}, + parent=prev, + last_tenant=tenant, + ) + self.tenant_char_count[tenant] = self.tenant_char_count.get(tenant, 0) + remaining_char_count + prev.children[first_char] = new_node + return + + shared_count = _shared_prefix_count(remaining, child.text) + child_len = len(child.text) + + if shared_count < child_len: + matched_text = child.text[:shared_count] + contracted_text = child.text[shared_count:] + matched_text_count = shared_count + + new_node = _Node( + text=matched_text, + tenant_last_access_time=dict(child.tenant_last_access_time), + parent=prev, + last_tenant=child.last_tenant, + ) + new_node.children[contracted_text[0]] = child + + child.text = contracted_text + child.parent = new_node + prev.children[first_char] = new_node + + if tenant not in new_node.tenant_last_access_time: + self.tenant_char_count[tenant] = self.tenant_char_count.get(tenant, 0) + matched_text_count + new_node.tenant_last_access_time[tenant] = 0 + + prev = new_node + remaining = remaining[shared_count:] + else: + if tenant not in child.tenant_last_access_time: + self.tenant_char_count[tenant] = self.tenant_char_count.get(tenant, 0) + child_len + child.tenant_last_access_time[tenant] = 0 + prev = child + remaining = remaining[shared_count:] + + epoch = _get_epoch() + prev.tenant_last_access_time[tenant] = epoch + prev.last_tenant = tenant + + def prefix_match_with_counts(self, text: str) -> PrefixMatchResult: + with self._lock: + remaining = text + matched_chars = 0 + prev = self.root + + while remaining: + first_char = remaining[0] + child = prev.children.get(first_char) + if child is None: + break + + shared_count = _shared_prefix_count(remaining, child.text) + child_len = len(child.text) + + if shared_count == child_len: + matched_chars += shared_count + remaining = remaining[shared_count:] + prev = child + else: + matched_chars += shared_count + prev = child + break + + curr = prev + + if curr.last_tenant and curr.last_tenant in curr.tenant_last_access_time: + tenant = curr.last_tenant + else: + tenant = next(iter(curr.tenant_last_access_time), "empty") + curr.last_tenant = tenant + + if tenant != "empty": + curr.tenant_last_access_time[tenant] = _get_epoch() + + return PythonPrefixMatchResult( + tenant=tenant, + matched_char_count=matched_chars, + input_char_count=len(text), + ) + + def prefix_match(self, text: str) -> Tuple[str, str]: + result = self.prefix_match_with_counts(text) + return text[: result.matched_char_count], result.tenant + + def prefix_match_tenant(self, text: str, tenant: str) -> str: + with self._lock: + remaining = text + matched_chars = 0 + prev = self.root + + while remaining: + first_char = remaining[0] + child = prev.children.get(first_char) + if child is None: + break + if tenant not in child.tenant_last_access_time: + break + + shared_count = _shared_prefix_count(remaining, child.text) + child_len = len(child.text) + + if shared_count == child_len: + matched_chars += shared_count + remaining = remaining[shared_count:] + prev = child + else: + matched_chars += shared_count + prev = child + break + + if tenant in prev.tenant_last_access_time: + prev.tenant_last_access_time[tenant] = _get_epoch() + prev.last_tenant = tenant + + return text[:matched_chars] + + @staticmethod + def _leaf_of(node: _Node) -> List[str]: + candidates: Dict[str, bool] = {tenant: True for tenant in node.tenant_last_access_time} + for child in node.children.values(): + for tenant in child.tenant_last_access_time: + candidates[tenant] = False + return [tenant for tenant, is_leaf in candidates.items() if is_leaf] + + def evict_tenant_by_size(self, max_size: int) -> None: + with self._lock: + stack = [self.root] + pq: List[Tuple[int, str, _Node]] = [] + + while stack: + curr = stack.pop() + stack.extend(curr.children.values()) + for tenant in self._leaf_of(curr): + ts = curr.tenant_last_access_time.get(tenant) + if ts is not None: + heappush(pq, (ts, tenant, curr)) + + while pq: + _, tenant, node = heappop(pq) + used_size = self.tenant_char_count.get(tenant, 0) + if used_size <= max_size: + continue + + if tenant not in node.tenant_last_access_time: + continue + if any(tenant in child.tenant_last_access_time for child in node.children.values()): + continue + + node_len = len(node.text) + self.tenant_char_count[tenant] = max(0, self.tenant_char_count.get(tenant, 0) - node_len) + + node.tenant_last_access_time.pop(tenant, None) + if node.last_tenant == tenant: + node.last_tenant = next(iter(node.tenant_last_access_time), None) + + parent = node.parent + if not node.children and not node.tenant_last_access_time and parent is not None: + if node.text: + parent.children.pop(node.text[0], None) + + if parent is not None and tenant in parent.tenant_last_access_time: + has_child_with_tenant = any( + tenant in child.tenant_last_access_time for child in parent.children.values() + ) + if not has_child_with_tenant: + ts = parent.tenant_last_access_time.get(tenant) + if ts is not None: + heappush(pq, (ts, tenant, parent)) + + if self.tenant_char_count.get(tenant, 0) == 0: + self.tenant_char_count.pop(tenant, None) + + def remove_tenant(self, tenant: str) -> None: + with self._lock: + stack = [self.root] + queue: List[_Node] = [] + + while stack: + curr = stack.pop() + stack.extend(curr.children.values()) + + if tenant in curr.tenant_last_access_time: + has_child_with_tenant = any( + tenant in child.tenant_last_access_time for child in curr.children.values() + ) + if not has_child_with_tenant: + queue.append(curr) + + while queue: + curr = queue.pop(0) + curr.tenant_last_access_time.pop(tenant, None) + if curr.last_tenant == tenant: + curr.last_tenant = next(iter(curr.tenant_last_access_time), None) + + parent = curr.parent + if not curr.children and not curr.tenant_last_access_time and parent is not None: + if curr.text: + parent.children.pop(curr.text[0], None) + + if parent is not None and tenant in parent.tenant_last_access_time: + has_child_with_tenant = any( + tenant in child.tenant_last_access_time for child in parent.children.values() + ) + if not has_child_with_tenant: + queue.append(parent) + + self.tenant_char_count.pop(tenant, None) + + def get_tenant_char_count(self) -> Dict[str, int]: + with self._lock: + return dict(self.tenant_char_count) + + def get_used_size_per_tenant(self) -> Dict[str, int]: + with self._lock: + used_size_per_tenant: Dict[str, int] = {} + stack = [self.root] + while stack: + curr = stack.pop() + text_count = len(curr.text) + for tenant in curr.tenant_last_access_time: + used_size_per_tenant[tenant] = used_size_per_tenant.get(tenant, 0) + text_count + stack.extend(curr.children.values()) + return used_size_per_tenant + + +PythonTree = Tree + +if RustTree is not None and RustPrefixMatchResult is not None: + PrefixMatchResult = RustPrefixMatchResult + Tree = RustTree diff --git a/lightllm/server/pd_io_struct.py b/lightllm/server/pd_io_struct.py index 1d68f81a9e..4cdb3d0a63 100644 --- a/lightllm/server/pd_io_struct.py +++ b/lightllm/server/pd_io_struct.py @@ -54,6 +54,7 @@ class PD_Client_Obj: start_args: object # 节点的启动参数信息,用于做匹配性的校验,防止运行过程中出现问题。 websocket: WebSocket = None # 用于通信的 websocket 连接对象 run_status: _PD_Client_RunStatus = field(default_factory=_PD_Client_RunStatus) + processed_prompt_len: int = 0 def __post_init__(self): if self.mode not in ["prefill", "decode"]: @@ -65,6 +66,15 @@ def __post_init__(self): def to_llm_url(self): return f"http://{self.client_ip_port}/pd_generate_stream" + def load(self): + return self.processed_prompt_len + + def update_load(self, prompt_len: int): + self.processed_prompt_len += prompt_len + + def url(self): + return self.client_ip_port + @dataclass class PD_Master_Obj: diff --git a/lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/up_status.py b/lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/up_status.py index bc1d00f384..b8fbd3e276 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/up_status.py +++ b/lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/up_status.py @@ -94,8 +94,10 @@ async def up_kv_status_task(self, pd_master_obj: PD_Master_Obj): logger.info(f"up kv status: {upkv_status}") else: await asyncio.sleep(3) + except BaseException as e: logger.error(str(e)) + await task_queue.put(upkv_status) raise e except asyncio.CancelledError: logger.info(f"up_kv_status_task {pd_master_obj} cancelled") diff --git a/rust/pd_tree/Cargo.lock b/rust/pd_tree/Cargo.lock new file mode 100644 index 0000000000..17f719e1a3 --- /dev/null +++ b/rust/pd_tree/Cargo.lock @@ -0,0 +1,403 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "autocfg" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2032f911046de80f0a198e0901378627c33f59ea0ac00e363d481118bd70a53" + +[[package]] +name = "bitflags" +version = "2.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4388bee8683e3d04af747c73422af53102d2bd24d9eadb6cbc100baef4b43f8" + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + +[[package]] +name = "dashmap" +version = "6.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6361d5c062261c78a176addb82d4c821ae42bed6089de0e12603cd25de2059c" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown", + "lock_api", + "once_cell", + "parking_lot_core", +] + +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", +] + +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "indoc" +version = "2.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79cf5c93f93228cf8efb3ba362535fb11199ac548a09ce117c9b1adc3030d706" +dependencies = [ + "rustversion", +] + +[[package]] +name = "libc" +version = "0.2.186" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" + +[[package]] +name = "lightllm-pd-tree" +version = "0.1.0" +dependencies = [ + "dashmap", + "parking_lot", + "pyo3", + "rand", + "tracing", +] + +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" + +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" + +[[package]] +name = "portable-atomic" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" + +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.23.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.23.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.23.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.23.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.23.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn", +] + +[[package]] +name = "quote" +version = "1.0.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfbc457d0c7a0759a614551b11a6409e5951f6c7537be1f1b7682b9ae9230368" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "rand" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44c5af06bb1b7d3216d91932aed5265164bf384dc89cd6ba05cf59a35f5f76ea" +dependencies = [ + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "smallvec" +version = "1.15.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ed6a63f02c8539c91a8685a86f4099661ba3da017932f6ebbea6de3f0fa7c90" + +[[package]] +name = "syn" +version = "2.0.118" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9ae57f904213ebb649ce6895b8a66c66f0203b9319718f69a5612a065b1422" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "target-lexicon" +version = "0.12.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" + +[[package]] +name = "tracing" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +dependencies = [ + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tracing-core" +version = "0.1.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +dependencies = [ + "once_cell", +] + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "unindent" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" + +[[package]] +name = "wasip2" +version = "1.0.4+wasi-0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b67efb37e106e55ce722a510d6b5f9c17f083e5fc79afc2badeb12cc313d9487" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "wit-bindgen" +version = "0.57.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ebf944e87a7c253233ad6766e082e3cd714b5d03812acc24c318f549614536e" + +[[package]] +name = "zerocopy" +version = "0.8.52" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce1022995ff5ff5d841ad7d994facc23098cd40152f2c1d11cd607c6f530653f" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.52" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ae7f38b72ec2a254e2b87ef277cf2cd4fb97cbebf944faa6f33354da0867930" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/rust/pd_tree/Cargo.toml b/rust/pd_tree/Cargo.toml new file mode 100644 index 0000000000..c0a70ac027 --- /dev/null +++ b/rust/pd_tree/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "lightllm-pd-tree" +version = "0.1.0" +edition = "2021" + +[lib] +name = "_pd_tree_rust" +crate-type = ["cdylib"] + +[dependencies] +pyo3 = { version = "0.23", features = ["extension-module"] } +dashmap = "6" +parking_lot = "0.12.5" +tracing = "0.1.44" +rand = "0.9.2" diff --git a/rust/pd_tree/src/lib.rs b/rust/pd_tree/src/lib.rs new file mode 100644 index 0000000000..4bd1c8cf0b --- /dev/null +++ b/rust/pd_tree/src/lib.rs @@ -0,0 +1,2334 @@ +// modified from https://github.com/sgl-project/sglang/blob/main/sgl-model-gateway/src/policies/tree.rs +use std::{ + cmp::Reverse, + collections::{BinaryHeap, HashMap, VecDeque}, + hash::{BuildHasherDefault, Hasher}, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, RwLock, + }, +}; + +use dashmap::{mapref::entry::Entry, DashMap}; +use tracing::debug; +use pyo3::prelude::*; + +type NodeRef = Arc; + +/// Shard counts for DashMaps to balance concurrency vs allocation overhead. +/// Default DashMap uses num_cpus * 4 shards (e.g., 256 on 64-core machines). +/// +/// Root node uses higher shard count since ALL requests pass through it. +/// Other nodes use lower count as traffic diverges through the tree. +/// +/// This reduces memory by ~90% vs default while maintaining good concurrency. +const ROOT_SHARD_COUNT: usize = 32; +const NODE_SHARD_COUNT: usize = 8; + +/// Create a children DashMap for non-root nodes +#[inline] +fn new_children_map() -> DashMap { + DashMap::with_hasher_and_shard_amount(CharHasherBuilder::default(), NODE_SHARD_COUNT) +} + +/// Create a tenant access time DashMap for non-root nodes +#[inline] +fn new_tenant_map() -> DashMap { + DashMap::with_shard_amount(NODE_SHARD_COUNT) +} + +/// Interned tenant ID to avoid repeated string allocations. +/// Using Arc allows cheap cloning and comparison. +pub type TenantId = Arc; + +/// Result of a prefix match operation, including char counts to avoid recomputation. +#[pyclass(module = "lightllm.server.httpserver_for_pd_master.pd_selector._pd_tree_rust")] +#[derive(Debug, Clone)] +pub struct PrefixMatchResult { + /// The tenant that owns the matched prefix (zero-copy) + #[pyo3(get)] + pub tenant: String, + /// Number of characters matched + #[pyo3(get)] + pub matched_char_count: usize, + /// Total number of characters in the input text + #[pyo3(get)] + pub input_char_count: usize, +} + +/// A fast identity hasher for single-character keys (used in children DashMap). +/// Since chars have good distribution already, we use identity hashing with mixing. +#[derive(Default)] +struct CharHasher(u64); + +impl Hasher for CharHasher { + #[inline(always)] + fn finish(&self) -> u64 { + self.0 + } + + #[inline(always)] + fn write(&mut self, bytes: &[u8]) { + // Fast path for 4-byte (char) writes - avoid loop + if bytes.len() == 4 { + let val = u32::from_ne_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]); + // Mix with golden ratio for better distribution + self.0 = (val as u64).wrapping_mul(0x9E3779B97F4A7C15); + return; + } + // Fallback for other sizes (shouldn't happen for char keys) + for &byte in bytes { + self.0 = self.0.wrapping_mul(0x100000001b3).wrapping_add(byte as u64); + } + } + + #[inline(always)] + fn write_u32(&mut self, i: u32) { + // Chars are u32 - use golden ratio multiplication for distribution + self.0 = (i as u64).wrapping_mul(0x9E3779B97F4A7C15); + } +} + +type CharHasherBuilder = BuildHasherDefault; + +/// Advance a string slice by N characters, returning the remaining slice. +/// Returns empty string if n >= char count. +/// Optimized: uses direct byte slicing for ASCII, falls back to char_indices for UTF-8. +#[inline] +fn advance_by_chars(s: &str, n: usize) -> &str { + if n == 0 { + return s; + } + if n >= s.len() { + return ""; + } + // Fast path: if first N bytes are all ASCII, we can slice directly + let bytes = s.as_bytes(); + if bytes[..n].is_ascii() { + // Safe: we verified all bytes in [0..n] are ASCII (valid UTF-8 boundary) + return &s[n..]; + } + // Slow path: UTF-8 requires char-by-char traversal + s.char_indices() + .nth(n) + .map(|(idx, _)| &s[idx..]) + .unwrap_or("") +} + +/// Get the first N characters of a string as a new String. +/// More efficient than chars().take(n).collect() for known bounds. +#[inline] +fn take_chars(s: &str, n: usize) -> String { + if n == 0 { + return String::new(); + } + s.char_indices() + .nth(n) + .map(|(idx, _)| s[..idx].to_string()) + .unwrap_or_else(|| s.to_string()) +} + +/// Node text with cached character count to avoid repeated O(n) chars().count() calls. +#[derive(Debug)] +struct NodeText { + /// The actual text stored in this node + text: String, + /// Cached character count (UTF-8 chars, not bytes) + char_count: usize, +} + +impl NodeText { + #[inline] + fn new(text: String) -> Self { + let char_count = text.chars().count(); + Self { text, char_count } + } + + #[inline] + fn empty() -> Self { + Self { + text: String::new(), + char_count: 0, + } + } + + #[inline] + fn char_count(&self) -> usize { + self.char_count + } + + #[inline] + fn as_str(&self) -> &str { + &self.text + } + + #[inline] + fn first_char(&self) -> Option { + self.text.chars().next() + } + + /// Split the text at a character boundary, returning the prefix and suffix. + /// This is more efficient than slice_by_chars as it computes both at once. + #[inline] + fn split_at_char(&self, char_idx: usize) -> (NodeText, NodeText) { + if char_idx == 0 { + return (NodeText::empty(), self.clone_text()); + } + if char_idx >= self.char_count { + return (self.clone_text(), NodeText::empty()); + } + + // Find byte index for the character boundary + let byte_idx = self + .text + .char_indices() + .nth(char_idx) + .map(|(i, _)| i) + .unwrap_or(self.text.len()); + + let prefix = NodeText { + text: self.text[..byte_idx].to_string(), + char_count: char_idx, + }; + let suffix = NodeText { + text: self.text[byte_idx..].to_string(), + char_count: self.char_count - char_idx, + }; + (prefix, suffix) + } + + #[inline] + fn clone_text(&self) -> NodeText { + NodeText { + text: self.text.clone(), + char_count: self.char_count, + } + } +} + +impl Clone for NodeText { + fn clone(&self) -> Self { + self.clone_text() + } +} + +/// Global epoch counter for LRU ordering. +/// Uses a simple incrementing counter instead of wall clock time. +/// +/// Benefits: +/// - No syscall overhead (vs SystemTime::now()) +/// - Smaller memory footprint (u64 vs u128) +/// - Perfectly monotonic (no clock skew issues) +/// +/// For LRU eviction, relative ordering is all that matters. +static EPOCH_COUNTER: AtomicU64 = AtomicU64::new(0); + +/// Get the next epoch value for LRU timestamp ordering. +/// Uses fetch_add for lock-free, monotonically increasing values. +/// Relaxed ordering is sufficient since we only need eventual consistency +/// for approximate LRU behavior. +#[inline] +fn get_epoch() -> u64 { + EPOCH_COUNTER.fetch_add(1, Ordering::Relaxed) +} + +#[derive(Debug)] +struct Node { + /// Children nodes indexed by first character. + /// Using custom hasher optimized for char keys. + children: DashMap, + /// Node text with cached character count + text: RwLock, + /// Per-tenant last access epoch for LRU ordering. Using TenantId (Arc) for cheap cloning. + tenant_last_access_time: DashMap, + /// Parent pointer for upward traversal during timestamp updates + parent: RwLock>, + /// Cached last-accessed tenant for O(1) lookup during prefix match. + /// Avoids O(shards) DashMap iteration in the common case. + last_tenant: parking_lot::RwLock>, +} + +#[pyclass(module = "lightllm.server.httpserver_for_pd_master.pd_selector._pd_tree_rust")] +#[derive(Debug)] +pub struct Tree { + root: NodeRef, + /// Per-tenant character count for size tracking. Using TenantId for consistency. + pub tenant_char_count: DashMap, +} + +// For the heap + +struct EvictionEntry { + timestamp: u64, + tenant: TenantId, + node: NodeRef, +} + +impl Eq for EvictionEntry {} + +#[allow(clippy::non_canonical_partial_ord_impl)] +impl PartialOrd for EvictionEntry { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.timestamp.cmp(&other.timestamp)) + } +} + +impl Ord for EvictionEntry { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.timestamp.cmp(&other.timestamp) + } +} + +impl PartialEq for EvictionEntry { + fn eq(&self, other: &Self) -> bool { + self.timestamp == other.timestamp + } +} + +// For char operations +// Note that in rust, `.len()` or slice is operated on the "byte" level. It causes issues for UTF-8 characters because one character might use multiple bytes. +// https://en.wikipedia.org/wiki/UTF-8 + +/// Count matching prefix characters between two strings. +/// Returns the number of characters that match from the start. +/// Optimized: uses fast byte comparison for ASCII, falls back to char iteration for UTF-8. +#[inline] +fn shared_prefix_count(a: &str, b: &str) -> usize { + let a_bytes = a.as_bytes(); + let b_bytes = b.as_bytes(); + + // Find common byte prefix length using iterator (potentially SIMD-optimized) + let common_byte_len = a_bytes + .iter() + .zip(b_bytes) + .position(|(&a_byte, &b_byte)| a_byte != b_byte) + .unwrap_or_else(|| a_bytes.len().min(b_bytes.len())); + + // If the common byte prefix is all ASCII, byte count == char count + // Otherwise, fall back to char-by-char comparison for UTF-8 safety + if a_bytes[..common_byte_len].is_ascii() { + common_byte_len + } else { + shared_prefix_count_chars(a, b) + } +} + +/// Fallback char-by-char comparison for strings with non-ASCII characters. +#[inline] +fn shared_prefix_count_chars(a: &str, b: &str) -> usize { + a.chars() + .zip(b.chars()) + .take_while(|(a_char, b_char)| a_char == b_char) + .count() +} + +/// Intern a tenant string into an Arc for efficient storage and comparison. +#[inline] +fn intern_tenant(tenant: &str) -> TenantId { + Arc::from(tenant) +} + +impl Default for Tree { + fn default() -> Self { + Self::new() + } +} + +impl Tree { + /* + Thread-safe multi tenant radix tree + + 1. Storing data for multiple tenants (the overlap of multiple radix tree) + 2. Node-level lock to enable concurrent access on nodes + 3. Leaf LRU eviction based on tenant access time + + Optimizations: + - Cached character counts in NodeText to avoid O(n) chars().count() calls + - Interned tenant IDs (Arc) for cheap cloning and comparison + - Batched timestamp updates to reduce syscalls + - Custom hasher for char keys in children DashMap + */ + pub fn new() -> Self { + Tree { + // Root uses higher shard count since ALL requests pass through it + root: Arc::new(Node { + children: DashMap::with_hasher_and_shard_amount( + CharHasherBuilder::default(), + ROOT_SHARD_COUNT, + ), + text: RwLock::new(NodeText::empty()), + tenant_last_access_time: DashMap::with_shard_amount(ROOT_SHARD_COUNT), + parent: RwLock::new(None), + last_tenant: parking_lot::RwLock::new(None), + }), + tenant_char_count: DashMap::with_shard_amount(ROOT_SHARD_COUNT), + } + } + /// Return the list of tenants for which this node is a leaf. + /// A tenant is a leaf at this node if no children have that tenant. + fn leaf_of(node: &NodeRef) -> Vec { + let mut candidates: HashMap = node + .tenant_last_access_time + .iter() + .map(|entry| (Arc::clone(entry.key()), true)) + .collect(); + + for child in node.children.iter() { + for tenant in child.value().tenant_last_access_time.iter() { + // Mark as non-leaf if any child has this tenant + candidates.insert(Arc::clone(tenant.key()), false); + } + } + + candidates + .into_iter() + .filter(|(_, is_leaf)| *is_leaf) + .map(|(tenant, _)| tenant) + .collect() + } + + #[allow(dead_code)] + fn node_to_string(node: &NodeRef, prefix: &str, is_last: bool) -> String { + let mut result = String::new(); + + // Add prefix and branch character + result.push_str(prefix); + result.push_str(if is_last { "└── " } else { "├── " }); + + // Add node text + let node_text = node.text.read().unwrap(); + result.push_str(&format!("'{}' [", node_text.as_str())); + + // Add tenant information with epoch values + let mut tenant_info = Vec::new(); + for entry in node.tenant_last_access_time.iter() { + let tenant_id = entry.key(); + let epoch = entry.value(); + tenant_info.push(format!("{} | epoch:{}", tenant_id, epoch)); + } + + result.push_str(&tenant_info.join(", ")); + result.push_str("]\n"); + + // Process children + let children: Vec<_> = node.children.iter().collect(); + let child_count = children.len(); + + for (i, entry) in children.iter().enumerate() { + let is_last_child = i == child_count - 1; + let new_prefix = format!("{}{}", prefix, if is_last { " " } else { "│ " }); + + result.push_str(&Tree::node_to_string( + entry.value(), + &new_prefix, + is_last_child, + )); + } + + result + } + +} + +#[pymethods] +impl Tree { + + #[new] + pub fn py_new() -> Self { + Self::new() + } + + pub fn insert(&self, text: &str, tenant: &str) { + // Insert text into tree with given tenant + // Use slice-based traversal to avoid Vec allocation + + // Intern the tenant ID once for reuse + let tenant_id = intern_tenant(tenant); + + // Ensure tenant exists at root (don't update timestamp - root is never evicted) + self.root + .tenant_last_access_time + .entry(Arc::clone(&tenant_id)) + .or_insert(0); + + self.tenant_char_count + .entry(Arc::clone(&tenant_id)) + .or_insert(0); + + // Track remaining text as a slice - no allocation needed + let mut remaining = text; + let mut prev = Arc::clone(&self.root); + + // Result type to carry state out of the match block + // This allows the entry guard to be dropped before we update prev + enum InsertStep { + Done, + Continue { + next_prev: NodeRef, + advance_chars: usize, + }, + } + + while !remaining.is_empty() { + let first_char = remaining.chars().next().unwrap(); + + // Use entry API for atomic check-and-insert semantics (required for thread safety) + let step = match prev.children.entry(first_char) { + Entry::Vacant(entry) => { + // No match - create new node with remaining text (this is the leaf) + // Compute remaining char count lazily - only here when creating leaf + let remaining_char_count = remaining.chars().count(); + let epoch = get_epoch(); + + let new_node = Arc::new(Node { + children: new_children_map(), + text: RwLock::new(NodeText::new(remaining.to_string())), + tenant_last_access_time: new_tenant_map(), + parent: RwLock::new(Some(Arc::clone(&prev))), + last_tenant: parking_lot::RwLock::new(Some(Arc::clone(&tenant_id))), + }); + + // Attach tenant to the new leaf node with timestamp + self.tenant_char_count + .entry(Arc::clone(&tenant_id)) + .and_modify(|count| *count += remaining_char_count) + .or_insert(remaining_char_count); + new_node + .tenant_last_access_time + .insert(Arc::clone(&tenant_id), epoch); + + entry.insert(new_node); + InsertStep::Done + } + + Entry::Occupied(mut entry) => { + let matched_node = entry.get().clone(); + + let matched_node_text = matched_node.text.read().unwrap(); + let matched_node_text_count = matched_node_text.char_count(); + let matched_node_text_str = matched_node_text.as_str(); + + // Use slice-based comparison - no allocation + let shared_count = shared_prefix_count(remaining, matched_node_text_str); + + if shared_count < matched_node_text_count { + // Split the matched node + let (matched_text, contracted_text) = + matched_node_text.split_at_char(shared_count); + let matched_text_count = shared_count; + + // Drop read lock before creating new node + drop(matched_node_text); + + let new_node = Arc::new(Node { + text: RwLock::new(matched_text), + children: new_children_map(), + parent: RwLock::new(Some(Arc::clone(&prev))), + tenant_last_access_time: matched_node.tenant_last_access_time.clone(), + last_tenant: parking_lot::RwLock::new( + matched_node.last_tenant.read().clone(), + ), + }); + + let first_new_char = contracted_text.first_char().unwrap(); + new_node + .children + .insert(first_new_char, Arc::clone(&matched_node)); + + entry.insert(Arc::clone(&new_node)); + + *matched_node.text.write().unwrap() = contracted_text; + *matched_node.parent.write().unwrap() = Some(Arc::clone(&new_node)); + + // Attach tenant to the new split node (intermediate - no timestamp update) + // The cloned DashMap already has the tenant; just ensure char count is correct + if !new_node + .tenant_last_access_time + .contains_key(tenant_id.as_ref()) + { + self.tenant_char_count + .entry(Arc::clone(&tenant_id)) + .and_modify(|count| *count += matched_text_count) + .or_insert(matched_text_count); + new_node + .tenant_last_access_time + .insert(Arc::clone(&tenant_id), 0); + } + + InsertStep::Continue { + next_prev: new_node, + advance_chars: shared_count, + } + } else { + // Full match - move to next node (intermediate - no timestamp update) + drop(matched_node_text); + + // Ensure tenant exists at this intermediate node + if !matched_node + .tenant_last_access_time + .contains_key(tenant_id.as_ref()) + { + self.tenant_char_count + .entry(Arc::clone(&tenant_id)) + .and_modify(|count| *count += matched_node_text_count) + .or_insert(matched_node_text_count); + matched_node + .tenant_last_access_time + .insert(Arc::clone(&tenant_id), 0); + } + + InsertStep::Continue { + next_prev: matched_node, + advance_chars: shared_count, + } + } + } + }; + + // Entry guard is now dropped - safe to update prev + match step { + InsertStep::Done => return, // New leaf created with timestamp, we're done + InsertStep::Continue { + next_prev, + advance_chars, + } => { + prev = next_prev; + remaining = advance_by_chars(remaining, advance_chars); + } + } + } + + // Loop exited normally (remaining empty) - prev is the leaf node + // Update its timestamp for LRU ordering + let epoch = get_epoch(); + prev.tenant_last_access_time + .insert(Arc::clone(&tenant_id), epoch); + } + + /// Performs prefix matching and returns detailed result with char counts. + /// Optimized: no string allocations, deferred char counting. + pub fn prefix_match_with_counts(&self, text: &str) -> PrefixMatchResult { + let mut remaining = text; + let mut matched_chars = 0; + let mut prev = Arc::clone(&self.root); + + while !remaining.is_empty() { + let first_char = remaining.chars().next().unwrap(); + + let child_node = prev.children.get(&first_char).map(|e| e.value().clone()); + + if let Some(matched_node) = child_node { + let matched_text_guard = matched_node.text.read().unwrap(); + let matched_node_text_count = matched_text_guard.char_count(); + + // Use slice-based comparison - no allocation + let shared_count = shared_prefix_count(remaining, matched_text_guard.as_str()); + drop(matched_text_guard); + + if shared_count == matched_node_text_count { + // Full match with current node's text, continue to next node + matched_chars += shared_count; + remaining = advance_by_chars(remaining, shared_count); + prev = matched_node; + } else { + // Partial match - still use this node for tenant selection + matched_chars += shared_count; + prev = matched_node; + break; + } + } else { + // No match found, stop here + break; + } + } + + let curr = prev; + + // Try cached tenant first (O(1)) before falling back to O(shards) DashMap iteration. + // The cache is valid if the tenant still exists in tenant_last_access_time. + let tenant: TenantId = { + let cached = curr.last_tenant.read(); + if let Some(ref t) = *cached { + if curr.tenant_last_access_time.contains_key(t.as_ref()) { + Arc::clone(t) + } else { + drop(cached); + // Cache stale, fall back to iteration and update cache + let t = curr + .tenant_last_access_time + .iter() + .next() + .map(|kv| Arc::clone(kv.key())) + .unwrap_or_else(|| Arc::from("empty")); + *curr.last_tenant.write() = Some(Arc::clone(&t)); + t + } + } else { + drop(cached); + // No cache, iterate and populate cache + let t = curr + .tenant_last_access_time + .iter() + .next() + .map(|kv| Arc::clone(kv.key())) + .unwrap_or_else(|| Arc::from("empty")); + *curr.last_tenant.write() = Some(Arc::clone(&t)); + t + } + }; + + // Update timestamp probabilistically (1 in 8 matches) to reduce DashMap contention. + // LRU eviction doesn't need perfect accuracy - approximate timestamps suffice. + let epoch = get_epoch(); + if epoch & 0x7 == 0 { + curr.tenant_last_access_time + .insert(Arc::clone(&tenant), epoch); + } + + // Compute input char count directly from input text. + // This is equivalent to matched_chars + remaining.chars().count() but avoids + // needing to track remaining precisely through the traversal. + let input_char_count = text.chars().count(); + + PrefixMatchResult { + tenant: tenant.to_string(), + matched_char_count: matched_chars, + input_char_count, + } + } + + /// Legacy prefix_match API for backward compatibility. + /// Note: This computes matched_text which has allocation overhead. + pub fn prefix_match(&self, text: &str) -> (String, String) { + let result = self.prefix_match_with_counts(text); + let matched_text = take_chars(text, result.matched_char_count); + (matched_text, result.tenant.to_string()) + } + + #[allow(dead_code)] + pub fn prefix_match_tenant(&self, text: &str, tenant: &str) -> String { + // Use slice-based traversal - no Vec allocation + + // Intern tenant ID once for efficient lookups + let tenant_id = intern_tenant(tenant); + + let mut remaining = text; + let mut matched_chars = 0; + let mut prev = Arc::clone(&self.root); + + while !remaining.is_empty() { + let first_char = remaining.chars().next().unwrap(); + + let child_node = prev.children.get(&first_char).map(|e| e.value().clone()); + + if let Some(matched_node) = child_node { + // Only continue matching if this node belongs to the specified tenant + if !matched_node + .tenant_last_access_time + .contains_key(tenant_id.as_ref()) + { + break; + } + + let matched_text_guard = matched_node.text.read().unwrap(); + let matched_node_text_count = matched_text_guard.char_count(); + + // Use slice-based comparison - no allocation + let shared_count = shared_prefix_count(remaining, matched_text_guard.as_str()); + drop(matched_text_guard); + + if shared_count == matched_node_text_count { + // Full match with current node's text, continue to next node + matched_chars += shared_count; + remaining = advance_by_chars(remaining, shared_count); + prev = matched_node; + } else { + // Partial match - still use this node for timestamp update + matched_chars += shared_count; + prev = matched_node; + break; + } + } else { + // No match found, stop here + break; + } + } + + let curr = prev; + + // Only update timestamp if we found a match for the specified tenant. + // Update matched node only - ancestor propagation is unnecessary. + if curr + .tenant_last_access_time + .contains_key(tenant_id.as_ref()) + { + let epoch = get_epoch(); + curr.tenant_last_access_time + .insert(Arc::clone(&tenant_id), epoch); + } + + // Build result from original input using char count + take_chars(text, matched_chars) + } + + pub fn evict_tenant_by_size(&self, max_size: usize) { + // Calculate used size and collect leaves + let mut stack = vec![Arc::clone(&self.root)]; + let mut pq = BinaryHeap::new(); + + while let Some(curr) = stack.pop() { + for child in curr.children.iter() { + stack.push(Arc::clone(child.value())); + } + + // Add leaves to priority queue + for tenant in Tree::leaf_of(&curr) { + if let Some(timestamp) = curr.tenant_last_access_time.get(tenant.as_ref()) { + pq.push(Reverse(EvictionEntry { + timestamp: *timestamp, + tenant: Arc::clone(&tenant), + node: Arc::clone(&curr), + })); + } + } + } + + debug!("Before eviction - Used size per tenant:"); + for entry in self.tenant_char_count.iter() { + debug!("Tenant: {}, Size: {}", entry.key(), entry.value()); + } + + // Process eviction + while let Some(Reverse(entry)) = pq.pop() { + let EvictionEntry { tenant, node, .. } = entry; + + if let Some(used_size) = self.tenant_char_count.get(tenant.as_ref()) { + if *used_size <= max_size { + continue; + } + } + + // Verify this node is still a leaf for this tenant (may have changed) + // A node is a leaf for a tenant if no children have that tenant + let is_still_leaf = node.tenant_last_access_time.contains_key(tenant.as_ref()) + && !node.children.iter().any(|child| { + child + .value() + .tenant_last_access_time + .contains_key(tenant.as_ref()) + }); + if !is_still_leaf { + continue; + } + + // Decrement when removing tenant from node + let node_len = node.text.read().unwrap().char_count(); + self.tenant_char_count + .entry(Arc::clone(&tenant)) + .and_modify(|count| { + *count = count.saturating_sub(node_len); + }); + + // Remove tenant from node + node.tenant_last_access_time.remove(tenant.as_ref()); + + // Get parent reference outside of the borrow scope + let parent_opt = node.parent.read().unwrap().clone(); + + // Remove empty nodes + if node.children.is_empty() && node.tenant_last_access_time.is_empty() { + if let Some(ref parent) = parent_opt { + if let Some(fc) = node.text.read().unwrap().first_char() { + parent.children.remove(&fc); + } + } + } + + // If parent has this tenant and no other children have it, + // parent becomes a new leaf - add to priority queue + if let Some(ref parent) = parent_opt { + if parent.tenant_last_access_time.contains_key(tenant.as_ref()) { + let has_child_with_tenant = parent.children.iter().any(|child| { + child + .value() + .tenant_last_access_time + .contains_key(tenant.as_ref()) + }); + + if !has_child_with_tenant { + // Add parent to priority queue as new leaf + if let Some(timestamp) = parent.tenant_last_access_time.get(tenant.as_ref()) + { + pq.push(Reverse(EvictionEntry { + timestamp: *timestamp, + tenant: Arc::clone(&tenant), + node: Arc::clone(parent), + })); + } + } + } + } + } + + debug!("After eviction - Used size per tenant:"); + for entry in self.tenant_char_count.iter() { + debug!("Tenant: {}, Size: {}", entry.key(), entry.value()); + } + } + + pub fn remove_tenant(&self, tenant: &str) { + // Intern tenant ID once for efficient lookups + let tenant_id = intern_tenant(tenant); + + // 1. Find all the leaves for the tenant + // A leaf is a node that has this tenant but no children have it + let mut stack = vec![Arc::clone(&self.root)]; + let mut queue = VecDeque::new(); + + while let Some(curr) = stack.pop() { + for child in curr.children.iter() { + stack.push(Arc::clone(child.value())); + } + + // Check if this node is a leaf for the tenant + if curr + .tenant_last_access_time + .contains_key(tenant_id.as_ref()) + { + let has_child_with_tenant = curr.children.iter().any(|child| { + child + .value() + .tenant_last_access_time + .contains_key(tenant_id.as_ref()) + }); + if !has_child_with_tenant { + queue.push_back(Arc::clone(&curr)); + } + } + } + + // 2. Start from the leaves and traverse up to the root, removing the tenant from each node + while let Some(curr) = queue.pop_front() { + // Remove tenant from node + curr.tenant_last_access_time.remove(tenant_id.as_ref()); + + // Get parent reference outside of the borrow scope + let parent_opt = curr.parent.read().unwrap().clone(); + + // Remove empty nodes + if curr.children.is_empty() && curr.tenant_last_access_time.is_empty() { + if let Some(ref parent) = parent_opt { + if let Some(fc) = curr.text.read().unwrap().first_char() { + parent.children.remove(&fc); + } + } + } + + // If parent has this tenant and no other children have it, + // parent becomes a new leaf - add to queue + if let Some(ref parent) = parent_opt { + if parent + .tenant_last_access_time + .contains_key(tenant_id.as_ref()) + { + let has_child_with_tenant = parent.children.iter().any(|child| { + child + .value() + .tenant_last_access_time + .contains_key(tenant_id.as_ref()) + }); + + if !has_child_with_tenant { + queue.push_back(Arc::clone(parent)); + } + } + } + } + + // 3. Remove the tenant from the tenant_char_count map + self.tenant_char_count.remove(tenant_id.as_ref()); + } + + #[allow(dead_code)] + pub fn get_tenant_char_count(&self) -> HashMap { + self.tenant_char_count + .iter() + .map(|entry| (entry.key().to_string(), *entry.value())) + .collect() + } + + #[allow(dead_code)] + pub fn get_used_size_per_tenant(&self) -> HashMap { + // perform a DFS to traverse all nodes and calculate the total size used by each tenant + + let mut used_size_per_tenant: HashMap = HashMap::new(); + let mut stack = vec![Arc::clone(&self.root)]; + + while let Some(curr) = stack.pop() { + // Use cached char count instead of chars().count() + let text_count = curr.text.read().unwrap().char_count(); + + for tenant in curr.tenant_last_access_time.iter() { + let size = used_size_per_tenant + .entry(tenant.key().to_string()) + .or_insert(0); + *size += text_count; + } + + for child in curr.children.iter() { + stack.push(Arc::clone(child.value())); + } + } + + used_size_per_tenant + } + + + #[allow(dead_code)] + pub fn pretty_print(&self) { + if self.root.children.is_empty() { + return; + } + + let mut result = String::new(); + let children: Vec<_> = self.root.children.iter().collect(); + let child_count = children.len(); + + for (i, entry) in children.iter().enumerate() { + let is_last = i == child_count - 1; + result.push_str(&Tree::node_to_string(entry.value(), "", is_last)); + } + + println!("{result}"); + } +} + +#[pymodule] +fn _pd_tree_rust(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + Ok(()) +} + + +// Unit tests +#[cfg(test)] +mod tests { + use std::{ + thread, + time::{Duration, Instant}, + }; + + use rand::{ + distr::{Alphanumeric, SampleString}, + rng as thread_rng, Rng, + }; + + use super::*; + + /// Helper to convert tenant_char_count to HashMap for comparison + fn get_maintained_counts(tree: &Tree) -> HashMap { + tree.tenant_char_count + .iter() + .map(|entry| (entry.key().to_string(), *entry.value())) + .collect() + } + + #[test] + fn test_tenant_char_count() { + let tree = Tree::new(); + + tree.insert("apple", "tenant1"); + tree.insert("apricot", "tenant1"); + tree.insert("banana", "tenant1"); + tree.insert("amplify", "tenant2"); + tree.insert("application", "tenant2"); + + let computed_sizes = tree.get_used_size_per_tenant(); + let maintained_counts = get_maintained_counts(&tree); + + println!("Phase 1 - Maintained vs Computed counts:"); + println!( + "Maintained: {:?}\nComputed: {:?}", + maintained_counts, computed_sizes + ); + assert_eq!( + maintained_counts, computed_sizes, + "Phase 1: Initial insertions" + ); + + tree.insert("apartment", "tenant1"); + tree.insert("appetite", "tenant2"); + tree.insert("ball", "tenant1"); + tree.insert("box", "tenant2"); + + let computed_sizes = tree.get_used_size_per_tenant(); + let maintained_counts = get_maintained_counts(&tree); + + println!("Phase 2 - Maintained vs Computed counts:"); + println!( + "Maintained: {:?}\nComputed: {:?}", + maintained_counts, computed_sizes + ); + assert_eq!( + maintained_counts, computed_sizes, + "Phase 2: Additional insertions" + ); + + tree.insert("zebra", "tenant1"); + tree.insert("zebra", "tenant2"); + tree.insert("zero", "tenant1"); + tree.insert("zero", "tenant2"); + + let computed_sizes = tree.get_used_size_per_tenant(); + let maintained_counts = get_maintained_counts(&tree); + + println!("Phase 3 - Maintained vs Computed counts:"); + println!( + "Maintained: {:?}\nComputed: {:?}", + maintained_counts, computed_sizes + ); + assert_eq!( + maintained_counts, computed_sizes, + "Phase 3: Overlapping insertions" + ); + + tree.evict_tenant_by_size(10); + + let computed_sizes = tree.get_used_size_per_tenant(); + let maintained_counts = get_maintained_counts(&tree); + + println!("Phase 4 - Maintained vs Computed counts:"); + println!( + "Maintained: {:?}\nComputed: {:?}", + maintained_counts, computed_sizes + ); + assert_eq!(maintained_counts, computed_sizes, "Phase 4: After eviction"); + } + + fn random_string(len: usize) -> String { + Alphanumeric.sample_string(&mut thread_rng(), len) + } + + #[test] + fn test_cold_start() { + let tree = Tree::new(); + + let (matched_text, tenant) = tree.prefix_match("hello"); + + assert_eq!(matched_text, ""); + assert_eq!(tenant, "empty"); + } + + #[test] + fn test_exact_match_seq() { + let tree = Tree::new(); + tree.insert("hello", "tenant1"); + tree.pretty_print(); + tree.insert("apple", "tenant2"); + tree.pretty_print(); + tree.insert("banana", "tenant3"); + tree.pretty_print(); + + let (matched_text, tenant) = tree.prefix_match("hello"); + assert_eq!(matched_text, "hello"); + assert_eq!(tenant, "tenant1"); + + let (matched_text, tenant) = tree.prefix_match("apple"); + assert_eq!(matched_text, "apple"); + assert_eq!(tenant, "tenant2"); + + let (matched_text, tenant) = tree.prefix_match("banana"); + assert_eq!(matched_text, "banana"); + assert_eq!(tenant, "tenant3"); + } + + #[test] + fn test_exact_match_concurrent() { + let tree = Arc::new(Tree::new()); + + // spawn 3 threads for insert + let tree_clone = Arc::clone(&tree); + + let texts = ["hello", "apple", "banana"]; + let tenants = ["tenant1", "tenant2", "tenant3"]; + + let mut handles = vec![]; + + for i in 0..3 { + let tree_clone = Arc::clone(&tree_clone); + let text = texts[i]; + let tenant = tenants[i]; + + let handle = thread::spawn(move || { + tree_clone.insert(text, tenant); + }); + + handles.push(handle); + } + + // wait + for handle in handles { + handle.join().unwrap(); + } + + // spawn 3 threads for match + let mut handles = vec![]; + + let tree_clone = Arc::clone(&tree); + + for i in 0..3 { + let tree_clone = Arc::clone(&tree_clone); + let text = texts[i]; + let tenant = tenants[i]; + + let handle = thread::spawn(move || { + let (matched_text, matched_tenant) = tree_clone.prefix_match(text); + assert_eq!(matched_text, text); + assert_eq!(matched_tenant, tenant); + }); + + handles.push(handle); + } + + // wait + for handle in handles { + handle.join().unwrap(); + } + } + + #[test] + fn test_partial_match_concurrent() { + let tree = Arc::new(Tree::new()); + + // spawn 3 threads for insert + let tree_clone = Arc::clone(&tree); + + static TEXTS: [&str; 3] = ["apple", "apabc", "acbdeds"]; + + let mut handles = vec![]; + + for text in TEXTS.iter() { + let tree_clone = Arc::clone(&tree_clone); + let tenant = "tenant0"; + + let handle = thread::spawn(move || { + tree_clone.insert(text, tenant); + }); + + handles.push(handle); + } + + // wait + for handle in handles { + handle.join().unwrap(); + } + + // spawn 3 threads for match + let mut handles = vec![]; + + let tree_clone = Arc::clone(&tree); + + for text in TEXTS.iter() { + let tree_clone = Arc::clone(&tree_clone); + let tenant = "tenant0"; + + let handle = thread::spawn(move || { + let (matched_text, matched_tenant) = tree_clone.prefix_match(text); + assert_eq!(matched_text, *text); + assert_eq!(matched_tenant, tenant); + }); + + handles.push(handle); + } + + // wait + for handle in handles { + handle.join().unwrap(); + } + } + + #[test] + fn test_group_prefix_insert_match_concurrent() { + static PREFIXES: [&str; 4] = [ + "Clock strikes midnight, I'm still wide awake", + "Got dreams bigger than these city lights", + "Time waits for no one, gotta make my move", + "Started from the bottom, that's no metaphor", + ]; + let suffixes = [ + "Got too much to prove, ain't got time to lose", + "History in the making, yeah, you can't erase this", + ]; + let tree = Arc::new(Tree::new()); + + let mut handles = vec![]; + + for (i, prefix) in PREFIXES.iter().enumerate() { + for suffix in suffixes.iter() { + let tree_clone = Arc::clone(&tree); + let text = format!("{} {}", prefix, suffix); + let tenant = format!("tenant{}", i); + + let handle = thread::spawn(move || { + tree_clone.insert(&text, &tenant); + }); + + handles.push(handle); + } + } + + // wait + for handle in handles { + handle.join().unwrap(); + } + + tree.pretty_print(); + + // check matching using multi threads + let mut handles = vec![]; + + for (i, prefix) in PREFIXES.iter().enumerate() { + let tree_clone = Arc::clone(&tree); + + let handle = thread::spawn(move || { + let (matched_text, matched_tenant) = tree_clone.prefix_match(prefix); + let tenant = format!("tenant{}", i); + assert_eq!(matched_text, *prefix); + assert_eq!(matched_tenant, tenant); + }); + + handles.push(handle); + } + + // wait + for handle in handles { + handle.join().unwrap(); + } + } + + #[test] + fn test_mixed_concurrent_insert_match() { + // ensure it does not deadlock instead of doing correctness check + + static PREFIXES: [&str; 4] = [ + "Clock strikes midnight, I'm still wide awake", + "Got dreams bigger than these city lights", + "Time waits for no one, gotta make my move", + "Started from the bottom, that's no metaphor", + ]; + let suffixes = [ + "Got too much to prove, ain't got time to lose", + "History in the making, yeah, you can't erase this", + ]; + let tree = Arc::new(Tree::new()); + + let mut handles = vec![]; + + for (i, prefix) in PREFIXES.iter().enumerate() { + for suffix in suffixes.iter() { + let tree_clone = Arc::clone(&tree); + let text = format!("{} {}", prefix, suffix); + let tenant = format!("tenant{}", i); + + let handle = thread::spawn(move || { + tree_clone.insert(&text, &tenant); + }); + + handles.push(handle); + } + } + + // check matching using multi threads + for prefix in PREFIXES.iter() { + let tree_clone = Arc::clone(&tree); + + let handle = thread::spawn(move || { + let (_matched_text, _matched_tenant) = tree_clone.prefix_match(prefix); + }); + + handles.push(handle); + } + + // wait + for handle in handles { + handle.join().unwrap(); + } + } + + #[test] + fn test_utf8_split_seq() { + // The string should be indexed and split by a utf-8 value basis instead of byte basis + // use .chars() to get the iterator of the utf-8 value + let tree = Arc::new(Tree::new()); + + static TEST_PAIRS: [(&str, &str); 3] = [ + ("你好嗎", "tenant1"), + ("你好喔", "tenant2"), + ("你心情好嗎", "tenant3"), + ]; + + // Insert sequentially + for (text, tenant) in TEST_PAIRS.iter() { + tree.insert(text, tenant); + } + + tree.pretty_print(); + + for (text, tenant) in TEST_PAIRS.iter() { + let (matched_text, matched_tenant) = tree.prefix_match(text); + assert_eq!(matched_text, *text); + assert_eq!(matched_tenant, *tenant); + } + } + + #[test] + fn test_utf8_split_concurrent() { + let tree = Arc::new(Tree::new()); + + static TEST_PAIRS: [(&str, &str); 3] = [ + ("你好嗎", "tenant1"), + ("你好喔", "tenant2"), + ("你心情好嗎", "tenant3"), + ]; + + // Create multiple threads for insertion + let mut handles = vec![]; + + for (text, tenant) in TEST_PAIRS.iter() { + let tree_clone = Arc::clone(&tree); + + let handle = thread::spawn(move || { + tree_clone.insert(text, tenant); + }); + + handles.push(handle); + } + + // Wait for all insertions to complete + for handle in handles { + handle.join().unwrap(); + } + + tree.pretty_print(); + + // Create multiple threads for matching + let mut handles = vec![]; + + for (text, tenant) in TEST_PAIRS.iter() { + let tree_clone = Arc::clone(&tree); + + let handle = thread::spawn(move || { + let (matched_text, matched_tenant) = tree_clone.prefix_match(text); + assert_eq!(matched_text, *text); + assert_eq!(matched_tenant, *tenant); + }); + + handles.push(handle); + } + + // Wait for all matches to complete + for handle in handles { + handle.join().unwrap(); + } + } + + #[test] + fn test_simple_eviction() { + let tree = Tree::new(); + let max_size = 5; + + // Insert strings for both tenants + tree.insert("hello", "tenant1"); // size 5 + + tree.insert("hello", "tenant2"); // size 5 + thread::sleep(Duration::from_millis(10)); + tree.insert("world", "tenant2"); // size 5, total for tenant2 = 10 + + tree.pretty_print(); + + let sizes_before = tree.get_used_size_per_tenant(); + assert_eq!(sizes_before.get("tenant1").unwrap(), &5); // "hello" = 5 + assert_eq!(sizes_before.get("tenant2").unwrap(), &10); // "hello" + "world" = 10 + + // Evict - should remove "hello" from tenant2 as it's the oldest + tree.evict_tenant_by_size(max_size); + + tree.pretty_print(); + + let sizes_after = tree.get_used_size_per_tenant(); + assert_eq!(sizes_after.get("tenant1").unwrap(), &5); // Should be unchanged + assert_eq!(sizes_after.get("tenant2").unwrap(), &5); // Only "world" remains + + let (matched, tenant) = tree.prefix_match("world"); + assert_eq!(matched, "world"); + assert_eq!(tenant, "tenant2"); + } + + #[test] + fn test_advanced_eviction() { + let tree = Tree::new(); + + // Set limits for each tenant + let max_size: usize = 100; + + // Define prefixes + let prefixes = ["aqwefcisdf", "iajsdfkmade", "kjnzxcvewqe", "iejksduqasd"]; + + // Insert strings with shared prefixes + for _i in 0..100 { + for (j, prefix) in prefixes.iter().enumerate() { + let random_suffix = random_string(10); + let text = format!("{}{}", prefix, random_suffix); + let tenant = format!("tenant{}", j + 1); + tree.insert(&text, &tenant); + } + } + + // Perform eviction + tree.evict_tenant_by_size(max_size); + + // Check sizes after eviction + let sizes_after = tree.get_used_size_per_tenant(); + for (tenant, &size) in sizes_after.iter() { + assert!( + size <= max_size, + "Tenant {} exceeds size limit. Current size: {}, Limit: {}", + tenant, + size, + max_size + ); + } + } + + #[test] + fn test_concurrent_operations_with_eviction() { + // Ensure eviction works fine with concurrent insert and match operations for a given period + + let tree = Arc::new(Tree::new()); + let mut handles = vec![]; + let test_duration = Duration::from_secs(10); + let start_time = Instant::now(); + let max_size = 100; // Single max size for all tenants + + // Spawn eviction thread + { + let tree = Arc::clone(&tree); + let handle = thread::spawn(move || { + while start_time.elapsed() < test_duration { + // Run eviction + tree.evict_tenant_by_size(max_size); + + // Sleep for 5 seconds + thread::sleep(Duration::from_secs(5)); + } + }); + handles.push(handle); + } + + // Spawn 4 worker threads + for thread_id in 0..4 { + let tree = Arc::clone(&tree); + let handle = thread::spawn(move || { + let mut rng = rand::rng(); + let tenant = format!("tenant{}", thread_id + 1); + let prefix = format!("prefix{}", thread_id); + + while start_time.elapsed() < test_duration { + // Random decision: match or insert (70% match, 30% insert) + if rng.random_bool(0.7) { + // Perform match operation + let random_len = rng.random_range(3..10); + let search_str = format!("{}{}", prefix, random_string(random_len)); + let (_matched, _) = tree.prefix_match(&search_str); + } else { + // Perform insert operation + let random_len = rng.random_range(5..15); + let insert_str = format!("{}{}", prefix, random_string(random_len)); + tree.insert(&insert_str, &tenant); + // println!("Thread {} inserted: {}", thread_id, insert_str); + } + + // Small random sleep to vary timing + thread::sleep(Duration::from_millis(rng.random_range(10..100))); + } + }); + handles.push(handle); + } + + // Wait for all threads to complete + for handle in handles { + handle.join().unwrap(); + } + + // final eviction + tree.evict_tenant_by_size(max_size); + + // Final size check + let final_sizes = tree.get_used_size_per_tenant(); + println!("Final sizes after test completion: {:?}", final_sizes); + + for (_, &size) in final_sizes.iter() { + assert!( + size <= max_size, + "Tenant exceeds size limit. Final size: {}, Limit: {}", + size, + max_size + ); + } + } + + #[test] + fn test_leaf_of() { + let tree = Tree::new(); + + // Helper to convert leaves to strings for easier assertion + let leaves_as_strings = + |leaves: &[TenantId]| -> Vec { leaves.iter().map(|t| t.to_string()).collect() }; + + // Single node + tree.insert("hello", "tenant1"); + let leaves = Tree::leaf_of(&tree.root.children.get(&'h').unwrap()); + assert_eq!(leaves_as_strings(&leaves), vec!["tenant1"]); + + // Node with multiple tenants + tree.insert("hello", "tenant2"); + let leaves = Tree::leaf_of(&tree.root.children.get(&'h').unwrap()); + let leaves_str = leaves_as_strings(&leaves); + assert_eq!(leaves_str.len(), 2); + assert!(leaves_str.contains(&"tenant1".to_string())); + assert!(leaves_str.contains(&"tenant2".to_string())); + + // Non-leaf node + tree.insert("hi", "tenant1"); + let leaves = Tree::leaf_of(&tree.root.children.get(&'h').unwrap()); + assert!(leaves.is_empty()); + } + + #[test] + fn test_get_used_size_per_tenant() { + let tree = Tree::new(); + + // Single tenant + tree.insert("hello", "tenant1"); + tree.insert("world", "tenant1"); + let sizes = tree.get_used_size_per_tenant(); + + tree.pretty_print(); + println!("{:?}", sizes); + assert_eq!(sizes.get("tenant1").unwrap(), &10); // "hello" + "world" + + // Multiple tenants sharing nodes + tree.insert("hello", "tenant2"); + tree.insert("help", "tenant2"); + let sizes = tree.get_used_size_per_tenant(); + + tree.pretty_print(); + println!("{:?}", sizes); + assert_eq!(sizes.get("tenant1").unwrap(), &10); + assert_eq!(sizes.get("tenant2").unwrap(), &6); // "hello" + "p" + + // UTF-8 characters + tree.insert("你好", "tenant3"); + let sizes = tree.get_used_size_per_tenant(); + tree.pretty_print(); + println!("{:?}", sizes); + assert_eq!(sizes.get("tenant3").unwrap(), &2); // 2 Chinese characters + + tree.pretty_print(); + } + + #[test] + fn test_prefix_match_tenant() { + let tree = Tree::new(); + + // Insert overlapping prefixes for different tenants + tree.insert("hello", "tenant1"); // tenant1: hello + tree.insert("hello", "tenant2"); // tenant2: hello + tree.insert("hello world", "tenant2"); // tenant2: hello -> world + tree.insert("help", "tenant1"); // tenant1: hel -> p + tree.insert("helicopter", "tenant2"); // tenant2: hel -> icopter + + assert_eq!(tree.prefix_match_tenant("hello", "tenant1"), "hello"); // Full match for tenant1 + assert_eq!(tree.prefix_match_tenant("help", "tenant1"), "help"); // Exclusive to tenant1 + assert_eq!(tree.prefix_match_tenant("hel", "tenant1"), "hel"); // Shared prefix + assert_eq!(tree.prefix_match_tenant("hello world", "tenant1"), "hello"); // Should stop at tenant1's boundary + assert_eq!(tree.prefix_match_tenant("helicopter", "tenant1"), "hel"); // Should stop at tenant1's boundary + + assert_eq!(tree.prefix_match_tenant("hello", "tenant2"), "hello"); // Full match for tenant2 + assert_eq!( + tree.prefix_match_tenant("hello world", "tenant2"), + "hello world" + ); // Exclusive to tenant2 + assert_eq!( + tree.prefix_match_tenant("helicopter", "tenant2"), + "helicopter" + ); // Exclusive to tenant2 + assert_eq!(tree.prefix_match_tenant("hel", "tenant2"), "hel"); // Shared prefix + assert_eq!(tree.prefix_match_tenant("help", "tenant2"), "hel"); // Should stop at tenant2's boundary + + assert_eq!(tree.prefix_match_tenant("hello", "tenant3"), ""); // Non-existent tenant + assert_eq!(tree.prefix_match_tenant("help", "tenant3"), ""); // Non-existent tenant + } + + #[test] + fn test_simple_tenant_eviction() { + let tree = Tree::new(); + + // Insert data for multiple tenants + tree.insert("hello", "tenant1"); + tree.insert("world", "tenant1"); + tree.insert("hello", "tenant2"); + tree.insert("help", "tenant2"); + + let initial_sizes = tree.get_used_size_per_tenant(); + assert_eq!(initial_sizes.get("tenant1").unwrap(), &10); // "hello" + "world" + assert_eq!(initial_sizes.get("tenant2").unwrap(), &6); // "hello" + "p" + + // Evict tenant1 + tree.remove_tenant("tenant1"); + + let final_sizes = tree.get_used_size_per_tenant(); + assert!( + !final_sizes.contains_key("tenant1"), + "tenant1 should be completely removed" + ); + assert_eq!( + final_sizes.get("tenant2").unwrap(), + &6, + "tenant2 should be unaffected" + ); + + assert_eq!(tree.prefix_match_tenant("hello", "tenant1"), ""); + assert_eq!(tree.prefix_match_tenant("world", "tenant1"), ""); + + assert_eq!(tree.prefix_match_tenant("hello", "tenant2"), "hello"); + assert_eq!(tree.prefix_match_tenant("help", "tenant2"), "help"); + } + + #[test] + fn test_complex_tenant_eviction() { + let tree = Tree::new(); + + // Create a more complex tree structure with shared prefixes + tree.insert("apple", "tenant1"); + tree.insert("application", "tenant1"); + tree.insert("apple", "tenant2"); + tree.insert("appetite", "tenant2"); + tree.insert("banana", "tenant1"); + tree.insert("banana", "tenant2"); + tree.insert("ball", "tenant2"); + + let initial_sizes = tree.get_used_size_per_tenant(); + println!("Initial sizes: {:?}", initial_sizes); + tree.pretty_print(); + + // Evict tenant1 + tree.remove_tenant("tenant1"); + + let final_sizes = tree.get_used_size_per_tenant(); + println!("Final sizes: {:?}", final_sizes); + tree.pretty_print(); + + assert!( + !final_sizes.contains_key("tenant1"), + "tenant1 should be completely removed" + ); + + assert_eq!(tree.prefix_match_tenant("apple", "tenant1"), ""); + assert_eq!(tree.prefix_match_tenant("application", "tenant1"), ""); + assert_eq!(tree.prefix_match_tenant("banana", "tenant1"), ""); + + assert_eq!(tree.prefix_match_tenant("apple", "tenant2"), "apple"); + assert_eq!(tree.prefix_match_tenant("appetite", "tenant2"), "appetite"); + assert_eq!(tree.prefix_match_tenant("banana", "tenant2"), "banana"); + assert_eq!(tree.prefix_match_tenant("ball", "tenant2"), "ball"); + + let tenant2_size = final_sizes.get("tenant2").unwrap(); + assert_eq!(tenant2_size, &(5 + 5 + 6 + 2)); // "apple" + "etite" + "banana" + "ll" + } + + // ==================== Edge Case Tests ==================== + + #[test] + fn test_empty_string_input() { + let tree = Tree::new(); + + // Insert empty string + tree.insert("", "tenant1"); + + // Match empty string + let (matched, tenant) = tree.prefix_match(""); + assert_eq!(matched, ""); + assert_eq!(tenant, "tenant1"); + + // Insert non-empty, then match empty + tree.insert("hello", "tenant2"); + let (matched, tenant) = tree.prefix_match(""); + assert_eq!(matched, ""); + assert_eq!(tenant, "tenant1"); + } + + #[test] + fn test_single_character_operations() { + let tree = Tree::new(); + + // Insert single characters + tree.insert("a", "tenant1"); + tree.insert("b", "tenant2"); + tree.insert("c", "tenant1"); + + let (matched, tenant) = tree.prefix_match("a"); + assert_eq!(matched, "a"); + assert_eq!(tenant, "tenant1"); + + let (matched, tenant) = tree.prefix_match("b"); + assert_eq!(matched, "b"); + assert_eq!(tenant, "tenant2"); + + // Match with longer string starting with single char + let (matched, tenant) = tree.prefix_match("abc"); + assert_eq!(matched, "a"); + assert_eq!(tenant, "tenant1"); + } + + #[test] + fn test_prefix_is_subset_of_existing() { + let tree = Tree::new(); + + // Insert longer string first + tree.insert("application", "tenant1"); + + // Now insert prefix of existing + tree.insert("app", "tenant2"); + + // Match the prefix - both tenants own "app" node + let (matched, tenant) = tree.prefix_match("app"); + assert_eq!(matched, "app"); + assert!(tenant == "tenant1" || tenant == "tenant2"); + + // Match longer string + let (matched, tenant) = tree.prefix_match("application"); + assert_eq!(matched, "application"); + assert_eq!(tenant, "tenant1"); + + // Match "apple" - matches "app" + "l" from the child node = "appl" + // Then 'e' doesn't match 'i' in the remaining suffix, so stops at 4 chars + let (matched, _tenant) = tree.prefix_match("apple"); + assert_eq!(matched, "appl"); + } + + #[test] + fn test_existing_is_prefix_of_new() { + let tree = Tree::new(); + + // Insert shorter string first + tree.insert("app", "tenant1"); + + // Now insert longer string with same prefix + tree.insert("application", "tenant2"); + + let (matched, tenant) = tree.prefix_match("app"); + assert_eq!(matched, "app"); + assert!(tenant == "tenant1" || tenant == "tenant2"); + + let (matched, tenant) = tree.prefix_match("application"); + assert_eq!(matched, "application"); + assert_eq!(tenant, "tenant2"); + + // "applesauce" matches "app" + "l" from the child node = "appl" + // Then 'e' in "esauce" doesn't match 'i' in the suffix, so matching stops + let (matched, _tenant) = tree.prefix_match("applesauce"); + assert_eq!(matched, "appl"); + } + + // ==================== prefix_match_with_counts Tests ==================== + + #[test] + fn test_prefix_match_with_counts_accuracy() { + let tree = Tree::new(); + + tree.insert("hello world", "tenant1"); + + // Exact match + let result = tree.prefix_match_with_counts("hello world"); + assert_eq!(result.matched_char_count, 11); + assert_eq!(result.input_char_count, 11); + assert_eq!(&*result.tenant, "tenant1"); + + // Partial match + let result = tree.prefix_match_with_counts("hello"); + assert_eq!(result.matched_char_count, 5); + assert_eq!(result.input_char_count, 5); + + // Extended match + let result = tree.prefix_match_with_counts("hello world and more"); + assert_eq!(result.matched_char_count, 11); + assert_eq!(result.input_char_count, 20); + + // No match + let result = tree.prefix_match_with_counts("goodbye"); + assert_eq!(result.matched_char_count, 0); + assert_eq!(result.input_char_count, 7); + } + + #[test] + fn test_prefix_match_with_counts_utf8() { + let tree = Tree::new(); + + // UTF-8 string: 5 characters, more bytes + tree.insert("你好世界呀", "tenant1"); + + let result = tree.prefix_match_with_counts("你好世界呀"); + assert_eq!(result.matched_char_count, 5); + assert_eq!(result.input_char_count, 5); + + let result = tree.prefix_match_with_counts("你好"); + assert_eq!(result.matched_char_count, 2); + assert_eq!(result.input_char_count, 2); + + // Mixed ASCII and UTF-8 + tree.insert("hello你好", "tenant2"); + let result = tree.prefix_match_with_counts("hello你好世界"); + assert_eq!(result.matched_char_count, 7); // "hello你好" = 7 chars + assert_eq!(result.input_char_count, 9); // "hello你好世界" = 9 chars + } + + // ==================== Node Splitting Edge Cases ==================== + + #[test] + fn test_split_at_first_character() { + let tree = Tree::new(); + + // Insert "abc" + tree.insert("abc", "tenant1"); + + // Insert "aXX" - should split at first char + tree.insert("aXX", "tenant2"); + + let (matched, tenant) = tree.prefix_match("abc"); + assert_eq!(matched, "abc"); + assert_eq!(tenant, "tenant1"); + + let (matched, tenant) = tree.prefix_match("aXX"); + assert_eq!(matched, "aXX"); + assert_eq!(tenant, "tenant2"); + + let (matched, _) = tree.prefix_match("a"); + assert_eq!(matched, "a"); + } + + #[test] + fn test_split_at_last_character() { + let tree = Tree::new(); + + // Insert "abcd" + tree.insert("abcd", "tenant1"); + + // Insert "abcX" - should split at last char of shared prefix + tree.insert("abcX", "tenant2"); + + let (matched, tenant) = tree.prefix_match("abcd"); + assert_eq!(matched, "abcd"); + assert_eq!(tenant, "tenant1"); + + let (matched, tenant) = tree.prefix_match("abcX"); + assert_eq!(matched, "abcX"); + assert_eq!(tenant, "tenant2"); + + let (matched, _) = tree.prefix_match("abc"); + assert_eq!(matched, "abc"); + } + + #[test] + fn test_multiple_splits_same_path() { + let tree = Tree::new(); + + // Create a chain of splits + tree.insert("abcdefgh", "tenant1"); + tree.insert("abcdef", "tenant2"); + tree.insert("abcd", "tenant3"); + tree.insert("ab", "tenant4"); + + // Verify all paths work + assert_eq!(tree.prefix_match("abcdefgh").0, "abcdefgh"); + assert_eq!(tree.prefix_match("abcdef").0, "abcdef"); + assert_eq!(tree.prefix_match("abcd").0, "abcd"); + assert_eq!(tree.prefix_match("ab").0, "ab"); + assert_eq!(tree.prefix_match("a").0, "a"); + } + + // ==================== High Contention Stress Tests ==================== + + #[test] + fn test_high_contention_same_prefix() { + let tree = Arc::new(Tree::new()); + let num_threads = 16; + let ops_per_thread = 100; + let mut handles = vec![]; + + // All threads operate on strings with same prefix + for thread_id in 0..num_threads { + let tree = Arc::clone(&tree); + let handle = thread::spawn(move || { + let tenant = format!("tenant{}", thread_id); + for i in 0..ops_per_thread { + let text = format!("shared_prefix_{}", i); + tree.insert(&text, &tenant); + + // Immediately try to match + let (matched, _) = tree.prefix_match(&text); + assert!( + matched.starts_with("shared_prefix_"), + "Match should start with shared_prefix_" + ); + } + }); + handles.push(handle); + } + + for handle in handles { + handle.join().expect("Thread panicked"); + } + + // Verify tree is still consistent + let sizes = tree.get_used_size_per_tenant(); + assert!(!sizes.is_empty(), "Tree should have entries"); + } + + #[test] + fn test_rapid_insert_remove_cycles() { + let tree = Arc::new(Tree::new()); + let num_cycles = 50; + + for cycle in 0..num_cycles { + let tenant = format!("tenant{}", cycle % 5); + + // Insert several entries + for i in 0..10 { + let text = format!("cycle{}entry{}", cycle, i); + tree.insert(&text, &tenant); + } + + // Remove the tenant + tree.remove_tenant(&tenant); + + // Verify tenant is gone + let sizes = tree.get_used_size_per_tenant(); + assert!( + !sizes.contains_key(&tenant), + "Tenant {} should be removed after cycle {}", + tenant, + cycle + ); + } + } + + // ==================== ASCII/UTF-8 Consistency Tests ==================== + + #[test] + fn test_ascii_utf8_consistency() { + let tree = Tree::new(); + + // Insert ASCII + tree.insert("hello", "tenant1"); + + // Insert UTF-8 with same logical prefix (none) + tree.insert("你好", "tenant2"); + + // Insert mixed + tree.insert("hello你好", "tenant3"); + + // All should be retrievable + assert_eq!(tree.prefix_match("hello").0, "hello"); + assert_eq!(tree.prefix_match("你好").0, "你好"); + assert_eq!(tree.prefix_match("hello你好").0, "hello你好"); + + // Counts should be correct + let result = tree.prefix_match_with_counts("hello"); + assert_eq!(result.matched_char_count, 5); + assert_eq!(result.input_char_count, 5); + + let result = tree.prefix_match_with_counts("你好"); + assert_eq!(result.matched_char_count, 2); + assert_eq!(result.input_char_count, 2); + + let result = tree.prefix_match_with_counts("hello你好"); + assert_eq!(result.matched_char_count, 7); + assert_eq!(result.input_char_count, 7); + } + + #[test] + fn test_emoji_handling() { + let tree = Tree::new(); + + // Emoji are multi-byte UTF-8 + tree.insert("hello 👋", "tenant1"); + tree.insert("hello 👋🌍", "tenant2"); + + let (matched, tenant) = tree.prefix_match("hello 👋"); + assert_eq!(matched, "hello 👋"); + assert_eq!(tenant, "tenant1"); + + let (matched, tenant) = tree.prefix_match("hello 👋🌍"); + assert_eq!(matched, "hello 👋🌍"); + assert_eq!(tenant, "tenant2"); + + // Verify char count (not byte count) + let result = tree.prefix_match_with_counts("hello 👋"); + assert_eq!(result.matched_char_count, 7); + assert_eq!(result.input_char_count, 7); // h-e-l-l-o-space-emoji + } + + // ==================== Eviction Edge Cases ==================== + + #[test] + fn test_eviction_empty_tree() { + let tree = Tree::new(); + + // Should not panic on empty tree + tree.evict_tenant_by_size(100); + + let sizes = tree.get_used_size_per_tenant(); + assert!(sizes.is_empty()); + } + + #[test] + fn test_eviction_zero_max_size() { + let tree = Tree::new(); + + tree.insert("hello", "tenant1"); + tree.insert("world", "tenant1"); + + // Evict with max_size = 0 should remove everything + tree.evict_tenant_by_size(0); + + let sizes = tree.get_used_size_per_tenant(); + assert!( + sizes.is_empty() || sizes.values().all(|&v| v == 0), + "All tenants should be evicted or have zero size" + ); + } + + #[test] + fn test_eviction_single_tenant_all_entries() { + let tree = Tree::new(); + + // Insert many entries for single tenant + for i in 0..100 { + let text = format!("entry{:03}", i); + tree.insert(&text, "tenant1"); + } + + let initial_size = *tree.get_used_size_per_tenant().get("tenant1").unwrap(); + assert!(initial_size > 50, "Should have significant size"); + + // Evict to small size + tree.evict_tenant_by_size(50); + + let final_size = *tree.get_used_size_per_tenant().get("tenant1").unwrap_or(&0); + assert!( + final_size <= 50, + "Size {} should be <= 50 after eviction", + final_size + ); + } + + // ==================== Last Tenant Cache Tests ==================== + + #[test] + fn test_last_tenant_cache_update() { + let tree = Tree::new(); + + // Insert for tenant1 + tree.insert("hello", "tenant1"); + + // First match should return tenant1 + let (_, tenant) = tree.prefix_match("hello"); + assert_eq!(tenant, "tenant1"); + + // Insert for tenant2 on same path + tree.insert("hello", "tenant2"); + + // Match again - should still work (cache or iteration) + let (matched, _) = tree.prefix_match("hello"); + assert_eq!(matched, "hello"); + } + + #[test] + fn test_stale_cache_after_tenant_removal() { + let tree = Tree::new(); + + tree.insert("hello", "tenant1"); + tree.insert("hello", "tenant2"); + + // Access to populate cache + let _ = tree.prefix_match("hello"); + + // Remove tenant1 + tree.remove_tenant("tenant1"); + + // Should still work with tenant2 + let (matched, tenant) = tree.prefix_match("hello"); + assert_eq!(matched, "hello"); + assert_eq!(tenant, "tenant2"); + } + + // ==================== Consistency Verification Tests ==================== + + #[test] + fn test_char_count_consistency_after_operations() { + let tree = Tree::new(); + + // Helper to verify consistency + let verify_consistency = |tree: &Tree| { + let maintained = get_maintained_counts(tree); + let computed = tree.get_used_size_per_tenant(); + assert_eq!( + maintained, computed, + "Maintained counts should match computed counts" + ); + }; + + // Insert phase + for i in 0..50 { + tree.insert(&format!("prefix{}", i), "tenant1"); + tree.insert(&format!("other{}", i), "tenant2"); + } + verify_consistency(&tree); + + // Overlapping inserts + for i in 0..25 { + tree.insert(&format!("prefix{}", i), "tenant2"); + } + verify_consistency(&tree); + + // Eviction + tree.evict_tenant_by_size(100); + verify_consistency(&tree); + + // Tenant removal + tree.remove_tenant("tenant1"); + verify_consistency(&tree); + } + + #[test] + fn test_tree_structure_integrity_after_stress() { + let tree = Arc::new(Tree::new()); + let num_threads = 8; + let mut handles = vec![]; + + for thread_id in 0..num_threads { + let tree = Arc::clone(&tree); + let handle = thread::spawn(move || { + let mut rng = rand::rng(); + let tenant = format!("tenant{}", thread_id); + + for _ in 0..200 { + let op: u8 = rng.random_range(0..10); + let key = format!("key{}", rng.random_range(0..50)); + + match op { + 0..=6 => { + // Insert (70%) + tree.insert(&key, &tenant); + } + 7..=8 => { + // Match (20%) + let _ = tree.prefix_match(&key); + } + _ => { + // Match with counts (10%) + let _ = tree.prefix_match_with_counts(&key); + } + } + } + }); + handles.push(handle); + } + + for handle in handles { + handle.join().expect("Thread panicked during stress test"); + } + + // Verify tree is still functional + let sizes = tree.get_used_size_per_tenant(); + for (tenant, size) in sizes.iter() { + assert!(*size > 0, "Tenant {} should have positive size", tenant); + } + + // Verify char count consistency + let maintained = get_maintained_counts(&tree); + let computed = tree.get_used_size_per_tenant(); + assert_eq!( + maintained, computed, + "Counts should be consistent after stress test" + ); + } + + // ==================== Boundary Condition Tests ==================== + + #[test] + fn test_very_long_strings() { + let tree = Tree::new(); + + // Create a very long string (10KB) + let long_string: String = (0..10000) + .map(|i| ((i % 26) as u8 + b'a') as char) + .collect(); + + tree.insert(&long_string, "tenant1"); + + let (matched, tenant) = tree.prefix_match(&long_string); + assert_eq!(matched.len(), long_string.len()); + assert_eq!(tenant, "tenant1"); + + // Partial match of long string + let partial = &long_string[..5000]; + let (matched, _) = tree.prefix_match(partial); + assert_eq!(matched, partial); + } + + #[test] + fn test_many_tenants_same_path() { + let tree = Tree::new(); + + // 100 tenants all insert same string + for i in 0..100 { + tree.insert("shared_path", &format!("tenant{}", i)); + } + + // Match should return one of them + let (matched, _) = tree.prefix_match("shared_path"); + assert_eq!(matched, "shared_path"); + + // Verify all tenants are tracked + let sizes = tree.get_used_size_per_tenant(); + assert_eq!(sizes.len(), 100, "Should have 100 tenants"); + } + + #[test] + fn test_special_characters() { + let tree = Tree::new(); + + // Various special characters + let test_cases = vec![ + ("hello\nworld", "tenant1"), // newline + ("hello\tworld", "tenant2"), // tab + ("hello\0world", "tenant3"), // null byte + ("hello\u{A0}world", "tenant4"), // non-breaking space + ("path/to/file", "tenant5"), // slashes + ("query?param=value", "tenant6"), // URL-like + ]; + + for (text, tenant) in &test_cases { + tree.insert(text, tenant); + } + + for (text, tenant) in &test_cases { + let (matched, matched_tenant) = tree.prefix_match(text); + assert_eq!(matched, *text, "Failed for: {:?}", text); + assert_eq!(matched_tenant, *tenant); + } + } +} diff --git a/setup.py b/setup.py index 94c5b192e6..7b0664b9ff 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,22 @@ from setuptools import setup, find_packages +try: + from setuptools_rust import Binding, RustExtension +except ImportError: + Binding = None + RustExtension = None + package_data = {"lightllm": ["common/all_kernel_configs/*/*.json", "common/triton_utils/*/*/*/*/*.json"]} +rust_extensions = [] +if RustExtension is not None and Binding is not None: + rust_extensions = [ + RustExtension( + "lightllm.server.httpserver_for_pd_master.pd_selector._pd_tree_rust", + path="rust/pd_tree/Cargo.toml", + binding=Binding.PyO3, + debug=False, + ) + ] setup( name="lightllm", version="1.1.0", @@ -29,4 +45,6 @@ "orjson", ], package_data=package_data, + rust_extensions=rust_extensions, + zip_safe=False, )