Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,18 @@ def _init_mem_manager(self):
def _check_mem_size(self):
self.max_total_token_num = self.mem_manager.size

from lightllm.server.visualserver.model_infer.mem_reserve import read_vit_reserved_mem_for_device
from lightllm.utils.dist_utils import get_current_device_id

device_id = get_current_device_id()
vit_reserved_bytes = read_vit_reserved_mem_for_device(self.args, device_id)
if vit_reserved_bytes > 0:
logger.info(
f"[mem] device {device_id}: co-located ViT worst-case reserved "
f"{vit_reserved_bytes / 1024 ** 3:.2f} GB; KV pool max_total_token_num="
f"{self.max_total_token_num}"
)

assert (
self.max_total_token_num > self.batch_max_tokens
), "max_total_token_num must be greater than batch_max_tokens"
Expand All @@ -208,11 +220,18 @@ def _check_mem_size(self):
# 特别大,可能能分配的 kv 容量有限,无法支持 max_seq_length 的推理。所以个人模式下
# 可以适当放宽这个限制,不做这个校验。
if self.args.performance_mode != "personal":
vit_hint = ""
if vit_reserved_bytes > 0:
vit_hint = (
f" A co-located ViT reserved {vit_reserved_bytes / 1024 ** 3:.2f} GB on this device; "
f"lower --visual_infer_batch_size / --max_image_pixels / --max_image_token_count, "
f"reduce --mem_fraction, or move the ViT to another GPU with --visual_gpu_ids."
)
assert self.max_seq_length <= self.max_total_token_num, (
f"max_total_token_num must be >= max_seq_length, "
f"got max_total_token_num={self.max_total_token_num}, "
f"max_seq_length={self.max_seq_length}. "
f"Try set --max_req_total_len a smaller value < {self.max_total_token_num}."
f"Try set --max_req_total_len a smaller value < {self.max_total_token_num}.{vit_hint}"
)

return
Expand Down Expand Up @@ -604,7 +623,6 @@ def _decode(

@final
def _context_forward(self, infer_state: InferStateInfo):

input_embs = self.pre_infer.context_forward(infer_state.input_ids, infer_state, self.pre_post_weight)
if self.args.enable_dp_prefill_balance:
assert not self.args.enable_prefill_cudagraph, "not support now"
Expand Down
3 changes: 2 additions & 1 deletion lightllm/models/qwen2_5_vl/qwen2_5_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from lightllm.server.visualserver import get_vit_attn_backend
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton
from lightllm.server.visualserver.model_infer.worst_case_reserve import QwenVLWorstCaseMixin


class Qwen2RMSNorm(nn.Module):
Expand Down Expand Up @@ -135,7 +136,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


class Qwen2_5_VisionTransformerPretrainedModel(nn.Module):
class Qwen2_5_VisionTransformerPretrainedModel(QwenVLWorstCaseMixin, nn.Module):
def __init__(
self,
kvargs,
Expand Down
3 changes: 2 additions & 1 deletion lightllm/models/qwen2_vl/qwen2_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton
from lightllm.server.visualserver.model_infer.worst_case_reserve import QwenVLWorstCaseMixin

# adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
Expand Down Expand Up @@ -175,7 +176,7 @@ def forward(self, hidden_states, cu_seqlens, max_seqlen, rotary_cos, rotary_sin)
return hidden_states


class Qwen2VisionTransformerPretrainedModel(nn.Module):
class Qwen2VisionTransformerPretrainedModel(QwenVLWorstCaseMixin, nn.Module):
def __init__(
self,
kvargs,
Expand Down
3 changes: 2 additions & 1 deletion lightllm/models/qwen3_vl/qwen3_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor
from lightllm.models.qwen2_vl.qwen2_visual import VisionRotaryEmbedding, VisionFlashAttention
from lightllm.utils.log_utils import init_logger
from lightllm.server.visualserver.model_infer.worst_case_reserve import QwenVLWorstCaseMixin

logger = init_logger(__name__)

Expand Down Expand Up @@ -116,7 +117,7 @@ def forward(self, hidden_states, cu_seqlens, max_seqlen, rotary_cos, rotary_sin)
return hidden_states


class Qwen3VisionTransformerPretrainedModel(nn.Module):
class Qwen3VisionTransformerPretrainedModel(QwenVLWorstCaseMixin, nn.Module):
def __init__(
self,
kvargs,
Expand Down
34 changes: 8 additions & 26 deletions lightllm/models/vit/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@
import torchvision.transforms as T
from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
from PIL import Image
from typing import List, Union, final
from typing import List, Union
from io import BytesIO
from rpyc.utils.classic import obtain
from lightllm.common.quantization import Quantcfg
from lightllm.utils.dist_utils import get_dp_world_size
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
from lightllm.server.visualserver.model_infer.worst_case_reserve import WorstCaseReserveMixin


logger = init_logger(__name__)


class VisionTransformer:

class VisionTransformer(WorstCaseReserveMixin):
# weight class
pre_and_post_weight_class = ViTPreAndPostLayerWeight
transformer_weight_class = ViTTransformerLayerWeight
Expand All @@ -53,31 +53,13 @@ def __init__(self, kvargs):
self._init_quant()
self._init_weights()
self._init_infer_layer()
self._check_max_len_infer()
return

@final
@torch.no_grad()
def _check_max_len_infer(self):
disable_check_max_len_infer = os.getenv("DISABLE_CHECK_MAX_LEN_INFER", None) is not None
if disable_check_max_len_infer:
return

try:
dummy_images = torch.randn(
(self.MAX_PATH_NUM * self.max_batch_size, 3, self.IMAGE_H, self.IMAGE_W), dtype=self.data_type
).cuda()
all_img_embeds = self.forward(dummy_images)
del all_img_embeds
logger.info(f"vit check max_len {self.max_batch_size} infer ok")
except (RuntimeError, torch.OutOfMemoryError) as e:
logger.exception(str(e))
exception_str = (
"Vit check max len infer fail, you can try:" "1.Set the --visual_infer_batch_size to a smaller value."
)
logger.error(exception_str)
raise Exception(exception_str)
return
def build_worst_case_input(self, batch_size, max_image_pixels, max_image_token_count) -> dict:
# InternVL uses fixed-size tiles: worst case is batch_size * MAX_PATH_NUM tiles of (3, IMAGE_H, IMAGE_W).
num_tiles = int(self.MAX_PATH_NUM) * int(batch_size)
dummy_images = torch.randn((num_tiles, 3, self.IMAGE_H, self.IMAGE_W), dtype=self.data_type, device="cuda")
return {"pixel_values": dummy_images}

def _init_config(self):
with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
Expand Down
8 changes: 8 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,14 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--visual_gpu_ids", nargs="+", type=int, default=None, help="List of GPU IDs to use, e.g., 0 1 2"
)
parser.add_argument(
"--visual_reserved_mem_gb",
type=float,
default=None,
help="""Override the automatic ViT worst-case activation reservation. When set, each visual rank
reserves exactly this many GB of GPU memory (held, not freed) and skips the dummy-image probe.
Use as a backstop for models without an automatic worst-case builder, or to override a bad estimate.""",
)
parser.add_argument("--visual_tp", type=int, default=1, help="number of tensort parallel instances for ViT")
parser.add_argument("--visual_dp", type=int, default=1, help="number of data parallel instances for ViT")
parser.add_argument(
Expand Down
19 changes: 17 additions & 2 deletions lightllm/server/visualserver/model_infer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from rpyc.utils.server import ThreadedServer
from lightllm.utils.graceful_utils import graceful_registry
from lightllm.utils.envs_utils import get_env_start_args, get_unique_server_name
from .model_rpc_client import VisualModelRpcClient
from .model_rpc import VisualModelRpcServer
from ..objs import rpyc_config


Expand All @@ -22,6 +20,7 @@ def _init_env(socket_path: str, success_event):
setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_model_infer")

import lightllm.utils.rpyc_fix_utils as _
from .model_rpc import VisualModelRpcServer

t = ThreadedServer(VisualModelRpcServer(), socket_path=socket_path, protocol_config=rpyc_config)
success_event.set()
Expand All @@ -31,6 +30,7 @@ def _init_env(socket_path: str, success_event):

async def start_model_process():
import lightllm.utils.rpyc_fix_utils as _
from .model_rpc_client import VisualModelRpcClient

socket_path = _generate_unix_socket_path()
if os.path.exists(socket_path):
Expand Down Expand Up @@ -61,3 +61,18 @@ def _generate_unix_socket_path() -> str:
"""Generate a random Unix socket path"""
unique_id = uuid.uuid4().hex[:8]
return f"/tmp/lightllm_model_infer_{unique_id}.sock"


def __getattr__(name):
# Lazy re-export to preserve the package's public API without re-introducing the import cycle
# (model modules import this package's worst-case helpers; eagerly importing model_rpc here would
# form qwen2_visual -> worst_case_reserve -> model_infer/__init__ -> model_rpc -> qwen2_visual).
if name == "VisualModelRpcClient":
from .model_rpc_client import VisualModelRpcClient

return VisualModelRpcClient
if name == "VisualModelRpcServer":
from .model_rpc import VisualModelRpcServer

return VisualModelRpcServer
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
80 changes: 80 additions & 0 deletions lightllm/server/visualserver/model_infer/mem_reserve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import math
from typing import List, Tuple
import torch
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
from lightllm.utils.envs_utils import get_unique_server_name
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


def get_vit_reserved_shm_name(device_id: int, global_rank: int) -> str:
return f"{get_unique_server_name()}_vit_reserved_mem_d{device_id}_r{global_rank}"


def publish_vit_reserved_mem(device_id: int, global_rank: int, reserved_bytes: int) -> None:
"""Visual rank writes its held worst-case reservation (bytes) for cross-process discovery."""
shm = SharedInt(get_vit_reserved_shm_name(device_id, global_rank))
shm.set_value(int(reserved_bytes))


def read_vit_reserved_mem_for_device(args, device_id: int) -> int:
"""Router side: sum reservations of all visual ranks placed on `device_id`. Diagnostic only."""
if getattr(args, "disable_vision", False) or not getattr(args, "enable_multimodal", False):
return 0
gpu_ids = getattr(args, "visual_gpu_ids", None) or []
total = 0
# assumes global_rank == index into visual_gpu_ids (matching how visual ranks call publish_vit_reserved_mem)
for global_rank, dev in enumerate(gpu_ids):
if dev == device_id:
total += int(SharedInt(get_vit_reserved_shm_name(dev, global_rank)).get_value())

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

During concurrent startup of the LLM worker and the visual worker, the LLM worker may call read_vit_reserved_mem_for_device before the visual worker has initialized and published its reserved memory via publish_vit_reserved_mem. In this case, SharedInt will raise an exception (such as FileNotFoundError or ValueError) because the shared memory segment does not exist yet, causing the LLM worker to crash during startup. Wrapping the lookup in a try...except block ensures robust defensive programming and prevents startup crashes.

Suggested change
total += int(SharedInt(get_vit_reserved_shm_name(dev, global_rank)).get_value())
try:
total += int(SharedInt(get_vit_reserved_shm_name(dev, global_rank)).get_value())
except Exception:
pass

return total


def reserve_guard_tensor(device_id: int, reserved_gb: float) -> Tuple[torch.Tensor, int]:
"""Allocate and HOLD a guard tensor of `reserved_gb` GB so the allocator high-water mark persists.
Returns (tensor, nbytes). The caller MUST keep a reference to the tensor."""
nbytes = int(reserved_gb * 1024 ** 3)
guard = torch.empty(nbytes, dtype=torch.uint8, device=f"cuda:{device_id}")
return guard, nbytes


def compute_qwen_worst_case_grid(
batch_size: int,
max_image_pixels: int,
max_image_token_count: int,
patch_size: int,
temporal_patch_size: int,
in_channels: int,
spatial_merge_size: int,
) -> Tuple[Tuple[int, int], List[List[int]]]:
"""Pure shape math for the Qwen-VL worst case.

Returns ((total_patches, row_width), grid_thw) where pixel_values has shape
(total_patches, row_width) and grid_thw is one [t, h, w] triple per dummy image.
Bounds each image by BOTH the per-image token cap and pixel cap (whichever is tighter),
using the smallest square grid (sides multiples of spatial_merge_size) whose patch count
is >= that cap. The side is rounded UP so the probe is an upper bound on the largest valid
request and never under-reserves (a square floor could undershoot, e.g. isqrt(32768)=181
-> 180x180 = 32400 patches < the 32768-patch cap).

Assumes valid inputs (max_image_token_count > 0 and max_image_pixels >= (patch_size *
spatial_merge_size)**2); smaller budgets are clamped up to a single spatial_merge_size tile.
"""
spatial_merge_unit = spatial_merge_size * spatial_merge_size
patches_by_tokens = max_image_token_count * spatial_merge_unit
patches_by_pixels = max_image_pixels // (patch_size * patch_size)
max_patches = max(1, min(patches_by_tokens, patches_by_pixels))

side = int(math.isqrt(max_patches))
if side * side < max_patches:
side += 1 # ceil(sqrt) so side*side >= max_patches (never undershoot)
if side % spatial_merge_size:
side += spatial_merge_size - (side % spatial_merge_size) # round up to a merge-unit multiple
side = max(side, spatial_merge_size) # never smaller than one merge unit

grid_h = grid_w = side
row_width = in_channels * temporal_patch_size * patch_size * patch_size
total_patches = grid_h * grid_w * batch_size
grid_thw = [[1, grid_h, grid_w] for _ in range(batch_size)]
return (total_patches, row_width), grid_thw
41 changes: 39 additions & 2 deletions lightllm/server/visualserver/model_infer/model_rpc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import rpyc
import torch
import socket
Expand All @@ -11,7 +12,6 @@
from rpyc.utils.classic import obtain
from lightllm.models.qwen_vl.qwen_visual import QWenVisionTransformer
from lightllm.models.llava.llava_visual import LlavaVisionModel
from lightllm.models.internvl.internvl_visual import InternVLVisionModel
from lightllm.models.gemma3.gemma3_visual import Gemma3VisionModel
from lightllm.models.gemma4.gemma4_visual import Gemma4VisionModel
from lightllm.models.vit.model import VisionTransformer
Expand All @@ -28,6 +28,8 @@
from lightllm.server.visualserver import set_vit_att_backend
from lightllm.server.embed_cache.afs_utils import SepEmbedHandler
from lightllm.utils.log_utils import init_logger
from lightllm.server.visualserver.model_infer.mem_reserve import publish_vit_reserved_mem, reserve_guard_tensor
from lightllm.server.visualserver.model_infer.worst_case_reserve import WorstCaseReserveMixin


logger = init_logger(__name__)
Expand Down Expand Up @@ -95,7 +97,6 @@ def exposed_init_model(self, kvargs):
self.model = LlavaVisionModel()
elif self.model_type == "internvl_chat":
self.model = VisionTransformer(kvargs)
# self.model = InternVLVisionModel()
elif self.model_type == "gemma3":
self.model = Gemma3VisionModel()
elif self.model_type == "gemma4":
Expand All @@ -114,6 +115,7 @@ def exposed_init_model(self, kvargs):

self.model.load_model(weight_dir)
self.model = self.model.cuda()
self._reserve_vit_worst_case_mem()
if not self.is_visual_only_mode:
self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True})
self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
Expand Down Expand Up @@ -143,6 +145,41 @@ def exposed_init_model(self, kvargs):
set_random_seed(2147483647)
return

def _reserve_vit_worst_case_mem(self):
args = get_env_start_args()
global_rank = self.dp_rank_id * self.vit_tp + self.tp_rank_id
reserved_bytes = 0
if getattr(args, "visual_reserved_mem_gb", None) is not None:
# Manual override: hold an explicit guard tensor, skip the dummy probe.
self._mem_reserve_guard, reserved_bytes = reserve_guard_tensor(self.device_id, args.visual_reserved_mem_gb)
elif os.getenv("DISABLE_CHECK_MAX_LEN_INFER", None) is not None:
# Preserved escape hatch: probe disabled. Reservation is skipped -> co-location OOM risk.
logger.warning(
"DISABLE_CHECK_MAX_LEN_INFER is set: skipping ViT worst-case reservation. "
"A co-located LLM may OOM at runtime. Unset it, or set --visual_reserved_mem_gb instead."
)
elif isinstance(self.model, WorstCaseReserveMixin):
reserved_bytes = self.model.reserve_worst_case_activation(
self.device_id,
self.infer_max_batch_size,
args.max_image_pixels,
args.max_image_token_count,
)
else:
logger.warning(
f"co-location OOM risk: model_type={self.model_type} has no ViT worst-case reservation. "
f"Set --visual_reserved_mem_gb to reserve headroom, or place the ViT on a separate GPU "
f"with --visual_gpu_ids."
)
publish_vit_reserved_mem(self.device_id, global_rank, reserved_bytes)
# publishing reserved_bytes (including 0 on the skip paths) tells the router exactly how much
# this rank holds on its device; 0 means "nothing held here".
if reserved_bytes > 0:
logger.info(
f"ViT rank {global_rank} on device {self.device_id} reserved "
f"{reserved_bytes / 1024 ** 3:.2f} GB worst-case activation memory."
)

def exposed_run_task(self, images: List["ImageItem"], ref_event_list: List[threading.Event]):
try:
images = obtain(images)
Expand Down
Loading
Loading