Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lightllm/common/basemodel/attention/fa3/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl
from typing import Optional, TYPE_CHECKING
from lightllm.utils.dist_utils import get_current_device_id
from lightllm.utils.sgl_utils import flash_attn_with_kvcache
from lightllm.utils.sgl_utils import flash_attn_with_kvcache, flash_attn_with_kvcache_autotune
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy
from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor
Expand Down Expand Up @@ -222,7 +222,7 @@ def _normal_decode_att(
k_descale, v_descale = None, None # disable quantization
Lq = q.shape[-1]
sm_scale = 1.0 / (Lq ** 0.5)
o = flash_attn_with_kvcache(
o = flash_attn_with_kvcache_autotune(
q=q,
k_cache=k.view(k.shape[0], 1, k.shape[1], k.shape[2]),
v_cache=v.view(v.shape[0], 1, v.shape[1], v.shape[2]),
Expand Down
2 changes: 1 addition & 1 deletion lightllm/common/basemodel/attention/nsa/flashmla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _nsa_decode_att(
kv: torch.Tensor,
att_control: AttControl,
) -> torch.Tensor:
from sgl_kernel.flash_attn import flash_attn_with_kvcache
from lightllm.utils.sgl_utils import flash_attn_with_kvcache

nsa_dict = att_control.nsa_decode_dict
topk_mem_indices = nsa_dict["topk_mem_indices"]
Expand Down
18 changes: 17 additions & 1 deletion lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def __init__(self, kvargs):
logger.info(f"use decode att backend1: {self.decode_att_backend1.__class__.__name__}")

self._autotune_warmup()
self._extra_autotune()
self._init_padded_req()
self._init_cudagraph()
self._init_prefill_cuda_graph()
Expand Down Expand Up @@ -286,6 +287,21 @@ def _init_prefill_cuda_graph(self):
else:
self.prefill_graph.warmup(self)

@final
@torch.no_grad()
@post_empty_cache
def _extra_autotune(self):
if self.disable_cudagraph:
return
from lightllm.common.basemodel.cuda_graph import gen_cuda_graph_batch_sizes
from lightllm.utils.sgl_utils import fa3_decode_autotune

cuda_graph_batch_sizes = gen_cuda_graph_batch_sizes(
max_batch_size=self.graph_max_batch_size,
tp_world_size=self.tp_world_size_,
)
fa3_decode_autotune(self, cuda_graph_batch_sizes)

def _init_custom(self):
pass

Expand Down Expand Up @@ -1050,7 +1066,7 @@ def _autotune_warmup(self):
Autotuner.start_autotune_warmup()
torch.distributed.barrier()

warmup_lengths = [1, 8, 16, 32, 64, 100, 128, 256, 1024, 2048, 4096]
warmup_lengths = [1, 4, 8, 16, 32, 64, 100, 128, 256, 1024, 2048, 4096]

if self.batch_max_tokens not in warmup_lengths:
warmup_lengths.append(self.batch_max_tokens)
Expand Down
55 changes: 33 additions & 22 deletions lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,34 @@
logger = init_logger(__name__)


def gen_cuda_graph_batch_sizes(max_batch_size=8, tp_world_size: int = 1):
args = get_env_start_args()
mtp_size = args.mtp_step + 1

# gen cuda graph batch_sizes
# cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size]
# and [graph_split_batch_size + graph_grow_step_size,
# if the mtp_step is not 0, then the batch_sizes will be multiply of (mtp_step + 1)

graph_split_batch_size = args.graph_split_batch_size * mtp_size
graph_grow_step_size = args.graph_grow_step_size * mtp_size

batch_sizes = [i * mtp_size for i in range(1, args.graph_split_batch_size + 1)]
for _batch_size in range(graph_split_batch_size + graph_grow_step_size, max_batch_size, graph_grow_step_size):
batch_sizes.append(_batch_size)

batch_sizes = list(set([e for e in batch_sizes if e < max_batch_size]))
batch_sizes.append(max_batch_size)
batch_sizes.sort()
if args.enable_tpsp_mix_mode:
batch_sizes = [triton.cdiv(e, tp_world_size) * tp_world_size for e in batch_sizes]
batch_sizes = list(set(batch_sizes))
batch_sizes.sort()

assert batch_sizes[-1] == max_batch_size
return batch_sizes


class CudaGraph:
# CudaGraph forward pass for the decoding stage.

Expand All @@ -27,28 +55,11 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192, tp_world_size: int =
self.graph_max_len_in_batch = max_len_in_batch
self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap

# gen cuda graph batch_sizes
# cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size]
# and [graph_split_batch_size + graph_grow_step_size,
# if the mtp_step is not 0, then the batch_sizes will be multiply of (mtp_step + 1)

graph_split_batch_size = self.args.graph_split_batch_size * (self.mtp_step + 1)
graph_grow_step_size = self.args.graph_grow_step_size * (self.mtp_step + 1)

batch_sizes = [i * (self.mtp_step + 1) for i in range(1, self.args.graph_split_batch_size + 1)]
for _batch_size in range(graph_split_batch_size + graph_grow_step_size, max_batch_size, graph_grow_step_size):
batch_sizes.append(_batch_size)

batch_sizes = list(set([e for e in batch_sizes if e < max_batch_size]))
batch_sizes.append(max_batch_size)
batch_sizes.sort()
if self.args.enable_tpsp_mix_mode:
batch_sizes = [triton.cdiv(e, self.tp_world_size) * self.tp_world_size for e in batch_sizes]
batch_sizes = list(set(batch_sizes))
batch_sizes.sort()

self.cuda_graph_batch_sizes = batch_sizes
assert batch_sizes[-1] == self.max_batch_size
self.cuda_graph_batch_sizes = gen_cuda_graph_batch_sizes(
max_batch_size=max_batch_size,
tp_world_size=tp_world_size,
)
assert self.cuda_graph_batch_sizes[-1] == self.max_batch_size
logger.info(f"cuda graph batch_sizes: {self.cuda_graph_batch_sizes}")

def can_run(self, batch_size, max_len_in_batch):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def moe_align2(token_num_mul_topk_num: int, exports_token_num: torch.Tensor, blo
out tensor is a tensor that contain block schduel infos tensor.
"""
max_num_tokens_padded = token_num_mul_topk_num + exports_token_num.shape[0] * (block_m - 1)
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_m)
max_num_m_blocks = min(token_num_mul_topk_num, triton.cdiv(max_num_tokens_padded, block_m))
# first is expert, second is m_index, third is token_start_index
mblocks_to_tuple_info = torch.empty((max_num_m_blocks, 3), dtype=torch.int32, device="cuda")

Expand Down Expand Up @@ -760,7 +760,7 @@ def _get_grouped_matmul_configs():
}
for ns in [2, 3, 4, 5]
for gm in [1, 16, 32, 64]
for nw in [4, 8]
for nw in [2, 4, 8]
for bm in [16, 32, 64, 128]
for bn in [16, 32, 64, 128]
for bk in [32, 64, 128]
Expand Down
2 changes: 2 additions & 0 deletions lightllm/common/triton_utils/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import orjson
import os
import inspect
import gc
import torch
import torch.distributed as dist
import random
Expand Down Expand Up @@ -275,6 +276,7 @@ def kernel_call():
return float("inf")

def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size):

is_key_all_same = True
if world_size > 1:
all_keys = [None for _ in range(world_size)]
Expand Down
Loading
Loading