Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/CN/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ lightllm 支持大多数的主流的开源大语言模型以及多模态模型
-
* - `Qwen3-Moe <https://github.com/QwenLM/Qwen3>`_
-
* - `GLM-5.2 <https://huggingface.co/zai-org/GLM-5.2>`_
- 支持 BF16/FP8 和 MTP。


多模态模型
Expand Down Expand Up @@ -93,4 +95,4 @@ Reward模型
* - `internLM-reward <https://huggingface.co/internlm/internlm2-1_8b-reward>`_
- :code:`--use_reward_model`
* - `Qwen2-Reward <https://huggingface.co/Qwen/Qwen2-Reward>`_
- :code:`--use_reward_model`
- :code:`--use_reward_model`
3 changes: 2 additions & 1 deletion docs/EN/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ Large Language Models
-
* - `DeepSeek-V3.2 `_
-
* - `GLM-5.2 <https://huggingface.co/zai-org/GLM-5.2>`_
- Supports BF16/FP8 and MTP.

Multimodal Models
^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -94,4 +96,3 @@ Reward Models
- :code:`--use_reward_model`
* - `Qwen2-Reward <https://huggingface.co/Qwen/Qwen2-Reward>`_
- :code:`--use_reward_model`

1 change: 0 additions & 1 deletion lightllm/common/basemodel/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

# NSA backend
from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend
from .nsa.fp8_flashmla_sparse import NsaFlashMlaFp8SparseAttBackend

from .create_utils import (
get_prefill_att_backend_class,
Expand Down
4 changes: 0 additions & 4 deletions lightllm/common/basemodel/attention/create_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from .flashinfer.fp import FlashInferAttBackend
from .flashinfer.mla import MlaFlashInferAttBackend
from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend
from .nsa.fp8_flashmla_sparse import NsaFlashMlaFp8SparseAttBackend

logger = init_logger(__name__)

Expand Down Expand Up @@ -57,9 +56,6 @@
"flashmla_sparse": NsaFlashMlaSparseAttBackend,
# Future backends: "fa3", "tilelang", "aiter"
},
"fp8kv_dsa": {
"flashmla_sparse": NsaFlashMlaFp8SparseAttBackend,
},
}


Expand Down
8 changes: 0 additions & 8 deletions lightllm/common/basemodel/attention/nsa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,9 @@
NsaFlashMlaSparsePrefillAttState,
NsaFlashMlaSparseDecodeAttState,
)
from .fp8_flashmla_sparse import (
NsaFlashMlaFp8SparseAttBackend,
NsaFlashMlaFp8SparsePrefillAttState,
NsaFlashMlaFp8SparseDecodeAttState,
)

__all__ = [
"NsaFlashMlaSparseAttBackend",
"NsaFlashMlaSparsePrefillAttState",
"NsaFlashMlaSparseDecodeAttState",
"NsaFlashMlaFp8SparseAttBackend",
"NsaFlashMlaFp8SparsePrefillAttState",
"NsaFlashMlaFp8SparseDecodeAttState",
]
9 changes: 8 additions & 1 deletion lightllm/common/basemodel/attention/nsa/flashmla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import dataclasses
import torch
import torch.nn.functional as F
from typing import Tuple, TYPE_CHECKING

from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl
Expand Down Expand Up @@ -86,14 +87,20 @@ def _nsa_prefill_att(
if topk_mem_indices.ndim == 2:
topk_mem_indices = topk_mem_indices.unsqueeze(1)

real_head_num = q.shape[1]
head_block_size = 64
pad_head_num = (-real_head_num) % head_block_size
if pad_head_num:
q = F.pad(q, (0, 0, 0, pad_head_num))

mla_out, _, _ = flash_mla_sparse_fwd(
q=q,
kv=kv,
indices=topk_mem_indices,
sm_scale=softmax_scale,
d_v=kv_lora_rank,
)
return mla_out
return mla_out[:, :real_head_num, :]


@dataclasses.dataclass
Expand Down
198 changes: 0 additions & 198 deletions lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py

This file was deleted.

7 changes: 1 addition & 6 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,12 +1171,7 @@ def _init_padded_req(self):
def _gen_special_model_input(self, token_num: int):
special_model_input = {}

is_mtp_draft_model = (
"Deepseek3MTPModel" in str(self.__class__)
or "Qwen3MOEMTPModel" in str(self.__class__)
or "MistralMTPModel" in str(self.__class__)
or "Glm4MoeLiteMTPModel" in str(self.__class__)
)
is_mtp_draft_model = getattr(self, "is_mtp_draft_model", False)
if is_mtp_draft_model:
special_model_input["mtp_draft_input_hiddens"] = torch.randn(
token_num, self.config["hidden_size"], dtype=self.data_type, device="cuda"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .base_weight import BaseWeightTpl
from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_dp, get_dp_world_size
from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward
from lightllm.common.basemodel.triton_kernel.norm.fused_add_rmsnorm import fused_add_rmsnorm_forward
from lightllm.common.basemodel.triton_kernel.norm.layernorm import layernorm_forward
from lightllm.common.basemodel.triton_kernel.norm.qk_norm import qk_rmsnorm_fused_forward
from lightllm.common.basemodel.triton_kernel.norm.gated_rmsnorm import gated_rmsnorm_forward
Expand Down Expand Up @@ -71,6 +72,21 @@ def __call__(
) -> torch.Tensor:
return self._forward(input=input, eps=eps, out=out, alloc_func=alloc_func)

def fused_add_forward(
self,
residual: torch.Tensor,
x: torch.Tensor,
eps: float,
out: Optional[torch.Tensor] = None,
alloc_func=torch.empty,
) -> torch.Tensor:
"""Fused residual-add + RMSNorm: ``residual <- residual + x`` (in place) and return
``rmsnorm(residual) * weight``. Bit-identical to a plain ``residual.add_(x)`` followed
by ``__call__`` but in a single Triton launch. CUDA/MUSA (Triton) only."""
if out is None:
out = alloc_func(residual.shape, dtype=residual.dtype, device=residual.device)
return fused_add_rmsnorm_forward(residual=residual, x=x, weight=self.weight, eps=eps, out=out)


class GatedRMSNormWeight(RMSNormWeight):
def _triton_forward(
Expand Down
Loading
Loading