Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ dist
.vscode
tmp/
requirements-musa.txt
logs/
logs/
target/
11 changes: 10 additions & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,13 @@ RUN if [ "${ENABLE_NIXL}" = "1" ]; then \
fi

COPY . /lightllm
RUN pip install -e /lightllm --no-cache-dir

ARG ENABLE_RUST=1
ENV PATH="/root/.cargo/bin:${PATH}"
RUN if [ "${ENABLE_RUST}" = "1" ]; then \
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain stable -y && \
pip install --no-cache-dir --upgrade pip setuptools-rust wheel; \
fi

RUN pip install -v -e /lightllm --no-cache-dir --no-build-isolation

2 changes: 1 addition & 1 deletion lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
"--select_p_d_node_strategy",
type=str,
default="round_robin",
choices=["random", "round_robin", "adaptive_load"],
choices=["random", "round_robin", "adaptive_load", "cache_aware"],
help="pd master use this strategy to select p d node, can be round_robin, random or adaptive_load",
)
parser.add_argument(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from .pd_selector import PDSelector, RandomSelector, RoundRobinSelector, AdaptiveLoadSelector
from .pd_selector import (
PDSelector,
RandomSelector,
RoundRobinSelector,
AdaptiveLoadSelector,
LoadBalancedCacheAwareSelector,
)


def create_selector(selector_type: str, pd_manager) -> PDSelector:
Expand All @@ -8,5 +14,7 @@ def create_selector(selector_type: str, pd_manager) -> PDSelector:
return RoundRobinSelector(pd_manager)
elif selector_type == "adaptive_load":
return AdaptiveLoadSelector(pd_manager)
elif selector_type == "cache_aware":
return LoadBalancedCacheAwareSelector(pd_manager)
else:
raise ValueError(f"Invalid selector type: {selector_type}")
139 changes: 139 additions & 0 deletions lightllm/server/httpserver_for_pd_master/pd_selector/cache_aware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from __future__ import annotations

import threading
from dataclasses import dataclass
from typing import List, Optional, runtime_checkable

from lightllm.server.pd_io_struct import PD_Client_Obj
from lightllm.utils.log_utils import init_logger

from .tree import Tree


logger = init_logger(__name__)


@dataclass(slots=True)
class CacheAwareConfig:
cache_threshold: float = 0.5
balance_rel_threshold: float = 1.2
eviction_interval_secs: int = 30
max_tree_size: int = 1000000


class CacheAwarePolicy:
def __init__(self, config: Optional[CacheAwareConfig] = None) -> None:
self.config = config or CacheAwareConfig()
self.tree: Tree = Tree()
self._stop_eviction = threading.Event()
self._eviction_thread: Optional[threading.Thread] = None
if self.config.eviction_interval_secs > 0:
self._eviction_thread = threading.Thread(
target=self._run_eviction_loop, name="cache-aware-eviction", daemon=True
)
self._eviction_thread.start()

def _run_eviction_loop(self) -> None:
while not self._stop_eviction.wait(self.config.eviction_interval_secs):
logger.info("Running cache eviction...")
self.evict_cache(self.config.max_tree_size)
logger.info(f"Cache eviction completed.: {self.tree.get_used_size_per_tenant()}")

def close(self) -> None:
self._stop_eviction.set()
if self._eviction_thread is not None and self._eviction_thread.is_alive():
self._eviction_thread.join(timeout=1.0)

def init_workers(self, workers: List[PD_Client_Obj]) -> None:
for worker in workers:
self.tree.insert("", worker.url())

def add_worker(self, worker: PD_Client_Obj) -> None:
self.tree.insert("", worker.url())

def remove_worker(self, worker: PD_Client_Obj) -> None:
self.tree.remove_tenant(worker.url())

def remove_worker_by_url(self, url: str) -> None:
self.tree.remove_tenant(url)

def evict_cache(self, max_size: int) -> None:
self.tree.evict_tenant_by_size(max_size)

def _select_worker_min_load(
self,
workers: List[PD_Client_Obj],
request_text: Optional[str],
) -> Optional[PD_Client_Obj]:

min_load_worker = min(workers, key=lambda worker: worker.load())

if request_text is not None:
self.tree.insert(request_text, min_load_worker.url())

return min_load_worker

def select_worker(
self, workers: List[PD_Client_Obj], request_text: Optional[str] = None
) -> Optional[PD_Client_Obj]:

if not workers:
return None

loads = [worker.load() for worker in workers]
Comment thread
kingder marked this conversation as resolved.
min_load = min(loads) if loads else 0
max_load = max(loads) if loads else 0

is_imbalanced = max_load > (min_load * self.config.balance_rel_threshold)

logger.info(
f"CacheAwarePolicy: min_load={min_load:.4f}, max_load={max_load:.4f}, "
f"balance_rel_threshold={self.config.balance_rel_threshold:.4f}, "
f"is_imbalanced={is_imbalanced}"
)

if is_imbalanced:
return self._select_worker_min_load(
workers=workers,
request_text=request_text,
)

text = request_text or ""

result = self.tree.prefix_match_with_counts(text)
match_rate = 0.0 if result.input_char_count == 0 else result.matched_char_count / result.input_char_count

logger.info(
f"CacheAwarePolicy: matched_char_count={result.matched_char_count}, "
f"input_char_count={result.input_char_count}, match_rate={match_rate:.4f}, "
f"cache_threshold={self.config.cache_threshold:.4f}"
)

selected_worker: Optional[PD_Client_Obj] = None
if match_rate > self.config.cache_threshold:
for worker in workers:
if worker.url() == result.tenant:
selected_worker = worker
break

if selected_worker is None:
# If the matched tenant is not in the current workers, we can evict it from the tree
logger.info(f"Evicting tenant: {result.tenant}")
self.tree.remove_tenant(result.tenant)

logger.info(
f"CacheAwarePolicy: selected_worker={selected_worker.url() if selected_worker else None}, "
f"match_rate={match_rate:.4f}, cache_threshold={self.config.cache_threshold:.4f}"
)

if selected_worker is not None:
self.tree.insert(text, selected_worker.url())
return selected_worker
else:
return self._select_worker_min_load(
workers=workers,
request_text=request_text,
)

def __del__(self) -> None:
self.close()
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
from lightllm.server.pd_io_struct import PD_Client_Obj
from lightllm.server.core.objs import SamplingParams
from lightllm.server.multimodal_params import MultimodalParams
from lightllm.utils.log_utils import init_logger

from .cache_aware import CacheAwarePolicy, CacheAwareConfig


logger = init_logger(__name__)


class PDSelector:
Expand Down Expand Up @@ -64,4 +70,43 @@ def select_p_d_node(
return p_node, d_node

def _importance_sampling(self, nodes: List[PD_Client_Obj]):
return random.choices(nodes, weights=[max(1.0 - e.run_status.total_token_usage_rate, 0.02) for e in nodes])
return random.choices(nodes, weights=[max(1.0 - e.run_status.total_token_usage_rate, 0.02) for e in nodes])[0]


class LoadBalancedCacheAwareSelector(AdaptiveLoadSelector):
"""refer to: https://github.com/sgl-project/sglang/blob/main/sgl-model-gateway/src/policies/cache_aware.rs"""

def __init__(self, pd_manager):
super().__init__(pd_manager)
self.policy = CacheAwarePolicy(CacheAwareConfig())
self.tree_workers = []

def update_nodes(self, prefill_nodes, decode_nodes):
super().update_nodes(prefill_nodes, decode_nodes)

add_workers = set(self.prefill_nodes) - set(self.tree_workers)
remove_workers = set(self.tree_workers) - set(self.prefill_nodes)

for worker in add_workers:
self.tree_workers.append(worker)

for worker in remove_workers:
self.tree_workers.remove(worker)

self.tree_workers = self.prefill_nodes

def select_p_d_node(
self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams
) -> Tuple[PD_Client_Obj, PD_Client_Obj]:
assert isinstance(prompt, str), "prompt must be a string for cache-aware selection"
p_node = self.policy.select_worker(self.prefill_nodes, request_text=prompt)
d_node = self._importance_sampling(self.decode_nodes)

p_node.update_load(len(prompt))

logger.info(
f"LoadBalancedCacheAwareSelector: selected p_node={p_node.url() if p_node else None}, "
f"d_node={d_node.url() if d_node else None}"
)

return p_node, d_node
Loading
Loading