From a3c503510198a7e001525fa8a7038a94834c15a1 Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 22 Jun 2026 17:46:50 +0800 Subject: [PATCH] feat(visual): reserve ViT worst-case activation memory Reserve peak ViT activation memory during visual worker startup so the co-located LLM router sizes its KV pool after the visual tower has already reached its worst-case allocator high-water mark. Add a manual --visual_reserved_mem_gb override for unsupported visual models and include the reserved amount in max-length diagnostics. --- lightllm/common/basemodel/basemodel.py | 22 ++++- lightllm/models/qwen2_5_vl/qwen2_5_visual.py | 3 +- lightllm/models/qwen2_vl/qwen2_visual.py | 3 +- lightllm/models/qwen3_vl/qwen3_visual.py | 3 +- lightllm/models/vit/model.py | 34 ++------ lightllm/server/api_cli.py | 8 ++ .../visualserver/model_infer/__init__.py | 19 ++++- .../visualserver/model_infer/mem_reserve.py | 80 +++++++++++++++++++ .../visualserver/model_infer/model_rpc.py | 41 +++++++++- .../model_infer/worst_case_reserve.py | 70 ++++++++++++++++ 10 files changed, 248 insertions(+), 35 deletions(-) create mode 100644 lightllm/server/visualserver/model_infer/mem_reserve.py create mode 100644 lightllm/server/visualserver/model_infer/worst_case_reserve.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index e83de684a7..7212b62dc9 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -199,6 +199,18 @@ def _init_mem_manager(self): def _check_mem_size(self): self.max_total_token_num = self.mem_manager.size + from lightllm.server.visualserver.model_infer.mem_reserve import read_vit_reserved_mem_for_device + from lightllm.utils.dist_utils import get_current_device_id + + device_id = get_current_device_id() + vit_reserved_bytes = read_vit_reserved_mem_for_device(self.args, device_id) + if vit_reserved_bytes > 0: + logger.info( + f"[mem] device {device_id}: co-located ViT worst-case reserved " + f"{vit_reserved_bytes / 1024 ** 3:.2f} GB; KV pool max_total_token_num=" + f"{self.max_total_token_num}" + ) + assert ( self.max_total_token_num > self.batch_max_tokens ), "max_total_token_num must be greater than batch_max_tokens" @@ -208,11 +220,18 @@ def _check_mem_size(self): # 特别大,可能能分配的 kv 容量有限,无法支持 max_seq_length 的推理。所以个人模式下 # 可以适当放宽这个限制,不做这个校验。 if self.args.performance_mode != "personal": + vit_hint = "" + if vit_reserved_bytes > 0: + vit_hint = ( + f" A co-located ViT reserved {vit_reserved_bytes / 1024 ** 3:.2f} GB on this device; " + f"lower --visual_infer_batch_size / --max_image_pixels / --max_image_token_count, " + f"reduce --mem_fraction, or move the ViT to another GPU with --visual_gpu_ids." + ) assert self.max_seq_length <= self.max_total_token_num, ( f"max_total_token_num must be >= max_seq_length, " f"got max_total_token_num={self.max_total_token_num}, " f"max_seq_length={self.max_seq_length}. " - f"Try set --max_req_total_len a smaller value < {self.max_total_token_num}." + f"Try set --max_req_total_len a smaller value < {self.max_total_token_num}.{vit_hint}" ) return @@ -604,7 +623,6 @@ def _decode( @final def _context_forward(self, infer_state: InferStateInfo): - input_embs = self.pre_infer.context_forward(infer_state.input_ids, infer_state, self.pre_post_weight) if self.args.enable_dp_prefill_balance: assert not self.args.enable_prefill_cudagraph, "not support now" diff --git a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py index 1b3a5f0db7..851536e048 100644 --- a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py +++ b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py @@ -16,6 +16,7 @@ from lightllm.server.visualserver import get_vit_attn_backend from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton +from lightllm.server.visualserver.model_infer.worst_case_reserve import QwenVLWorstCaseMixin class Qwen2RMSNorm(nn.Module): @@ -135,7 +136,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class Qwen2_5_VisionTransformerPretrainedModel(nn.Module): +class Qwen2_5_VisionTransformerPretrainedModel(QwenVLWorstCaseMixin, nn.Module): def __init__( self, kvargs, diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index e02c3d9aa3..1f9922bc0e 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -36,6 +36,7 @@ from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton +from lightllm.server.visualserver.model_infer.worst_case_reserve import QwenVLWorstCaseMixin # adapted from # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -175,7 +176,7 @@ def forward(self, hidden_states, cu_seqlens, max_seqlen, rotary_cos, rotary_sin) return hidden_states -class Qwen2VisionTransformerPretrainedModel(nn.Module): +class Qwen2VisionTransformerPretrainedModel(QwenVLWorstCaseMixin, nn.Module): def __init__( self, kvargs, diff --git a/lightllm/models/qwen3_vl/qwen3_visual.py b/lightllm/models/qwen3_vl/qwen3_visual.py index bab0800f26..df86b4324d 100644 --- a/lightllm/models/qwen3_vl/qwen3_visual.py +++ b/lightllm/models/qwen3_vl/qwen3_visual.py @@ -30,6 +30,7 @@ from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor from lightllm.models.qwen2_vl.qwen2_visual import VisionRotaryEmbedding, VisionFlashAttention from lightllm.utils.log_utils import init_logger +from lightllm.server.visualserver.model_infer.worst_case_reserve import QwenVLWorstCaseMixin logger = init_logger(__name__) @@ -116,7 +117,7 @@ def forward(self, hidden_states, cu_seqlens, max_seqlen, rotary_cos, rotary_sin) return hidden_states -class Qwen3VisionTransformerPretrainedModel(nn.Module): +class Qwen3VisionTransformerPretrainedModel(QwenVLWorstCaseMixin, nn.Module): def __init__( self, kvargs, diff --git a/lightllm/models/vit/model.py b/lightllm/models/vit/model.py index 0befb50166..1aa61f83a5 100644 --- a/lightllm/models/vit/model.py +++ b/lightllm/models/vit/model.py @@ -14,19 +14,19 @@ import torchvision.transforms as T from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data from PIL import Image -from typing import List, Union, final +from typing import List, Union from io import BytesIO from rpyc.utils.classic import obtain from lightllm.common.quantization import Quantcfg from lightllm.utils.dist_utils import get_dp_world_size from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager +from lightllm.server.visualserver.model_infer.worst_case_reserve import WorstCaseReserveMixin logger = init_logger(__name__) -class VisionTransformer: - +class VisionTransformer(WorstCaseReserveMixin): # weight class pre_and_post_weight_class = ViTPreAndPostLayerWeight transformer_weight_class = ViTTransformerLayerWeight @@ -53,31 +53,13 @@ def __init__(self, kvargs): self._init_quant() self._init_weights() self._init_infer_layer() - self._check_max_len_infer() return - @final - @torch.no_grad() - def _check_max_len_infer(self): - disable_check_max_len_infer = os.getenv("DISABLE_CHECK_MAX_LEN_INFER", None) is not None - if disable_check_max_len_infer: - return - - try: - dummy_images = torch.randn( - (self.MAX_PATH_NUM * self.max_batch_size, 3, self.IMAGE_H, self.IMAGE_W), dtype=self.data_type - ).cuda() - all_img_embeds = self.forward(dummy_images) - del all_img_embeds - logger.info(f"vit check max_len {self.max_batch_size} infer ok") - except (RuntimeError, torch.OutOfMemoryError) as e: - logger.exception(str(e)) - exception_str = ( - "Vit check max len infer fail, you can try:" "1.Set the --visual_infer_batch_size to a smaller value." - ) - logger.error(exception_str) - raise Exception(exception_str) - return + def build_worst_case_input(self, batch_size, max_image_pixels, max_image_token_count) -> dict: + # InternVL uses fixed-size tiles: worst case is batch_size * MAX_PATH_NUM tiles of (3, IMAGE_H, IMAGE_W). + num_tiles = int(self.MAX_PATH_NUM) * int(batch_size) + dummy_images = torch.randn((num_tiles, 3, self.IMAGE_H, self.IMAGE_W), dtype=self.data_type, device="cuda") + return {"pixel_values": dummy_images} def _init_config(self): with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 04e0187452..b68b225c71 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -495,6 +495,14 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--visual_gpu_ids", nargs="+", type=int, default=None, help="List of GPU IDs to use, e.g., 0 1 2" ) + parser.add_argument( + "--visual_reserved_mem_gb", + type=float, + default=None, + help="""Override the automatic ViT worst-case activation reservation. When set, each visual rank + reserves exactly this many GB of GPU memory (held, not freed) and skips the dummy-image probe. + Use as a backstop for models without an automatic worst-case builder, or to override a bad estimate.""", + ) parser.add_argument("--visual_tp", type=int, default=1, help="number of tensort parallel instances for ViT") parser.add_argument("--visual_dp", type=int, default=1, help="number of data parallel instances for ViT") parser.add_argument( diff --git a/lightllm/server/visualserver/model_infer/__init__.py b/lightllm/server/visualserver/model_infer/__init__.py index 3e74793634..2317902ce0 100644 --- a/lightllm/server/visualserver/model_infer/__init__.py +++ b/lightllm/server/visualserver/model_infer/__init__.py @@ -11,8 +11,6 @@ from rpyc.utils.server import ThreadedServer from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.envs_utils import get_env_start_args, get_unique_server_name -from .model_rpc_client import VisualModelRpcClient -from .model_rpc import VisualModelRpcServer from ..objs import rpyc_config @@ -22,6 +20,7 @@ def _init_env(socket_path: str, success_event): setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_model_infer") import lightllm.utils.rpyc_fix_utils as _ + from .model_rpc import VisualModelRpcServer t = ThreadedServer(VisualModelRpcServer(), socket_path=socket_path, protocol_config=rpyc_config) success_event.set() @@ -31,6 +30,7 @@ def _init_env(socket_path: str, success_event): async def start_model_process(): import lightllm.utils.rpyc_fix_utils as _ + from .model_rpc_client import VisualModelRpcClient socket_path = _generate_unix_socket_path() if os.path.exists(socket_path): @@ -61,3 +61,18 @@ def _generate_unix_socket_path() -> str: """Generate a random Unix socket path""" unique_id = uuid.uuid4().hex[:8] return f"/tmp/lightllm_model_infer_{unique_id}.sock" + + +def __getattr__(name): + # Lazy re-export to preserve the package's public API without re-introducing the import cycle + # (model modules import this package's worst-case helpers; eagerly importing model_rpc here would + # form qwen2_visual -> worst_case_reserve -> model_infer/__init__ -> model_rpc -> qwen2_visual). + if name == "VisualModelRpcClient": + from .model_rpc_client import VisualModelRpcClient + + return VisualModelRpcClient + if name == "VisualModelRpcServer": + from .model_rpc import VisualModelRpcServer + + return VisualModelRpcServer + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/lightllm/server/visualserver/model_infer/mem_reserve.py b/lightllm/server/visualserver/model_infer/mem_reserve.py new file mode 100644 index 0000000000..e998fc2d81 --- /dev/null +++ b/lightllm/server/visualserver/model_infer/mem_reserve.py @@ -0,0 +1,80 @@ +import math +from typing import List, Tuple +import torch +from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt +from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def get_vit_reserved_shm_name(device_id: int, global_rank: int) -> str: + return f"{get_unique_server_name()}_vit_reserved_mem_d{device_id}_r{global_rank}" + + +def publish_vit_reserved_mem(device_id: int, global_rank: int, reserved_bytes: int) -> None: + """Visual rank writes its held worst-case reservation (bytes) for cross-process discovery.""" + shm = SharedInt(get_vit_reserved_shm_name(device_id, global_rank)) + shm.set_value(int(reserved_bytes)) + + +def read_vit_reserved_mem_for_device(args, device_id: int) -> int: + """Router side: sum reservations of all visual ranks placed on `device_id`. Diagnostic only.""" + if getattr(args, "disable_vision", False) or not getattr(args, "enable_multimodal", False): + return 0 + gpu_ids = getattr(args, "visual_gpu_ids", None) or [] + total = 0 + # assumes global_rank == index into visual_gpu_ids (matching how visual ranks call publish_vit_reserved_mem) + for global_rank, dev in enumerate(gpu_ids): + if dev == device_id: + total += int(SharedInt(get_vit_reserved_shm_name(dev, global_rank)).get_value()) + return total + + +def reserve_guard_tensor(device_id: int, reserved_gb: float) -> Tuple[torch.Tensor, int]: + """Allocate and HOLD a guard tensor of `reserved_gb` GB so the allocator high-water mark persists. + Returns (tensor, nbytes). The caller MUST keep a reference to the tensor.""" + nbytes = int(reserved_gb * 1024 ** 3) + guard = torch.empty(nbytes, dtype=torch.uint8, device=f"cuda:{device_id}") + return guard, nbytes + + +def compute_qwen_worst_case_grid( + batch_size: int, + max_image_pixels: int, + max_image_token_count: int, + patch_size: int, + temporal_patch_size: int, + in_channels: int, + spatial_merge_size: int, +) -> Tuple[Tuple[int, int], List[List[int]]]: + """Pure shape math for the Qwen-VL worst case. + + Returns ((total_patches, row_width), grid_thw) where pixel_values has shape + (total_patches, row_width) and grid_thw is one [t, h, w] triple per dummy image. + Bounds each image by BOTH the per-image token cap and pixel cap (whichever is tighter), + using the smallest square grid (sides multiples of spatial_merge_size) whose patch count + is >= that cap. The side is rounded UP so the probe is an upper bound on the largest valid + request and never under-reserves (a square floor could undershoot, e.g. isqrt(32768)=181 + -> 180x180 = 32400 patches < the 32768-patch cap). + + Assumes valid inputs (max_image_token_count > 0 and max_image_pixels >= (patch_size * + spatial_merge_size)**2); smaller budgets are clamped up to a single spatial_merge_size tile. + """ + spatial_merge_unit = spatial_merge_size * spatial_merge_size + patches_by_tokens = max_image_token_count * spatial_merge_unit + patches_by_pixels = max_image_pixels // (patch_size * patch_size) + max_patches = max(1, min(patches_by_tokens, patches_by_pixels)) + + side = int(math.isqrt(max_patches)) + if side * side < max_patches: + side += 1 # ceil(sqrt) so side*side >= max_patches (never undershoot) + if side % spatial_merge_size: + side += spatial_merge_size - (side % spatial_merge_size) # round up to a merge-unit multiple + side = max(side, spatial_merge_size) # never smaller than one merge unit + + grid_h = grid_w = side + row_width = in_channels * temporal_patch_size * patch_size * patch_size + total_patches = grid_h * grid_w * batch_size + grid_thw = [[1, grid_h, grid_w] for _ in range(batch_size)] + return (total_patches, row_width), grid_thw diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 50bc12fd23..803f0f24f9 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -1,3 +1,4 @@ +import os import rpyc import torch import socket @@ -11,7 +12,6 @@ from rpyc.utils.classic import obtain from lightllm.models.qwen_vl.qwen_visual import QWenVisionTransformer from lightllm.models.llava.llava_visual import LlavaVisionModel -from lightllm.models.internvl.internvl_visual import InternVLVisionModel from lightllm.models.gemma3.gemma3_visual import Gemma3VisionModel from lightllm.models.gemma4.gemma4_visual import Gemma4VisionModel from lightllm.models.vit.model import VisionTransformer @@ -28,6 +28,8 @@ from lightllm.server.visualserver import set_vit_att_backend from lightllm.server.embed_cache.afs_utils import SepEmbedHandler from lightllm.utils.log_utils import init_logger +from lightllm.server.visualserver.model_infer.mem_reserve import publish_vit_reserved_mem, reserve_guard_tensor +from lightllm.server.visualserver.model_infer.worst_case_reserve import WorstCaseReserveMixin logger = init_logger(__name__) @@ -95,7 +97,6 @@ def exposed_init_model(self, kvargs): self.model = LlavaVisionModel() elif self.model_type == "internvl_chat": self.model = VisionTransformer(kvargs) - # self.model = InternVLVisionModel() elif self.model_type == "gemma3": self.model = Gemma3VisionModel() elif self.model_type == "gemma4": @@ -114,6 +115,7 @@ def exposed_init_model(self, kvargs): self.model.load_model(weight_dir) self.model = self.model.cuda() + self._reserve_vit_worst_case_mem() if not self.is_visual_only_mode: self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) @@ -143,6 +145,41 @@ def exposed_init_model(self, kvargs): set_random_seed(2147483647) return + def _reserve_vit_worst_case_mem(self): + args = get_env_start_args() + global_rank = self.dp_rank_id * self.vit_tp + self.tp_rank_id + reserved_bytes = 0 + if getattr(args, "visual_reserved_mem_gb", None) is not None: + # Manual override: hold an explicit guard tensor, skip the dummy probe. + self._mem_reserve_guard, reserved_bytes = reserve_guard_tensor(self.device_id, args.visual_reserved_mem_gb) + elif os.getenv("DISABLE_CHECK_MAX_LEN_INFER", None) is not None: + # Preserved escape hatch: probe disabled. Reservation is skipped -> co-location OOM risk. + logger.warning( + "DISABLE_CHECK_MAX_LEN_INFER is set: skipping ViT worst-case reservation. " + "A co-located LLM may OOM at runtime. Unset it, or set --visual_reserved_mem_gb instead." + ) + elif isinstance(self.model, WorstCaseReserveMixin): + reserved_bytes = self.model.reserve_worst_case_activation( + self.device_id, + self.infer_max_batch_size, + args.max_image_pixels, + args.max_image_token_count, + ) + else: + logger.warning( + f"co-location OOM risk: model_type={self.model_type} has no ViT worst-case reservation. " + f"Set --visual_reserved_mem_gb to reserve headroom, or place the ViT on a separate GPU " + f"with --visual_gpu_ids." + ) + publish_vit_reserved_mem(self.device_id, global_rank, reserved_bytes) + # publishing reserved_bytes (including 0 on the skip paths) tells the router exactly how much + # this rank holds on its device; 0 means "nothing held here". + if reserved_bytes > 0: + logger.info( + f"ViT rank {global_rank} on device {self.device_id} reserved " + f"{reserved_bytes / 1024 ** 3:.2f} GB worst-case activation memory." + ) + def exposed_run_task(self, images: List["ImageItem"], ref_event_list: List[threading.Event]): try: images = obtain(images) diff --git a/lightllm/server/visualserver/model_infer/worst_case_reserve.py b/lightllm/server/visualserver/model_infer/worst_case_reserve.py new file mode 100644 index 0000000000..9c06ff287c --- /dev/null +++ b/lightllm/server/visualserver/model_infer/worst_case_reserve.py @@ -0,0 +1,70 @@ +import torch +from lightllm.server.visualserver.model_infer.mem_reserve import compute_qwen_worst_case_grid +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + +_RESERVE_OOM_HINT = ( + "ViT worst-case activation reservation hit OOM. Lower --visual_infer_batch_size, " + "--max_image_pixels, or --max_image_token_count, or place the ViT on a separate GPU " + "with --visual_gpu_ids." +) + + +class WorstCaseReserveMixin: + """Adds a reserve-and-HOLD worst-case activation probe to a visual model. + + Subclasses MUST implement build_worst_case_input(...). The reservation is held by + deliberately NOT calling torch.cuda.empty_cache() — the retained allocator high-water + mark is what the LLM router observes via mem_get_info during KV-pool profiling. + """ + + def build_worst_case_input(self, batch_size, max_image_pixels, max_image_token_count) -> dict: + raise NotImplementedError + + def run_worst_case_forward(self, dummy: dict): + return self.forward(**dummy) + + @torch.no_grad() + def reserve_worst_case_activation( + self, device_id: int, batch_size: int, max_image_pixels: int, max_image_token_count: int + ) -> int: + torch.cuda.set_device(device_id) + # Baseline = memory already reserved by the loaded ViT weights. We return the activation + # growth ABOVE this baseline so the published/logged value is the tunable activation + # headroom (what --visual_infer_batch_size / --max_image_* control), not weights+activation. + # The physical hold is unaffected: we still never empty_cache, so the full peak stays + # reserved and visible to the LLM's mem_get_info profiling. + baseline_reserved = torch.cuda.memory_reserved(device_id) + torch.cuda.reset_peak_memory_stats(device_id) + try: + dummy = self.build_worst_case_input(batch_size, max_image_pixels, max_image_token_count) + out = self.run_worst_case_forward(dummy) + del out, dummy + except (RuntimeError, torch.OutOfMemoryError) as e: + logger.exception(str(e)) + raise Exception(_RESERVE_OOM_HINT) + # NB: intentionally NO torch.cuda.empty_cache() here — holding the high-water mark IS the mechanism. + peak_reserved = torch.cuda.max_memory_reserved(device_id) + return int(max(0, peak_reserved - baseline_reserved)) + + +class QwenVLWorstCaseMixin(WorstCaseReserveMixin): + """Worst-case builder for Qwen2/2.5/3-VL visual towers (shared forward(hidden_states, grid_thw)).""" + + def build_worst_case_input(self, batch_size, max_image_pixels, max_image_token_count) -> dict: + (total_patches, row_width), grid_thw = compute_qwen_worst_case_grid( + batch_size=batch_size, + max_image_pixels=max_image_pixels, + max_image_token_count=max_image_token_count, + patch_size=self.patch_size, + temporal_patch_size=self.temporal_patch_size, + in_channels=self.in_channels, + spatial_merge_size=self.spatial_merge_size, + ) + # Derive dtype from the loaded weights rather than self.data_type — the latter is not + # guaranteed to be a torch.dtype on every Qwen visual class; parameters() always is. + dtype = next(self.parameters()).dtype + hidden_states = torch.randn((total_patches, row_width), dtype=dtype, device="cuda") + grid_thw_t = torch.tensor(grid_thw, dtype=torch.long, device="cuda") + return {"hidden_states": hidden_states, "grid_thw": grid_thw_t}