From 664b23d2e2a128d29085541cde6f4130714a0ef3 Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Fri, 26 Jun 2026 10:34:27 +0800 Subject: [PATCH] feat: opt autotuner --- lightllm/common/basemodel/attention/fa3/fp.py | 4 +- .../attention/nsa/flashmla_sparse.py | 2 +- lightllm/common/basemodel/basemodel.py | 18 +- lightllm/common/basemodel/cuda_graph.py | 55 ++-- .../fused_moe/grouped_fused_moe.py | 4 +- lightllm/common/triton_utils/autotuner.py | 2 + lightllm/utils/sgl_utils.py | 278 +++++++++++++++++- 7 files changed, 332 insertions(+), 31 deletions(-) diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index 952bb39d91..57f3ab6fe3 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -3,7 +3,7 @@ from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl from typing import Optional, TYPE_CHECKING from lightllm.utils.dist_utils import get_current_device_id -from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from lightllm.utils.sgl_utils import flash_attn_with_kvcache, flash_attn_with_kvcache_autotune from lightllm.utils.envs_utils import get_env_start_args from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor @@ -222,7 +222,7 @@ def _normal_decode_att( k_descale, v_descale = None, None # disable quantization Lq = q.shape[-1] sm_scale = 1.0 / (Lq ** 0.5) - o = flash_attn_with_kvcache( + o = flash_attn_with_kvcache_autotune( q=q, k_cache=k.view(k.shape[0], 1, k.shape[1], k.shape[2]), v_cache=v.view(v.shape[0], 1, v.shape[1], v.shape[2]), diff --git a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py index c3456f4b7a..b06cb64699 100644 --- a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py +++ b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py @@ -162,7 +162,7 @@ def _nsa_decode_att( kv: torch.Tensor, att_control: AttControl, ) -> torch.Tensor: - from sgl_kernel.flash_attn import flash_attn_with_kvcache + from lightllm.utils.sgl_utils import flash_attn_with_kvcache nsa_dict = att_control.nsa_decode_dict topk_mem_indices = nsa_dict["topk_mem_indices"] diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index e83de684a7..c703ae0496 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -126,6 +126,7 @@ def __init__(self, kvargs): logger.info(f"use decode att backend1: {self.decode_att_backend1.__class__.__name__}") self._autotune_warmup() + self._extra_autotune() self._init_padded_req() self._init_cudagraph() self._init_prefill_cuda_graph() @@ -286,6 +287,21 @@ def _init_prefill_cuda_graph(self): else: self.prefill_graph.warmup(self) + @final + @torch.no_grad() + @post_empty_cache + def _extra_autotune(self): + if self.disable_cudagraph: + return + from lightllm.common.basemodel.cuda_graph import gen_cuda_graph_batch_sizes + from lightllm.utils.sgl_utils import fa3_decode_autotune + + cuda_graph_batch_sizes = gen_cuda_graph_batch_sizes( + max_batch_size=self.graph_max_batch_size, + tp_world_size=self.tp_world_size_, + ) + fa3_decode_autotune(self, cuda_graph_batch_sizes) + def _init_custom(self): pass @@ -1050,7 +1066,7 @@ def _autotune_warmup(self): Autotuner.start_autotune_warmup() torch.distributed.barrier() - warmup_lengths = [1, 8, 16, 32, 64, 100, 128, 256, 1024, 2048, 4096] + warmup_lengths = [1, 4, 8, 16, 32, 64, 100, 128, 256, 1024, 2048, 4096] if self.batch_max_tokens not in warmup_lengths: warmup_lengths.append(self.batch_max_tokens) diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index d0ac8ead10..5fbe77f9ae 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -14,6 +14,34 @@ logger = init_logger(__name__) +def gen_cuda_graph_batch_sizes(max_batch_size=8, tp_world_size: int = 1): + args = get_env_start_args() + mtp_size = args.mtp_step + 1 + + # gen cuda graph batch_sizes + # cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size] + # and [graph_split_batch_size + graph_grow_step_size, + # if the mtp_step is not 0, then the batch_sizes will be multiply of (mtp_step + 1) + + graph_split_batch_size = args.graph_split_batch_size * mtp_size + graph_grow_step_size = args.graph_grow_step_size * mtp_size + + batch_sizes = [i * mtp_size for i in range(1, args.graph_split_batch_size + 1)] + for _batch_size in range(graph_split_batch_size + graph_grow_step_size, max_batch_size, graph_grow_step_size): + batch_sizes.append(_batch_size) + + batch_sizes = list(set([e for e in batch_sizes if e < max_batch_size])) + batch_sizes.append(max_batch_size) + batch_sizes.sort() + if args.enable_tpsp_mix_mode: + batch_sizes = [triton.cdiv(e, tp_world_size) * tp_world_size for e in batch_sizes] + batch_sizes = list(set(batch_sizes)) + batch_sizes.sort() + + assert batch_sizes[-1] == max_batch_size + return batch_sizes + + class CudaGraph: # CudaGraph forward pass for the decoding stage. @@ -27,28 +55,11 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192, tp_world_size: int = self.graph_max_len_in_batch = max_len_in_batch self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap - # gen cuda graph batch_sizes - # cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size] - # and [graph_split_batch_size + graph_grow_step_size, - # if the mtp_step is not 0, then the batch_sizes will be multiply of (mtp_step + 1) - - graph_split_batch_size = self.args.graph_split_batch_size * (self.mtp_step + 1) - graph_grow_step_size = self.args.graph_grow_step_size * (self.mtp_step + 1) - - batch_sizes = [i * (self.mtp_step + 1) for i in range(1, self.args.graph_split_batch_size + 1)] - for _batch_size in range(graph_split_batch_size + graph_grow_step_size, max_batch_size, graph_grow_step_size): - batch_sizes.append(_batch_size) - - batch_sizes = list(set([e for e in batch_sizes if e < max_batch_size])) - batch_sizes.append(max_batch_size) - batch_sizes.sort() - if self.args.enable_tpsp_mix_mode: - batch_sizes = [triton.cdiv(e, self.tp_world_size) * self.tp_world_size for e in batch_sizes] - batch_sizes = list(set(batch_sizes)) - batch_sizes.sort() - - self.cuda_graph_batch_sizes = batch_sizes - assert batch_sizes[-1] == self.max_batch_size + self.cuda_graph_batch_sizes = gen_cuda_graph_batch_sizes( + max_batch_size=max_batch_size, + tp_world_size=tp_world_size, + ) + assert self.cuda_graph_batch_sizes[-1] == self.max_batch_size logger.info(f"cuda graph batch_sizes: {self.cuda_graph_batch_sizes}") def can_run(self, batch_size, max_len_in_batch): diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py index c6eeb781dc..e10adf7758 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py @@ -455,7 +455,7 @@ def moe_align2(token_num_mul_topk_num: int, exports_token_num: torch.Tensor, blo out tensor is a tensor that contain block schduel infos tensor. """ max_num_tokens_padded = token_num_mul_topk_num + exports_token_num.shape[0] * (block_m - 1) - max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_m) + max_num_m_blocks = min(token_num_mul_topk_num, triton.cdiv(max_num_tokens_padded, block_m)) # first is expert, second is m_index, third is token_start_index mblocks_to_tuple_info = torch.empty((max_num_m_blocks, 3), dtype=torch.int32, device="cuda") @@ -760,7 +760,7 @@ def _get_grouped_matmul_configs(): } for ns in [2, 3, 4, 5] for gm in [1, 16, 32, 64] - for nw in [4, 8] + for nw in [2, 4, 8] for bm in [16, 32, 64, 128] for bn in [16, 32, 64, 128] for bk in [32, 64, 128] diff --git a/lightllm/common/triton_utils/autotuner.py b/lightllm/common/triton_utils/autotuner.py index c62a2572ff..700e91786f 100644 --- a/lightllm/common/triton_utils/autotuner.py +++ b/lightllm/common/triton_utils/autotuner.py @@ -2,6 +2,7 @@ import orjson import os import inspect +import gc import torch import torch.distributed as dist import random @@ -275,6 +276,7 @@ def kernel_call(): return float("inf") def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size): + is_key_all_same = True if world_size > 1: all_keys = [None for _ in range(world_size)] diff --git a/lightllm/utils/sgl_utils.py b/lightllm/utils/sgl_utils.py index b48a62506d..b4d5a4f63a 100644 --- a/lightllm/utils/sgl_utils.py +++ b/lightllm/utils/sgl_utils.py @@ -1,6 +1,14 @@ +import torch + +from frozendict import frozendict +from lightllm.common.triton_utils.autotuner import AutotuneLevel, Autotuner +from lightllm.utils.envs_utils import get_triton_autotune_level from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) + +_DEFAULT_NUM_SPLITS = 0 + try: import sgl_kernel @@ -17,16 +25,280 @@ ) try: - from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache + from sgl_kernel.flash_attn import ( + flash_attn_varlen_func, + flash_attn_with_kvcache as _flash_attn_with_kvcache, + ) flash_attn_varlen_func = flash_attn_varlen_func - flash_attn_with_kvcache = flash_attn_with_kvcache merge_state_v2 = sgl_ops.merge_state_v2 except: flash_attn_varlen_func = None - flash_attn_with_kvcache = None + _flash_attn_with_kvcache = None merge_state_v2 = None logger.warning( "sgl_kernel is not installed, or the installed version did not support fa3. \ Try to upgrade it." ) + + +def _flash_attn_kvcache_num_splits_configs(): + return [{"num_splits": num_splits} for num_splits in [0, 16, 32]] + + +def _flash_attn_kvcache_static_key(q, k_cache, v_cache, causal, window_size, softcap, sinks): + return { + "qd": str(q.dtype), + "kd": str(k_cache.dtype), + "vd": str(v_cache.dtype), + "qh": int(q.shape[-2]), + "kh": int(k_cache.shape[-2]), + "hd": int(q.shape[-1]), + "vh": int(v_cache.shape[-1]), + "pb": int(k_cache.shape[-3]), + "c": int(bool(causal)), + "wl": int(window_size[0]), + "wr": int(window_size[1]), + "sc": int(softcap > 0.0), + "sk": int(sinks is not None), + "sgl": getattr(sgl_ops, "__version__", "unknown"), + } + + +def _flash_attn_max_q_len(q, max_seqlen_q): + if max_seqlen_q is not None: + return int(max_seqlen_q) + if q.dim() >= 4: + return int(q.shape[1]) + return int(q.shape[0]) + + +def _flash_attn_kvcache_run_key(q, page_table, max_seqlen_q): + batch_size = int(page_table.shape[0]) + max_q_len = _flash_attn_max_q_len(q, max_seqlen_q) + max_kv_len = int(page_table.shape[1]) + return batch_size * 1_000_000_000_000 + max_q_len * 1_000_000 + max_kv_len + + +def _flash_attn_is_decode_like(q, page_table, max_seqlen_q=None): + if page_table is None or page_table.dim() < 2: + return False + + max_q_len = _flash_attn_max_q_len(q, max_seqlen_q) + if max_q_len <= 0 or int(page_table.shape[1]) <= max_q_len: + return False + + q_token_num = int(q.shape[0]) * int(q.shape[1]) if q.dim() >= 4 else int(q.shape[0]) + return q_token_num == int(page_table.shape[0]) * max_q_len + + +class _FlashAttnKvcacheAutotuner(Autotuner): + def _bench(self, *args, n_repeat=3, n_retries=3, **kwargs): + page_table = kwargs.get("page_table") + cache_seqlens = kwargs.get("cache_seqlens") + max_kv_len = int(page_table.shape[1]) if page_table is not None and page_table.dim() >= 2 else 0 + + bench_times = [] + for bench_kv_len in sorted({kv_len for kv_len in [10240, max_kv_len] if 0 < kv_len <= max_kv_len}): + bench_kwargs = kwargs.copy() + if isinstance(cache_seqlens, torch.Tensor): + bench_cache_seqlens = cache_seqlens.clone() + bench_cache_seqlens.fill_(bench_kv_len) + else: + bench_cache_seqlens = bench_kv_len + bench_kwargs["cache_seqlens"] = bench_cache_seqlens + + cu_seqlens_k_new = bench_kwargs.get("cu_seqlens_k_new") + if isinstance(cu_seqlens_k_new, torch.Tensor) and cu_seqlens_k_new.numel() != 0: + bench_cu_seqlens_k_new = torch.arange( + cu_seqlens_k_new.numel(), + device=cu_seqlens_k_new.device, + dtype=cu_seqlens_k_new.dtype, + ) + bench_cu_seqlens_k_new *= bench_kv_len + bench_kwargs["cu_seqlens_k_new"] = bench_cu_seqlens_k_new + + bench_times.append(super()._bench(*args, n_repeat=n_repeat, n_retries=n_retries, **bench_kwargs)) + + if bench_times: + return sum(bench_times) / len(bench_times) + + return super()._bench(*args, n_repeat=n_repeat, n_retries=n_retries, **kwargs) + + +if _flash_attn_with_kvcache is not None and torch.cuda.is_available(): + + @torch.no_grad() + def _flash_attn_with_kvcache_autotuned_impl( + q, + k_cache, + v_cache, + cache_seqlens=None, + page_table=None, + cu_seqlens_q=None, + cu_seqlens_k_new=None, + max_seqlen_q=None, + causal=False, + window_size=(-1, -1), + softcap=0.0, + num_splits=0, + sinks=None, + run_config=None, + **kwargs, + ): + if run_config is not None: + num_splits = run_config["num_splits"] + return _flash_attn_with_kvcache( + q=q, + k_cache=k_cache, + v_cache=v_cache, + cache_seqlens=cache_seqlens, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, + max_seqlen_q=max_seqlen_q, + causal=causal, + window_size=window_size, + softcap=softcap, + num_splits=num_splits, + sinks=sinks, + **kwargs, + ) + + _flash_attn_with_kvcache_autotuned = _FlashAttnKvcacheAutotuner( + fn=_flash_attn_with_kvcache_autotuned_impl, + kernel_name="sgl_fa3_kvcache_ns:v1", + configs_gen_func=_flash_attn_kvcache_num_splits_configs, + static_key_func=_flash_attn_kvcache_static_key, + run_key_func=_flash_attn_kvcache_run_key, + ) + +else: + _flash_attn_with_kvcache_autotuned = None + + +flash_attn_with_kvcache = _flash_attn_with_kvcache + +if _flash_attn_with_kvcache_autotuned is not None: + + def _flash_attn_with_kvcache_autotune(q, k_cache, v_cache, **kwargs): + if ( + kwargs.get("num_splits", _DEFAULT_NUM_SPLITS) != _DEFAULT_NUM_SPLITS + or any( + kwargs.get(name) is not None + for name in ( + "k", + "v", + "out", + "qv", + "q_descale", + "k_descale", + "v_descale", + ) + ) + or not _flash_attn_is_decode_like(q, kwargs.get("page_table"), kwargs.get("max_seqlen_q")) + ): + return _flash_attn_with_kvcache(q=q, k_cache=k_cache, v_cache=v_cache, **kwargs) + + tuner = _flash_attn_with_kvcache_autotuned + call_kwargs = {"q": q, "k_cache": k_cache, "v_cache": v_cache, **kwargs} + call_kwargs.setdefault("causal", False) + call_kwargs.setdefault("window_size", (-1, -1)) + call_kwargs.setdefault("softcap", 0.0) + call_kwargs.setdefault("sinks", None) + call_kwargs.setdefault("num_splits", _DEFAULT_NUM_SPLITS) + + if get_triton_autotune_level() == AutotuneLevel.ADAPTIVE_AUTOTUNE: + static_key = frozendict(tuner._static_key(**call_kwargs)) + run_key = str(tuner._run_key(**call_kwargs)) + tuner._try_load_cache(static_key) + + if run_key not in tuner.cached_configs.get(static_key, {}) and not Autotuner.is_autotune_warmup(): + Autotuner.start_autotune_warmup() + try: + return tuner(**call_kwargs) + finally: + Autotuner.end_autotune_warmup() + + return tuner(**call_kwargs) + + flash_attn_with_kvcache_autotune = _flash_attn_with_kvcache_autotune +else: + flash_attn_with_kvcache_autotune = _flash_attn_with_kvcache + + +def fa3_decode_autotune(model, cuda_graph_batch_sizes): + if _flash_attn_with_kvcache_autotuned is None or get_triton_autotune_level() not in [ + AutotuneLevel.ADAPTIVE_AUTOTUNE, + AutotuneLevel.FORCE_AUTOTUNE, + ]: + return + + decode_backends = [ + model.decode_att_backend, + getattr(model, "decode_att_backend1", None), + ] + if not any(backend is not None and backend.__class__.__name__ == "Fa3AttBackend" for backend in decode_backends): + return + + need_end_warmup = not Autotuner.is_autotune_warmup() + if need_end_warmup: + Autotuner.start_autotune_warmup() + try: + max_kv_len = int(model.graph_max_len_in_batch) + if max_kv_len <= 0: + return + + k, v = model.mem_manager.get_att_input_params(layer_index=0) + k_cache = k.view(k.shape[0], 1, k.shape[1], k.shape[2]) + v_cache = v.view(v.shape[0], 1, v.shape[1], v.shape[2]) + q_head_num = int(model.config["num_attention_heads"]) // model.tp_world_size_ + head_dim = int(k.shape[-1]) + mtp_size = model.args.mtp_step + 1 + hold_token_memindex = model.mem_manager.HOLD_TOKEN_MEMINDEX + k[hold_token_memindex].zero_() + v[hold_token_memindex].zero_() + + for batch_size in cuda_graph_batch_sizes[::-1]: + att_batch_size = batch_size // mtp_size + if att_batch_size <= 0: + continue + + q = torch.zeros( + (att_batch_size * mtp_size, q_head_num, head_dim), + dtype=model.data_type, + device=k.device, + ) + page_table = torch.full( + (att_batch_size, max_kv_len), + hold_token_memindex, + dtype=torch.int32, + device=k.device, + ) + cache_seqlens = torch.full((att_batch_size,), max_kv_len, dtype=torch.int32, device=k.device) + cu_seqlens_q = torch.arange(att_batch_size + 1, dtype=torch.int32, device=k.device) * mtp_size + cu_seqlens_k = torch.arange(att_batch_size + 1, dtype=torch.int32, device=k.device) * max_kv_len + softmax_scale = 1.0 / (head_dim ** 0.5) + + flash_attn_with_kvcache_autotune( + q=q, + k_cache=k_cache, + v_cache=v_cache, + page_table=page_table, + cache_seqlens=cache_seqlens, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k, + max_seqlen_q=mtp_size, + softmax_scale=softmax_scale, + causal=True, + window_size=(-1, -1), + softcap=0.0, + k_descale=None, + v_descale=None, + return_softmax_lse=False, + sinks=None, + ) + finally: + if need_end_warmup: + Autotuner.end_autotune_warmup() + return