Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
89a27fd
feat(mtp): MTP verify-decode infrastructure
sufubao Jun 9, 2026
61e4dcb
feat(qwen3_5_mtp): Qwen3.5 / Qwen3.5-MoE MTP draft models
sufubao Jun 9, 2026
89a163a
feat(qwen3next): GDN spec-decode verify path + linear-att cache split
sufubao Jun 9, 2026
47ddb6c
feat(scheduler): MTP verify backend + accept-len transport
sufubao Jun 9, 2026
085a185
test(mtp): MTP unit tests + static benchmark
sufubao Jun 9, 2026
db50f25
Fix Qwen3Next MTP linear-att page moves
sufubao Jun 9, 2026
45ec253
revert formatting churn on pre-existing code
sufubao Jun 15, 2026
5883b41
revert(mtp): drop eagle reduced-batch draft optimization
sufubao Jun 15, 2026
82522e6
revert(mtp): run the MTP draft on upstream's grouped verify layout
sufubao Jun 15, 2026
cd6b918
clean code
sufubao Jun 16, 2026
fe9ac22
clean code
sufubao Jun 16, 2026
10473dd
refactor(mtp): GPU-resident req_to_accept_len + simplify verify-decod…
sufubao Jun 16, 2026
45831a2
revert: drop all test/ and unit_tests/ changes from this branch
sufubao Jun 16, 2026
31fa641
style: black-format fp8.py k/v_descale lines (pre-commit)
sufubao Jun 16, 2026
c4c3c2f
clean code
sufubao Jun 16, 2026
6f78b54
Merge upstream/main into qw35_mtp_feature
sufubao Jun 23, 2026
7871295
Merge remote-tracking branch 'upstream/main' into qw35_mtp_feature
sufubao Jun 29, 2026
0c2f7d0
fix
sufubao Jun 29, 2026
f75dfaf
fix
sufubao Jun 29, 2026
01447ec
Merge remote-tracking branch 'upstream/main' into qw35_mtp_feature
sufubao Jun 29, 2026
5814653
fix
sufubao Jun 29, 2026
694fbe6
fix format
sufubao Jun 29, 2026
39d3822
clean code: mtp_verify_extra_state.py
sufubao Jun 29, 2026
efda16d
clean code
sufubao Jun 29, 2026
8d51682
clean code
sufubao Jun 29, 2026
a4c79d6
clean code
sufubao Jun 29, 2026
7bc84fc
clean code
sufubao Jun 29, 2026
ebe6ae8
fix
sufubao Jun 30, 2026
fb98c5c
restore cudagraph
shihaobai Jul 1, 2026
145fb32
update infer_struct
shihaobai Jul 1, 2026
47f0e99
clean code
shihaobai Jul 1, 2026
4f85a3f
fix
sufubao Jul 1, 2026
6a05570
clean transformers layer_weight
shihaobai Jul 1, 2026
c51bc4b
Merge branch 'qw35_mtp_feature' of https://github.com/sufubao/lightll…
shihaobai Jul 1, 2026
0ee2150
clean req_manager.py
shihaobai Jul 1, 2026
d2e46e4
fix
shihaobai Jul 1, 2026
e119a68
fix
shihaobai Jul 1, 2026
c6ed3ea
clean code
shihaobai Jul 1, 2026
40a143d
clean model.py
shihaobai Jul 1, 2026
e2aab74
fix
hiworldwzj Jul 2, 2026
5671f1e
fix
hiworldwzj Jul 2, 2026
e7552cf
fix
hiworldwzj Jul 2, 2026
2beb876
fix
hiworldwzj Jul 2, 2026
e06bce6
fix
hiworldwzj Jul 2, 2026
36d65f8
fix
hiworldwzj Jul 2, 2026
104abc7
fix
hiworldwzj Jul 2, 2026
740cb2d
fix
hiworldwzj Jul 2, 2026
635e87e
fix
hiworldwzj Jul 2, 2026
041788d
fix
hiworldwzj Jul 3, 2026
44a793e
fix
hiworldwzj Jul 3, 2026
5a2fec0
fix
hiworldwzj Jul 3, 2026
744acbc
fix
hiworldwzj Jul 3, 2026
37d5e7b
fix
hiworldwzj Jul 3, 2026
29b14a0
fix
hiworldwzj Jul 3, 2026
bcc068e
fix
hiworldwzj Jul 3, 2026
a1ffb41
fix
hiworldwzj Jul 3, 2026
d19b9d8
fix
hiworldwzj Jul 3, 2026
471d1f6
fix
hiworldwzj Jul 3, 2026
de1283a
fix
hiworldwzj Jul 3, 2026
fe05478
fix
hiworldwzj Jul 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,7 @@ dist
.vscode
tmp/
requirements-musa.txt
logs/
logs/

/benchmark/
artifacts/
22 changes: 14 additions & 8 deletions lightllm/common/basemodel/attention/fa3/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,12 @@ def init_state(self):
torch.arange(batch_size, device=device), self.infer_state.b_q_seq_len
)
# 为了减少推理计算量,在推理外部初始化k_descale和v_descale
self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)

self.k_descale = (
offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
)
self.v_descale = (
offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
)

def prefill_att(
self,
Expand Down Expand Up @@ -120,16 +123,19 @@ def init_state(self):
att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1)
assert self.infer_state.batch_size % (args_mtp_step + 1) == 0

device = self.infer_state.input_ids.device
batch_size = att_batch_size
mem_manager = self.backend.model.mem_manager

offline_scales: torch.Tensor = mem_manager.scales
head_num = mem_manager.head_num

# 为了减少推理计算量,在推理外部初始化k_descale和v_descale
self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
self.k_descale = (
offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
)
self.v_descale = (
offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
)

return

Expand Down Expand Up @@ -180,11 +186,11 @@ def _fp8_decode_att(
k_cache=cache_k,
v_cache=cache_v,
page_table=self.page_table,
cache_seqlens=self.infer_state.b_seq_len,
cache_seqlens=self.b_att_seq_len,
cu_seqlens_q=self.cu_seqlens_q,
cu_seqlens_k_new=self.cu_seqlens_k,
max_seqlen_q=self.decode_max_q_seq_len,
causal=False,
causal=True,
window_size=(-1, -1),
softcap=0.0,
q_descale=q_scale.view(self.infer_state.batch_size, k_head_num),
Expand Down
7 changes: 1 addition & 6 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,12 +1171,7 @@ def _init_padded_req(self):
def _gen_special_model_input(self, token_num: int):
special_model_input = {}

is_mtp_draft_model = (
"Deepseek3MTPModel" in str(self.__class__)
or "Qwen3MOEMTPModel" in str(self.__class__)
or "MistralMTPModel" in str(self.__class__)
or "Glm4MoeLiteMTPModel" in str(self.__class__)
)
is_mtp_draft_model = getattr(self, "is_mtp_draft_model", False)
if is_mtp_draft_model:
special_model_input["mtp_draft_input_hiddens"] = torch.randn(
token_num, self.config["hidden_size"], dtype=self.data_type, device="cuda"
Expand Down
90 changes: 54 additions & 36 deletions lightllm/common/basemodel/triton_kernel/linear_att_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@

@triton.jit
def _copy_linear_att_state_to_kv_buffer(
gpu_conv_ptr, # [linear_layer_num, size_num, xdim]
gpu_ssm_ptr, # [linear_layer_num, size_num, xxdim]
cpu_kv_conv_ptr, # [size, linear_layer_num, xdim]
cpu_kv_ssm_ptr, # [size, linear_layer_num, xxdim]
gpu_conv_ptr, # uint8 view: [linear_layer_num, req_num, conv_dim, gpu_conv_row_bytes]
gpu_ssm_ptr, # uint8 view: [linear_layer_num, req_num * (mtp_step + 1), ssm_bytes]
cpu_kv_conv_ptr, # uint8 view: [buffer_num, linear_layer_num, conv_dim * cpu_conv_row_bytes]
cpu_kv_ssm_ptr, # uint8 view: [buffer_num, linear_layer_num, ssm_bytes]
b_req_idx, # [batch_size,]
big_page_buffer_ids, # [batch_size,]
gpu_conv_stride_l,
gpu_conv_stride_s,
gpu_conv_stride_c,
gpu_conv_stride_d,
gpu_ssm_stride_l,
gpu_ssm_stride_s,
Expand All @@ -24,16 +25,25 @@ def _copy_linear_att_state_to_kv_buffer(
cpu_kv_ssm_stride_l,
cpu_kv_ssm_stride_d,
mtp_step,
gpu_conv_tail_dim,
gpu_conv_dim, # number of conv rows
gpu_conv_tail_dim_bytes, # bytes copied per conv row; equals the CPU/cache row width
gpu_ssm_tail_dim,
BLOCK: tl.constexpr,
):
cur_layer = tl.program_id(0).to(tl.int64)
cur_batch = tl.program_id(1).to(tl.int64)
cpu_kv_conv_stride_s = tl.cast(cpu_kv_conv_stride_s, dtype=tl.int64)
cpu_kv_ssm_stride_s = tl.cast(cpu_kv_ssm_stride_s, dtype=tl.int64)
gpu_conv_stride_l = tl.cast(gpu_conv_stride_l, dtype=tl.int64)
gpu_conv_stride_s = tl.cast(gpu_conv_stride_s, dtype=tl.int64)
gpu_conv_stride_c = tl.cast(gpu_conv_stride_c, dtype=tl.int64)
gpu_conv_stride_d = tl.cast(gpu_conv_stride_d, dtype=tl.int64)
gpu_ssm_stride_l = tl.cast(gpu_ssm_stride_l, dtype=tl.int64)
gpu_ssm_stride_s = tl.cast(gpu_ssm_stride_s, dtype=tl.int64)
cpu_kv_conv_stride_s = tl.cast(cpu_kv_conv_stride_s, dtype=tl.int64)
cpu_kv_conv_stride_l = tl.cast(cpu_kv_conv_stride_l, dtype=tl.int64)
cpu_kv_conv_stride_d = tl.cast(cpu_kv_conv_stride_d, dtype=tl.int64)
cpu_kv_ssm_stride_s = tl.cast(cpu_kv_ssm_stride_s, dtype=tl.int64)
cpu_kv_ssm_stride_l = tl.cast(cpu_kv_ssm_stride_l, dtype=tl.int64)
gpu_conv_tail_dim_bytes = tl.cast(gpu_conv_tail_dim_bytes, dtype=tl.int64)

big_page_buffer_idx = tl.load(big_page_buffer_ids + cur_batch)
if big_page_buffer_idx == -1:
Expand All @@ -42,20 +52,16 @@ def _copy_linear_att_state_to_kv_buffer(
cur_req_idx = tl.load(b_req_idx + cur_batch).to(tl.int64)
cur_state_req_idx = (cur_req_idx * (mtp_step + 1)).to(tl.int64)

for i in range(tl.cdiv(gpu_conv_tail_dim, BLOCK)):
gpu_start_off = i * BLOCK + tl.arange(0, BLOCK)
mask = gpu_start_off < gpu_conv_tail_dim
conv_data = tl.load(
gpu_conv_ptr + cur_layer * gpu_conv_stride_l + cur_state_req_idx * gpu_conv_stride_s + gpu_start_off,
mask=mask,
)
dest_conv_ptr = (
cpu_kv_conv_ptr
+ big_page_buffer_idx * cpu_kv_conv_stride_s
+ cur_layer * cpu_kv_conv_stride_l
+ gpu_start_off
)
tl.store(dest_conv_ptr, conv_data, mask=mask)
gpu_conv_base = gpu_conv_ptr + cur_layer * gpu_conv_stride_l + cur_req_idx * gpu_conv_stride_s
cpu_conv_base = cpu_kv_conv_ptr + big_page_buffer_idx * cpu_kv_conv_stride_s + cur_layer * cpu_kv_conv_stride_l
conv_tail_dim = gpu_conv_dim * gpu_conv_tail_dim_bytes
for i in range(tl.cdiv(conv_tail_dim, BLOCK)):
conv_start = i * BLOCK + tl.arange(0, BLOCK)
conv_row = conv_start // gpu_conv_tail_dim_bytes
conv_col = conv_start % gpu_conv_tail_dim_bytes
mask = conv_start < conv_tail_dim
conv_data = tl.load(gpu_conv_base + conv_row * gpu_conv_stride_c + conv_col, mask=mask)
tl.store(cpu_conv_base + conv_start, conv_data, mask=mask)

for i in range(tl.cdiv(gpu_ssm_tail_dim, BLOCK)):
gpu_start_off = i * BLOCK + tl.arange(0, BLOCK)
Expand All @@ -75,36 +81,46 @@ def _copy_linear_att_state_to_kv_buffer(
def copy_linear_att_state_to_kv_buffer(
b_req_idx: torch.Tensor,
big_page_buffer_ids: torch.Tensor,
gpu_conv_state: torch.Tensor, # [linear_layer_num, s, ...]
gpu_ssm_state: torch.Tensor, # [linear_layer_num, s, ...]
cpu_kv_conv_state: torch.Tensor, # [s, linear_layer_num, ...]
cpu_kv_ssm_state: torch.Tensor, # [s, linear_layer_num, ...]
gpu_conv_state: torch.Tensor, # [linear_layer_num, req_num, conv_dim, kernel_size]
gpu_ssm_state: torch.Tensor, # [linear_layer_num, req_num * (mtp_step + 1), ...]
cpu_kv_conv_state: torch.Tensor, # [buffer_num, linear_layer_num, conv_dim, kernel_size]
cpu_kv_ssm_state: torch.Tensor, # [buffer_num, linear_layer_num, ...]
mtp_step: int,
):
# gpu_conv_state 的后两维可能是不连续的。
assert len(b_req_idx) == big_page_buffer_ids.shape[0]
BLOCK = 4096
gpu_conv_state = gpu_conv_state.view(gpu_conv_state.shape[0], gpu_conv_state.shape[1], -1).view(dtype=torch.uint8)

assert gpu_conv_state.dim() == 4, "gpu_conv_state must be [layer, s, conv_dim, widened_width]"
assert cpu_kv_conv_state.dim() == 4, "cpu_kv_conv_state must be [size, layer, conv_dim, width_narrow]"
# 因为存在mtp模式,gpu_conv_state 的最后一个维度可能存在冗余的部分,需要进行切片对齐。
gpu_conv_state = gpu_conv_state[:, :, :, :cpu_kv_conv_state.shape[-1]]
gpu_conv_state = gpu_conv_state.view(
gpu_conv_state.shape[0], gpu_conv_state.shape[1], gpu_conv_state.shape[2], -1
).view(dtype=torch.uint8)
cpu_kv_conv_state = cpu_kv_conv_state.view(
cpu_kv_conv_state.shape[0], cpu_kv_conv_state.shape[1], -1
).view(dtype=torch.uint8)
gpu_ssm_state = gpu_ssm_state.view(gpu_ssm_state.shape[0], gpu_ssm_state.shape[1], -1).view(dtype=torch.uint8)
cpu_kv_conv_state = cpu_kv_conv_state.view(cpu_kv_conv_state.shape[0], cpu_kv_conv_state.shape[1], -1).view(
dtype=torch.uint8
)
cpu_kv_ssm_state = cpu_kv_ssm_state.view(cpu_kv_ssm_state.shape[0], cpu_kv_ssm_state.shape[1], -1).view(
dtype=torch.uint8
)
assert gpu_conv_state.shape[-1] == cpu_kv_conv_state.shape[-1]
assert gpu_ssm_state.shape[-1] == cpu_kv_ssm_state.shape[-1]

gpu_conv_dim = gpu_conv_state.shape[2]
gpu_conv_tail_dim_bytes = gpu_conv_state.shape[3]

assert gpu_conv_tail_dim_bytes * gpu_conv_dim == cpu_kv_conv_state.shape[-1]

assert (
gpu_conv_state.stride(-1)
== gpu_ssm_state.stride(-1)
== cpu_kv_conv_state.stride(-1)
== cpu_kv_ssm_state.stride(-1)
== 1
)

gpu_conv_tail_dim = gpu_conv_state.shape[-1]
gpu_ssm_tail_dim = gpu_ssm_state.shape[-1]

layer_num = gpu_conv_state.shape[0]

grid = (layer_num, b_req_idx.shape[0])

_copy_linear_att_state_to_kv_buffer[grid](
Expand All @@ -116,7 +132,8 @@ def copy_linear_att_state_to_kv_buffer(
big_page_buffer_ids=big_page_buffer_ids,
gpu_conv_stride_l=gpu_conv_state.stride(0),
gpu_conv_stride_s=gpu_conv_state.stride(1),
gpu_conv_stride_d=gpu_conv_state.stride(2),
gpu_conv_stride_c=gpu_conv_state.stride(2),
gpu_conv_stride_d=gpu_conv_state.stride(3),
gpu_ssm_stride_l=gpu_ssm_state.stride(0),
gpu_ssm_stride_s=gpu_ssm_state.stride(1),
gpu_ssm_stride_d=gpu_ssm_state.stride(2),
Expand All @@ -127,7 +144,8 @@ def copy_linear_att_state_to_kv_buffer(
cpu_kv_ssm_stride_l=cpu_kv_ssm_state.stride(1),
cpu_kv_ssm_stride_d=cpu_kv_ssm_state.stride(2),
mtp_step=mtp_step,
gpu_conv_tail_dim=gpu_conv_tail_dim,
gpu_conv_dim=gpu_conv_dim,
gpu_conv_tail_dim_bytes=gpu_conv_tail_dim_bytes,
gpu_ssm_tail_dim=gpu_ssm_tail_dim,
BLOCK=BLOCK,
)
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,7 @@ def copy_kv_buffer_to_cpu_cache(
cpu_kv_ssm_tail_dim = cpu_kv_ssm_state.shape[-1]
full_att_layer_num = gpu_kv_full_att_state.shape[-2]

assert (
full_att_layer_num
== (linear_config.all_layer_num // linear_config.full_attention_interval)
== (linear_config.all_layer_num - linear_config.linear_layer_num)
)
assert full_att_layer_num == linear_config.get_full_att_kv_layer_num_with_draft_model()
assert gpu_full_att_tail_dim == cpu_cache_full_att.shape[-1]
assert cpu_cache_conv.shape[-1] == cpu_kv_conv_state.shape[-1]
assert cpu_cache_ssm.shape[-1] == cpu_kv_ssm_state.shape[-1]
Expand Down Expand Up @@ -388,7 +384,6 @@ def copy_cpu_cache_to_kv_buffer(
linear_config: LinearAttCacheConfig,
grid_num: int = 12,
):

assert len(mem_indexes) % len(page_indexes) == 0

BLOCK = 4096
Expand Down
71 changes: 71 additions & 0 deletions lightllm/common/basemodel/triton_kernel/mtp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,77 @@ def gen_b_req_mtp_start_loc(b_mtp_index: torch.Tensor, num_reqs: int):
return b_req_mtp_start_loc


@triton.jit
def _fwd_kernel_linear_att_mtp_state_index_update(
req_to_mtp_state_index,
b_req_mtp_start_loc,
b_req_idx,
b_mtp_index,
accepted_index,
req_mtp_all_num,
BLOCK_SIZE: tl.constexpr,
):
cur_index = tl.program_id(0)
req_nums = tl.num_programs(axis=0)

req_start_loc = tl.load(b_req_mtp_start_loc + cur_index)
req_start_end = tl.load(b_req_mtp_start_loc + cur_index + 1, mask=cur_index + 1 < req_nums, other=req_mtp_all_num)
req_mtp_num = req_start_end - req_start_loc
cur_req_idx = tl.load(b_req_idx + req_start_loc)

offset = tl.arange(0, BLOCK_SIZE)
req_offset = req_start_loc + offset

cur_mtp_index = tl.load(b_mtp_index + req_offset, mask=offset < req_mtp_num, other=-1)
cur_accepted = tl.load(accepted_index + req_offset, mask=offset < req_mtp_num, other=0)

valid_mtp_index = tl.where(cur_accepted == 1, cur_mtp_index, -1)
max_mtp_index = tl.max(valid_mtp_index, axis=0)

tl.store(req_to_mtp_state_index + cur_req_idx, max_mtp_index)
return


def linear_att_mtp_state_index_update(
req_to_mtp_state_index: torch.Tensor,
b_req_mtp_start_loc: torch.Tensor,
b_req_idx: torch.Tensor,
b_mtp_index: torch.Tensor,
accepted_index: torch.Tensor,
max_mtp_step: int,
):
"""
Update req_to_mtp_state_index with the max b_mtp_index among accepted tokens per request.
Args:
req_to_mtp_state_index: (max_req_num + 1,)
b_req_mtp_start_loc: (num_reqs,)
b_req_idx: (batch_size,)
b_mtp_index: (batch_size,)
accepted_index: (batch_size,), 1 means accepted, 0 means not accepted.
max_mtp_step: max mtp step per request, typically mtp_step + 1.
"""
BLOCK_SIZE = 16
assert max_mtp_step <= BLOCK_SIZE, f"max_mtp_step must be less than {BLOCK_SIZE}"
num_reqs = b_req_mtp_start_loc.shape[0]
req_mtp_all_num = b_req_idx.shape[0]

assert len(b_req_idx) == len(b_mtp_index) == len(accepted_index)

grid = (num_reqs,)
num_warps = 1
_fwd_kernel_linear_att_mtp_state_index_update[grid](
req_to_mtp_state_index=req_to_mtp_state_index,
b_req_mtp_start_loc=b_req_mtp_start_loc,
b_req_idx=b_req_idx,
b_mtp_index=b_mtp_index,
accepted_index=accepted_index,
req_mtp_all_num=req_mtp_all_num,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
num_stages=1,
)


def test_mtp_verify():
req_to_next_token_ids = torch.tensor(
[[1, 2, -2, -1, -1], [1, 2, 0, -1, -1], [1, 3, 4, 4, 5]], dtype=torch.int32, device="cuda"
Expand Down
25 changes: 15 additions & 10 deletions lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,8 @@ def write_req_to_page(
dp_mems: List["Qwen3NextMemManager"],
):
conv_page, ssm_page = self.view_page_to_linear_att_state(page_index)
req_buffer_idx = req_idx * (get_env_start_args().mtp_step + 1)
for tp_index, mem in enumerate(dp_mems):
self._write_one_rank(mem, tp_index, req_buffer_idx, conv_page, ssm_page)
self._write_one_rank(mem, tp_index, req_idx, conv_page, ssm_page)
return

def read_page_to_req(
Expand All @@ -220,21 +219,26 @@ def read_page_to_req(
dp_mems: List["Qwen3NextMemManager"],
):
conv_page, ssm_page = self.view_page_to_linear_att_state(page_index)
req_buffer_idx = req_idx * (get_env_start_args().mtp_step + 1)
for tp_index, mem in enumerate(dp_mems):
self._read_one_rank(mem, tp_index, req_buffer_idx, conv_page, ssm_page)
self._read_one_rank(mem, tp_index, req_idx, conv_page, ssm_page)
return

def _get_req_state_indexes(self, req_idx: int):
mtp_size = get_env_start_args().mtp_step + 1
# Conv is one widened slot per request; SSM keeps the historical S+1 block layout.
return req_idx, req_idx * mtp_size

def _write_one_rank(
self,
mem: "Qwen3NextMemManager",
tp_index: int,
req_buffer_idx: int,
req_idx: int,
conv_page: torch.Tensor,
ssm_page: torch.Tensor,
):
conv_state = mem.req_to_conv_state.buffer[:, req_buffer_idx, ...]
ssm_state = mem.req_to_ssm_state.buffer[:, req_buffer_idx, ...]
conv_req_idx, ssm_req_idx = self._get_req_state_indexes(req_idx)
conv_state = mem.req_to_conv_state.buffer[:, conv_req_idx, ..., : self.conv_shape[-1]]
ssm_state = mem.req_to_ssm_state.buffer[:, ssm_req_idx, ...]
self._copy_conv_state_to_page(conv_state, conv_page, mem, tp_index)
self._copy_ssm_state_to_page(ssm_state, ssm_page, mem, tp_index)
return
Expand Down Expand Up @@ -408,12 +412,13 @@ def _read_one_rank(
self,
mem: "Qwen3NextMemManager",
tp_index: int,
req_buffer_idx: int,
req_idx: int,
conv_page: torch.Tensor,
ssm_page: torch.Tensor,
):
conv_state = mem.req_to_conv_state.buffer[:, req_buffer_idx, ...]
ssm_state = mem.req_to_ssm_state.buffer[:, req_buffer_idx, ...]
conv_req_idx, ssm_req_idx = self._get_req_state_indexes(req_idx)
conv_state = mem.req_to_conv_state.buffer[:, conv_req_idx, ..., : self.conv_shape[-1]]
ssm_state = mem.req_to_ssm_state.buffer[:, ssm_req_idx, ...]
self._copy_page_to_conv_state(conv_page, conv_state, mem, tp_index)
self._copy_page_to_ssm_state(ssm_page, ssm_state, mem, tp_index)
return
Loading
Loading