diff --git a/README-ZH.md b/README-ZH.md
index b5592d1e9..9106c4153 100644
--- a/README-ZH.md
+++ b/README-ZH.md
@@ -17,6 +17,8 @@ mllm
## 最新动态
+- [2026 年 6 月 8 日] `pymllm` 已覆盖 Qwen3、Qwen3-VL 与 Qwen3.5 在 Jetson Orin 上的 W4A16 / W8A8 serving;Qwen3-VL-2B W8A8 在 AGX Orin 32GB 上最高达到 3.12x prefill 加速比,decode 吞吐整体与 llama.cpp 接近。
+- [2026 年 4 月 30 日] `pymllm` 新增面向 Jetson 的 Qwen3 / Qwen3-VL BF16、W4A16 和 W8A8 serving 支持,覆盖 compressed-tensors AWQ 与 W8A8 INT8 路径。
- [2026 年 3 月 18 日] 🔥🔥🔥 `pymllm` 已支持在 Jetson Orin 和 Jetson Thor 设备上使用 CUDA(实验特性,仍在持续开发中)。
- [2026 年 2 月 3 日] 🔥🔥🔥 MLLM Qnn AOT 已支持在 NPU 上全图执行, [技术报告](https://chenghuawang.github.io/News/2026-01-29-mllm-qnn-aot-support/)
- [2025 年 11 月 27 日] Android Demo 更新:通过一种全新的 In-App Go 服务架构,在 Android 上实现了 Qwen3 和 DeepSeek-OCR 的稳定流式推理。
@@ -29,6 +31,29 @@ mllm
- 更加完善、精细的工程实现
- [2025 年 7 月 30 日] 为 QNN 后端模型新增旋转量化(Rotation Quantization)方法,并支持 Qwen-2-VL 2B(ViT 性能分析将在 v2 中集成)
+## Jetson Orin CUDA Runtime
+
+`pymllm` 现已支持 Qwen3、Qwen3-VL 与 Qwen3.5 在 Jetson Orin 上运行,覆盖 BF16 serving 以及 W4A16、W8A8 两种量化 serving 路径。其中,W4A16 使用 AWQ compressed tensors 与 Marlin GEMM,W8A8 使用 Triton per-token activation quantization 与 CUTLASS INT8 GEMM。
+
+在 `input_len=2048`、`output_len=128` 的测速口径下,`pymllm` 在 Jetson Orin 上的 prefill 性能相对 llama.cpp 有明显提升。Qwen3-VL-2B W8A8 在 AGX Orin 32GB 上最高达到 **3.12x prefill 加速比**,prefill 吞吐约 **12243 tok/s**。decode 吞吐整体与 llama.cpp 接近,不同模型、设备和量化格式下会有小幅领先或回落。
+
+
+

+
+
+
+

+
+
+对于多模态 prefill,`bench_one_batch --image` 测量“视觉编码 + 图像/文本 token prefill”的完整路径。下表使用 `input_len=2048`,TPS 为多次运行的 mean latency 计算结果。
+
+| Device | Model | FP16 | W4A16 | W8A8 |
+|---|---|---:|---:|---:|
+| AGX Orin 32GB | Qwen3-VL-2B | 4875.75 | 4700.28 | 6443.59 |
+| AGX Orin 32GB | Qwen3-VL-4B | - | 2499.46 | 3837.07 |
+| Orin NX 16GB | Qwen3-VL-2B | 2438.27 | 2494.89 | 3200.40 |
+| Orin NX 16GB | Qwen3-VL-4B | - | 1231.21 | 1673.93 |
+
## Android Demo & Architecture
我们已对 Android 端实现进行了重构,采用了一种稳健的、完全在设备端运行的 **Client-Server** 架构。
@@ -75,17 +100,21 @@ mllm 框架可以与主流社区框架的模型检查点无缝集成。通过 ml
### mllm v2
-| Model(v2) | CPU | Hexagon NPU
INT8 |
-|-----------------------------------------------------------------------------|------|-----------------------|
-| [Qwen3-0.6B](https://github.com/QwenLM/Qwen3) | [✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/Qwen3-0.6B-w4a32kai) | |
-| [Qwen3-1.7B](https://github.com/QwenLM/Qwen3) | [✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/Qwen3-1.7B-w4a8-i8mm-kai) | [W4A16-SM8650](https://modelscope.cn/models/mllmTeam/Qwen3-1.7B-Qnn-AOT-SM8650/summary) |
-| [Qwen3-4B](https://github.com/QwenLM/Qwen3) | [✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/Qwen3-4B-w4a8-i8mm-kai) | |
-| [DeepSeek-OCR](https://github.com/deepseek-ai/DeepSeek-OCR) | [✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/DeepSeek-OCR-w4a8-i8mm-kai) | |
-| [SmolLM3](https://huggingface.co/blog/smollm3)| [✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/SmolLM3-3B-w4a8-i8mm-kai) | |
-| [Qwen2-VL-2B-Instruct](https://qwenlm.github.io/zh/blog/qwen2-vl/)|[✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/Qwen2-VL-2B-Instruct-w4a32kai) ||
-| [Qwen2-VL-7B-Instruct](https://qwenlm.github.io/zh/blog/qwen2-vl/)|[✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/Qwen2-VL-7B-Instruct-w4a32kai)||
-| [Qwen2.5-VL-3B-Instruct](https://qwenlm.github.io/blog/qwen2.5-vl/)|[✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/Qwen2.5-VL-3B-Instruct-w4a32kai)||
-| [Qwen2.5-VL-7B-Instruct](https://qwenlm.github.io/blog/qwen2.5-vl/)|[✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/Qwen2.5-VL-7B-Instruct-w4a32kai)||
+| Model(v2) | CPU | Jetson Orin CUDA | Hexagon NPU
INT8 |
+|-----------------------------------------------------------------------------|------|------------------|-----------------------|
+| [Qwen3-0.6B](https://github.com/QwenLM/Qwen3) | [✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/Qwen3-0.6B-w4a32kai) | | |
+| [Qwen3-1.7B](https://github.com/QwenLM/Qwen3) | [✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/Qwen3-1.7B-w4a8-i8mm-kai) | | [W4A16-SM8650](https://modelscope.cn/models/mllmTeam/Qwen3-1.7B-Qnn-AOT-SM8650/summary) |
+| [Qwen3-4B](https://github.com/QwenLM/Qwen3) | [✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/Qwen3-4B-w4a8-i8mm-kai) | | |
+| Qwen3.5-2B | | ✔️ W4A16 / W8A8 | |
+| Qwen3.5-4B | | ✔️ W4A16 / W8A8 | |
+| Qwen3-VL-2B-Instruct | | ✔️ W4A16 / W8A8 | |
+| Qwen3-VL-4B-Instruct | | ✔️ W4A16 / W8A8 | |
+| [DeepSeek-OCR](https://github.com/deepseek-ai/DeepSeek-OCR) | [✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/DeepSeek-OCR-w4a8-i8mm-kai) | | |
+| [SmolLM3](https://huggingface.co/blog/smollm3)| [✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/SmolLM3-3B-w4a8-i8mm-kai) | | |
+| [Qwen2-VL-2B-Instruct](https://qwenlm.github.io/zh/blog/qwen2-vl/) | [✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/Qwen2-VL-2B-Instruct-w4a32kai) | | |
+| [Qwen2-VL-7B-Instruct](https://qwenlm.github.io/zh/blog/qwen2-vl/) | [✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/Qwen2-VL-7B-Instruct-w4a32kai) | | |
+| [Qwen2.5-VL-3B-Instruct](https://qwenlm.github.io/blog/qwen2.5-vl/) | [✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/Qwen2.5-VL-3B-Instruct-w4a32kai) | | |
+| [Qwen2.5-VL-7B-Instruct](https://qwenlm.github.io/blog/qwen2.5-vl/) | [✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/Qwen2.5-VL-7B-Instruct-w4a32kai) | | |
### mllm v1
diff --git a/README.md b/README.md
index 7356742cc..8d43cc472 100644
--- a/README.md
+++ b/README.md
@@ -17,6 +17,8 @@ mllm
## Latest News
+- [2026 Jun 08] `pymllm` now covers Qwen3, Qwen3-VL, and Qwen3.5 on Jetson Orin with W4A16 / W8A8 serving; Qwen3-VL-2B W8A8 reaches up to 3.12x prefill speedup on AGX Orin 32GB, while decode throughput stays broadly close to llama.cpp.
+- [2026 Apr 30] `pymllm` adds Jetson-oriented Qwen3 / Qwen3-VL BF16, W4A16, and W8A8 serving support, including compressed-tensors AWQ and W8A8 INT8 paths.
- [2026 Mar 18] 🔥🔥🔥 `pymllm` now supports CUDA on Jetson Orin and Jetson Thor devices (experimental; still under active development).
- [2026 Feb 03] 🔥🔥🔥 MLLM Qnn AOT Support for Full Graph Execution on NPU! [Quick Start](https://ubiquitouslearning.github.io/mllm/qnn_backend/aot_execute.html), [Technical Report](https://chenghuawang.github.io/News/2026-01-29-mllm-qnn-aot-support-en/)
- [2025 Nov 27] Android Demo Update: Enabled stable Qwen3 and DeepSeek-OCR streaming on Android via a novel In-App Go Server Architecture.
@@ -28,6 +30,29 @@ mllm
- A more refined engineering implementation
- [2025 Jul 30] Add Rotation Quantization method for QNN backend models and support Qwen-2-VL 2B(ViT profiling will integrate in v2)
+## Jetson Orin CUDA Runtime
+
+`pymllm` now supports Qwen3, Qwen3-VL, and Qwen3.5 on Jetson Orin with BF16 serving plus W4A16 and W8A8 quantized serving. The W4A16 path uses AWQ compressed tensors and Marlin GEMM. The W8A8 path uses Triton per-token activation quantization and CUTLASS INT8 GEMM.
+
+For `input_len=2048` and `output_len=128`, `pymllm` shows strong prefill gains over llama.cpp on Jetson Orin. Qwen3-VL-2B W8A8 reaches up to **3.12x prefill speedup** on AGX Orin 32GB and about **12243 tok/s** prefill throughput. Decode throughput is generally close to llama.cpp, with small wins or losses depending on model, device, and quantization.
+
+
+

+
+
+
+

+
+
+For multimodal prefill, `bench_one_batch --image` measures the full path of vision encoding plus image/text token prefill. The table below uses `input_len=2048` and reports mean TPS across repeated runs.
+
+| Device | Model | FP16 | W4A16 | W8A8 |
+|---|---|---:|---:|---:|
+| AGX Orin 32GB | Qwen3-VL-2B | 4875.75 | 4700.28 | 6443.59 |
+| AGX Orin 32GB | Qwen3-VL-4B | - | 2499.46 | 3837.07 |
+| Orin NX 16GB | Qwen3-VL-2B | 2438.27 | 2494.89 | 3200.40 |
+| Orin NX 16GB | Qwen3-VL-4B | - | 1231.21 | 1673.93 |
+
## Android Demo & Architecture
We have refactored the Android implementation to use a robust **Client-Server** architecture entirely on-device.
diff --git a/assets/jetson/pymllm-jetson-prefill-throughput-2048.jpg b/assets/jetson/pymllm-jetson-prefill-throughput-2048.jpg
new file mode 100644
index 000000000..1ea55c57f
Binary files /dev/null and b/assets/jetson/pymllm-jetson-prefill-throughput-2048.jpg differ
diff --git a/assets/jetson/pymllm-jetson-speedup-summary-2048.jpg b/assets/jetson/pymllm-jetson-speedup-summary-2048.jpg
new file mode 100644
index 000000000..f3f79aa3c
Binary files /dev/null and b/assets/jetson/pymllm-jetson-speedup-summary-2048.jpg differ
diff --git a/bench_assets/two_cats.jpg b/bench_assets/two_cats.jpg
new file mode 100644
index 000000000..2af1b8317
Binary files /dev/null and b/bench_assets/two_cats.jpg differ
diff --git a/bench_assets/two_cats_480p.jpg b/bench_assets/two_cats_480p.jpg
new file mode 100644
index 000000000..40fb7e7af
Binary files /dev/null and b/bench_assets/two_cats_480p.jpg differ
diff --git a/docs/pymllm_runtime/developer_guide.rst b/docs/pymllm_runtime/developer_guide.rst
index 47b528659..47dbe0742 100644
--- a/docs/pymllm_runtime/developer_guide.rst
+++ b/docs/pymllm_runtime/developer_guide.rst
@@ -4,13 +4,13 @@ pymllm Developer Guide
总览
----------------------------------------
-本文档面向希望为 ``pymllm`` 增加模型、量化格式、kernel 或性能优化的开发者。当前代码处在
-快速演进阶段,推荐遵循“小步验证、边界清晰、先单测后服务级验证”的工作方式。
+这份文档写给想给 ``pymllm`` 加模型、加量化格式、加 kernel 或做性能优化的开发者。代码还在快速
+演进,建议的工作方式是“小步验证、边界清晰、先单测再服务级验证”。
开发环境建议
----------------------------------------
-推荐使用 editable install,便于修改 Python 代码后直接验证:
+推荐用 editable install,改完 Python 代码能直接验证:
.. code-block:: bash
@@ -28,9 +28,9 @@ pymllm Developer Guide
print("ok")
PY
-``mllm-kernel`` 的 JIT 编译产物会写入 ``~/.cache/mllm_kernel``。正常修改后重新运行
-会触发相应 kernel 的加载或编译;只有在验证首次编译行为、排查失败缓存、或更换 CUTLASS
-等外部头文件来源时,才需要清理对应缓存:
+``mllm-kernel`` 的 JIT 编译产物写在 ``~/.cache/mllm_kernel``。正常改完代码重新跑,会按需触发
+对应 kernel 的加载或编译;只有在验证首次编译行为、排查失败缓存、或者换了 CUTLASS 之类外部头
+文件来源时,才需要手动清对应缓存:
.. code-block:: bash
@@ -39,17 +39,17 @@ pymllm Developer Guide
新增模型
----------------------------------------
-新增模型时,优先复用现有 ``pymllm.layers`` 和 ``pymllm.executor`` 约定,而不是把
-HuggingFace 模型直接包进服务。
+加模型时,优先复用现有的 ``pymllm.layers`` 和 ``pymllm.executor`` 约定,别把 HuggingFace 模型
+整个塞进服务。
推荐步骤:
1. 新增 ``pymllm/models/.py``。
2. 在 ``pymllm/models/__init__.py`` 注册 architecture 字符串。
-3. 实现模型类,保持 ``forward(input_ids, positions, forward_batch)`` 风格。
+3. 实现模型类,保持 ``forward(input_ids, positions, forward_batch)`` 的风格。
4. 所有 linear layer 都接受 ``quant_method``。
-5. 实现 ``load_weights``,处理 checkpoint key、stacked projection 和 tied embedding。
-6. 增加最小单测。
+5. 实现 ``load_weights``,处理好 checkpoint key、stacked projection 和 tied embedding。
+6. 补最小单测。
7. 最后做服务级 smoke test。
最小测试建议:
@@ -63,38 +63,38 @@ HuggingFace 模型直接包进服务。
新增量化 scheme
----------------------------------------
-新增量化路径时,不建议在模型文件里写格式判断。推荐保持以下分层:
+加量化路径时,别在模型文件里写格式判断。保持这三层:
.. code-block:: text
QuantizationConfig
- parses checkpoint config
- decides whether a layer is quantized
+ 解析 checkpoint config
+ 决定某个 layer 是否量化
LinearMethod
- owns linear layer lifecycle
+ 承接 linear layer 生命周期
Scheme
- owns checkpoint-facing params
- owns post-load layout conversion
- owns kernel apply path
+ 管 checkpoint-facing 参数
+ 管 post-load layout 转换
+ 管 kernel apply 路径
-``create_weights`` 应注册 checkpoint-facing 参数名。``process_weights_after_loading`` 应作为
-checkpoint layout 到 runtime kernel layout 的唯一转换边界。``apply`` 中只做 forward 必需的
-runtime 计算,不应重复做权重 repack。
+``create_weights`` 注册 checkpoint-facing 的参数名。``process_weights_after_loading`` 是
+checkpoint layout 转 runtime kernel layout 的唯一边界。``apply`` 里只做 forward 必需的 runtime
+计算,不要重复做权重 repack。
-新增量化路径至少需要覆盖:
+新增量化路径至少要覆盖:
- config 解析测试。
- ``ignore`` / prefix 匹配测试。
-- 参数注册 shape/dtype 测试。
+- 参数注册的 shape / dtype 测试。
- post-load layout 转换测试。
- forward correctness 或 smoke test。
新增 CUDA JIT kernel
----------------------------------------
-若 kernel 适合走 ``mllm-kernel`` 的 TVM-FFI JIT 路径,推荐结构如下:
+如果 kernel 适合走 ``mllm-kernel`` 的 TVM-FFI JIT 路径,推荐这个结构:
.. code-block:: text
@@ -103,28 +103,29 @@ runtime 计算,不应重复做权重 repack。
mllm-kernel/tests/test_.py
mllm-kernel/benchmarks/bench_.py
-Python wrapper 应负责:
+Python wrapper 负责:
-- 校验输入 shape、dtype、device。
+- 校验输入的 shape、dtype、device。
- 分配输出 tensor。
-- 调用 ``@jit`` 包装后的 compiled module。
-- 暴露稳定、简洁的 Python API。
+- 调 ``@jit`` 包好的 compiled module。
+- 对外暴露一个稳定、干净的 Python API。
-CUDA/C++ source 应尽量只表达 kernel 语义,不混入 checkpoint 配置解析或模型层逻辑。
+CUDA / C++ source 尽量只表达 kernel 语义,别混进 checkpoint 配置解析或模型层逻辑。
-如果 kernel 依赖 CUTLASS 等重模板库,可以先做编译 spike。确认 Jetson 目标设备上的编译时间、
-缓存路径、include 来源和内存占用后,再决定使用 TVM-FFI JIT、torch extension JIT 或 AOT 构建。
+如果 kernel 依赖 CUTLASS 这种重模板库,建议先做一次编译 spike:把 Jetson 目标设备上的编译
+时间、缓存路径、include 来源和内存占用摸清楚,再决定用 TVM-FFI JIT、torch extension JIT 还是
+AOT 构建。
服务级验证
----------------------------------------
-服务级 smoke test 应覆盖:
+服务级 smoke test 至少要覆盖:
-- ``/v1/models`` 可返回。
-- 文本 ``/v1/chat/completions`` 可完成。
-- 图文模型能处理容器内图片绝对路径。
-- streaming 与 non-streaming 至少各测一次。
-- 中止请求或客户端断连不会泄漏 running request。
+- ``/v1/models`` 能返回。
+- 文本 ``/v1/chat/completions`` 能跑完。
+- 图文模型能处理容器内的图片绝对路径。
+- streaming 和 non-streaming 各测一次。
+- 中止请求或客户端断连时不会泄漏 running request。
示例:
@@ -145,7 +146,7 @@ CUDA/C++ source 应尽量只表达 kernel 语义,不混入 checkpoint 配置
性能验证
----------------------------------------
-性能数据需要固定口径,否则不同记录之间很难比较。建议记录:
+性能数据一定要固定口径,否则不同记录之间根本没法比。建议每次都记下:
- commit hash。
- JetPack / L4T 版本。
@@ -154,11 +155,11 @@ CUDA/C++ source 应尽量只表达 kernel 语义,不混入 checkpoint 配置
- 模型路径和量化格式。
- 启动命令。
- prompt token 数、max tokens、temperature。
-- 是否启用 radix cache、CUDA Graph、shared queue。
+- 有没有开 radix cache、CUDA Graph、shared queue。
- 是否包含首次 JIT 编译。
-对服务级请求,建议丢弃第一次 warmup 结果,记录第 2/3 次请求的 prefill/decode 统计。
-对 kernel microbench,建议单独记录 warmup、重复次数、输入 shape 和 dtype。
+服务级请求建议丢掉第一次 warmup 的结果,记第 2 / 3 次请求的 prefill / decode 统计。kernel
+microbench 则要单独记 warmup、重复次数、输入 shape 和 dtype。
常见问题定位
----------------------------------------
@@ -166,55 +167,55 @@ CUDA/C++ source 应尽量只表达 kernel 语义,不混入 checkpoint 配置
启动失败
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-优先确认:
+先看:
-- ``pymllm`` 和 ``mllm_kernel`` 是否来自预期源码目录或安装版本。
-- ``model_path`` 和 ``tokenizer_path`` 是否在容器内可见。
-- ``transformers`` 是否能读取目标 ``config.json``。
-- CUDA 是否可用,``torch.cuda.get_device_capability()`` 是否符合量化 kernel 要求。
+- ``pymllm`` 和 ``mllm_kernel`` 是不是来自预期的源码目录或安装版本。
+- ``model_path`` 和 ``tokenizer_path`` 在容器内能不能看到。
+- ``transformers`` 能不能读目标 ``config.json``。
+- CUDA 可不可用,``torch.cuda.get_device_capability()`` 满不满足量化 kernel 的要求。
W8A8 编译失败
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-优先确认:
+先看:
-- ``CUTLASS_HOME`` 是否设置正确。
-- ``flashinfer`` 是否包含 bundled CUTLASS。
-- ``~/.cache/mllm_kernel/cutlass_int8_scaled_mm/`` 是否存在旧的失败缓存。
-- 当前 GPU 是否为 SM80-SM89。
+- ``CUTLASS_HOME`` 设没设对。
+- ``flashinfer`` 里有没有 bundled CUTLASS。
+- ``~/.cache/mllm_kernel/cutlass_int8_scaled_mm/`` 是不是有旧的失败缓存。
+- 当前 GPU 是不是 SM80–SM89。
请求卡住或 CPU 占用高
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-优先确认:
+先看:
-- scheduler 是否启用了 idle sleep。
-- tokenizer / scheduler / detokenizer 子进程是否全部存活。
-- 是否有请求已经断连但未 abort。
-- ``max_total_tokens`` 是否过小导致 KV allocation 反复失败和 eviction。
+- scheduler 有没有启用 idle sleep。
+- tokenizer / scheduler / detokenizer 子进程是不是都还活着。
+- 是不是有请求已经断连但没 abort。
+- ``max_total_tokens`` 是不是太小,导致 KV allocation 反复失败和 eviction。
输出异常
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-优先确认:
+先看:
-- tokenizer chat template 是否符合目标模型。
-- EOS token 是否从 config、generation_config 或 tokenizer 中正确解析。
-- 量化模型的 ``ignore`` 是否覆盖视觉分支、embedding、norm 和 lm_head 等不应量化模块。
-- ``process_weights_after_loading`` 是否已执行。
+- tokenizer 的 chat template 对不对得上目标模型。
+- EOS token 有没有从 config、generation_config 或 tokenizer 里正确解析出来。
+- 量化模型的 ``ignore`` 有没有覆盖视觉分支、embedding、norm、lm_head 这些不该量化的模块。
+- ``process_weights_after_loading`` 跑没跑。
贡献建议
----------------------------------------
-开发时尽量保持以下边界:
+开发时尽量守住这些边界:
-- 服务协议变化放在 ``pymllm/server``。
-- 请求/响应结构放在 ``pymllm/engine/io_struct.py``。
-- 调度策略放在 ``pymllm/orchestrator/scheduler_process.py``。
-- GPU 资源和 forward 逻辑放在 ``pymllm/executor``。
-- 模型结构放在 ``pymllm/models``。
-- 基础层放在 ``pymllm/layers``。
-- 量化格式放在 ``pymllm/quantization``。
-- 自定义 kernel 放在 ``mllm-kernel``。
+- 服务协议变化放 ``pymllm/server``。
+- 请求 / 响应结构放 ``pymllm/engine/io_struct.py``。
+- 调度策略放 ``pymllm/orchestrator/scheduler_process.py``。
+- GPU 资源和 forward 逻辑放 ``pymllm/executor``。
+- 模型结构放 ``pymllm/models``。
+- 基础层放 ``pymllm/layers``。
+- 量化格式放 ``pymllm/quantization``。
+- 自定义 kernel 放 ``mllm-kernel``。
-这样可以避免把一次模型适配写成跨层补丁,也方便后续把同一能力复用到更多模型和设备。
+守住这些边界,一次模型适配就不会写成跨层补丁,后面把同一份能力复用到更多模型和设备也更省事。
diff --git a/docs/pymllm_runtime/kernels_and_acceleration.rst b/docs/pymllm_runtime/kernels_and_acceleration.rst
index d5d09c30c..849c5a5e2 100644
--- a/docs/pymllm_runtime/kernels_and_acceleration.rst
+++ b/docs/pymllm_runtime/kernels_and_acceleration.rst
@@ -4,28 +4,28 @@ pymllm Kernels and Acceleration
总览
----------------------------------------
-``pymllm`` 的性能路径由多类加速组件共同组成:
+``pymllm`` 的性能由几类加速组件分工撑起来,它们解决的不是同一个问题:
-- FlashInfer:paged KV cache attention。
-- CUDA Graph:decode 阶段减少 CPU launch overhead。
-- Triton:W8A8 per-token activation quantization。
-- CUTLASS:W8A8 INT8 Tensor Core GEMM。
-- ``mllm-kernel``:基于 TVM-FFI / torch extension 的 JIT kernel 工具包。
+- **FlashInfer**:paged KV cache attention。
+- **CUDA Graph**:减少 decode 阶段的 CPU launch overhead。
+- **Triton**:W8A8 的 per-token activation quantization。
+- **CUTLASS**:W8A8 的 INT8 Tensor Core GEMM。
+- **mllm-kernel**:基于 TVM-FFI / torch extension 的 JIT kernel 工具包。
-这些组件不是彼此替代关系,而是在不同层次承担职责。attention backend 解决 KV cache
-attention;CUDA Graph 解决重复 decode step 的 launch overhead;Triton 和 CUTLASS 解决量化
-linear 的核心计算;``mllm-kernel`` 为项目内自定义 CUDA/C++ kernel 提供封装、缓存和工具。
+简单说:attention backend 管 KV cache attention,CUDA Graph 管重复 decode step 的 launch
+开销,Triton 和 CUTLASS 管量化 linear 的核心计算,``mllm-kernel`` 则为项目内自定义的
+CUDA / C++ kernel 提供封装、缓存和工具。
mllm-kernel
----------------------------------------
-``mllm-kernel`` 是 mllm 项目中的高性能 kernel 包。当前 Python 侧主要包含:
+``mllm-kernel`` 是 mllm 里的高性能 kernel 包,Python 侧目前主要是:
- ``mllm_kernel.cuda.jit``:CUDA JIT kernel wrapper。
- ``mllm_kernel.cpu.jit``:CPU JIT kernel wrapper。
- ``mllm_kernel.jit_utils``:JIT 编译、缓存、注册表和工具函数。
-CUDA JIT kernel 的典型结构是:
+一个 CUDA JIT kernel 的典型结构:
.. code-block:: text
@@ -33,54 +33,54 @@ CUDA JIT kernel 的典型结构是:
-> @jit(...)
-> include CUDA/C++ source
-> export TVM-FFI typed function
- -> compile on first use
- -> reuse cached shared library
+ -> 首次使用时编译
+ -> 之后复用缓存的 shared library
-默认 JIT 缓存目录为:
+默认 JIT 缓存目录:
.. code-block:: text
~/.cache/mllm_kernel/
-``mllm-kernel`` 的 JIT 路径与 SGLang 的 ``jit_kernel`` 设计关系更直接:二者都强调轻量
-JIT、运行时选择模板实例、避免大型 AOT torch extension 带来的长编译周期。与此同时,SGLang
-的 ``sgl-kernel`` AOT kernel 仍然是重要参考,尤其适合对照量化 GEMM 的语义和性能。
+``mllm-kernel`` 的 JIT 思路和 SGLang 的 ``jit_kernel`` 关系更近:都强调轻量 JIT、运行时选模板
+实例、避开大型 AOT torch extension 那种动辄几分钟的编译。同时 SGLang 的 ``sgl-kernel`` AOT
+kernel 仍是重要参考,对照量化 GEMM 的语义和性能时尤其有用。
TVM-FFI JIT 路径
----------------------------------------
-``mllm_kernel.jit_utils.jit`` decorator 会将 Python 函数包装成一个按需编译的 kernel 调用。
-它负责:
+``mllm_kernel.jit_utils.jit`` 这个 decorator 把一个 Python 函数包成按需编译的 kernel 调用,
+负责:
-- 根据 tensor device 推断 CPU/CUDA 目标。
-- 将 Python 参数转换为 C++ template 参数。
-- 拼接 C++/CUDA source 和 export wrapper。
-- 调用 TVM-FFI 编译并加载 shared library。
-- 将编译结果缓存到 ``~/.cache/mllm_kernel``。
+- 根据 tensor device 推断 CPU / CUDA 目标。
+- 把 Python 参数转成 C++ template 参数。
+- 拼 C++/CUDA source 和 export wrapper。
+- 调 TVM-FFI 编译并加载 shared library。
+- 把编译结果缓存到 ``~/.cache/mllm_kernel``。
-这种方式适合小而明确的自定义 kernel,例如:
+这种方式适合小而明确的自定义 kernel,比如:
-- ``create_kv_indices``:构造 FlashInfer KV index metadata。
-- ``store_cache``:将 K/V 写入 KVPool。
-- ``gptq_marlin_repack``:Marlin weight layout 转换。
+- ``create_kv_indices``:构造 FlashInfer 的 KV index metadata。
+- ``store_cache``:把 K/V 写进 KVPool。
+- ``gptq_marlin_repack``:Marlin 权重 layout 转换。
- ``gptq_marlin_gemm``:W4A16 Marlin GEMM。
-W8A8 CUTLASS kernel 当前使用 ``torch.utils.cpp_extension.load`` 编译。这是因为 CUTLASS
-模板和 include 体系较重,当前以稳定通过 Jetson SM87 编译为优先。
+W8A8 的 CUTLASS kernel 目前是个例外,用 ``torch.utils.cpp_extension.load`` 编译——CUTLASS 的
+模板和 include 体系太重,现阶段优先保证它能在 Jetson SM87 上稳定编过。
FlashInfer Attention
----------------------------------------
-``pymllm.layers.attention.flashinfer_backend.FlashInferAttnBackend`` 封装 FlashInfer 的 paged
-KV cache attention。它负责:
+``pymllm.layers.attention.flashinfer_backend.FlashInferAttnBackend`` 封装了 FlashInfer 的
+paged KV cache attention,负责:
- 为 prefill 和 decode 准备 ``kv_indptr``、``kv_indices``、``kv_last_page_len`` 等 metadata。
- 管理全局 workspace buffer。
-- 根据是否存在 sliding window 选择 wrapper dispatch。
-- 在 decode 中根据 GQA group size 和 KV dtype 决定是否使用 tensor core 路径。
-- 为 CUDA Graph capture / replay 提供专用 metadata 初始化接口。
+- 根据有没有 sliding window 选 wrapper dispatch。
+- decode 时按 GQA group size 和 KV dtype 决定走不走 tensor core 路径。
+- 给 CUDA Graph capture / replay 提供专用的 metadata 初始化接口。
-prefill 和 decode 使用不同 wrapper:
+prefill 和 decode 用不同 wrapper:
.. code-block:: text
@@ -91,22 +91,22 @@ prefill 和 decode 使用不同 wrapper:
decode
BatchDecodeWithPagedKVCacheWrapper
-attention backend 只负责 attention 计算和 metadata,不负责请求调度和 KV slot 生命周期。KV slot
-的分配、释放和 prefix cache 命中由 scheduler / model runner 侧完成。
+attention backend 只管 attention 计算和 metadata,不碰请求调度和 KV slot 生命周期。KV slot
+的分配、释放、prefix cache 命中是 scheduler / model runner 那边的事。
CUDA Graph
----------------------------------------
-``pymllm.executor.cuda_graph_runner.CudaGraphRunner`` 用于 decode step 的 CUDA Graph capture
-和 replay。它的目标是减少小 batch decode 中 CPU launch overhead。
+``pymllm.executor.cuda_graph_runner.CudaGraphRunner`` 负责 decode step 的 CUDA Graph capture
+和 replay,目的就是把小 batch decode 里的 CPU launch overhead 压下去。
-初始化阶段会按一组离散 batch size 捕获 graph:
+初始化时按一组离散 batch size 捕获 graph:
.. code-block:: text
[1, 2, 4, 8, 12, 16, 24, 32, ...]
-每个 captured graph 复用预分配输入 buffer:
+每个 captured graph 复用预分配好的输入 buffer:
- ``input_ids``
- ``req_pool_indices``
@@ -115,36 +115,36 @@ CUDA Graph
- ``positions``
- ``mrope_position_deltas``
-replay 时,真实 batch 会被 padding 到最近的 captured batch size。attention backend 会走专用
-``init_forward_metadata_replay_cuda_graph`` 路径,避免使用普通动态 metadata 初始化。
+replay 时真实 batch 会 padding 到最近的 captured batch size,attention backend 走专用的
+``init_forward_metadata_replay_cuda_graph`` 路径,而不是普通的动态 metadata 初始化。
-CUDA Graph 只覆盖 decode 主路径。调试模型、调试 attention metadata 或定位 shape 问题时,可以
-使用 ``--server.disable_cuda_graph`` 暂时关闭。
+CUDA Graph 只覆盖 decode 主路径。调试模型、查 attention metadata 或定位 shape 问题时,可以
+用 ``--server.disable_cuda_graph`` 临时关掉。
W4A16 Marlin
----------------------------------------
-W4A16 路径复用 Marlin kernel。checkpoint 权重先以 ``weight_packed`` 和 ``weight_scale``
-加载,然后在 post-load 阶段转换为 Marlin runtime layout。
+W4A16 复用 Marlin kernel。checkpoint 权重先以 ``weight_packed`` 和 ``weight_scale`` 加载,
+再在 post-load 阶段转成 Marlin 的 runtime layout。
关键 kernel:
- ``mllm_kernel.cuda.jit.gptq_marlin_repack``
- ``mllm_kernel.cuda.jit.gptq_marlin``
-执行约束包括:
+执行约束:
- SM80+
-- output partition 可被 64 整除
-- input partition 可被 128 整除
-- group size 当前主路径为 32
+- output partition 能被 64 整除
+- input partition 能被 128 整除
+- group size 主路径目前是 32
-这种路径适合 AWQ / W4A16 类权重量化模型,activation 保持 FP16/BF16。
+这条路径适合 AWQ / W4A16 这类权重量化模型,activation 保持 FP16/BF16。
W8A8 Triton + CUTLASS
----------------------------------------
-W8A8 路径包含两个核心 kernel:
+W8A8 有两个核心 kernel:
1. ``pymllm.quantization.kernels.int8_activation_triton.per_token_quant_int8``
2. ``mllm_kernel.cuda.jit.int8_scaled_mm_cutlass.int8_scaled_mm``
@@ -156,37 +156,32 @@ W8A8 路径包含两个核心 kernel:
[M, K] fp16/bf16 activation
-> Triton per-token absmax + round + int8 cast
-> [M, K] int8 + [M, 1] fp32 scale
- -> CUTLASS int8 GEMM with per-row/per-col scales
+ -> CUTLASS int8 GEMM(per-row / per-col scale)
-> [M, N] fp16/bf16 output
-CUTLASS kernel 要求 ``mat_b`` 为 ``[K, N]`` column-major,因此 W8A8 scheme 会在
-``process_weights_after_loading`` 中把 checkpoint 的 ``[N, K]`` INT8 weight 转成对应布局。
+CUTLASS kernel 要求 ``mat_b`` 是 ``[K, N]`` column-major,所以 W8A8 scheme 会在
+``process_weights_after_loading`` 里把 checkpoint 的 ``[N, K]`` INT8 weight 转成对应布局。
-当前 CUTLASS include 查找顺序为:
-
-1. ``CUTLASS_HOME/include``
-2. ``flashinfer`` bundled CUTLASS
-3. 系统 include 目录
-
-如果找不到 CUTLASS 头文件,W8A8 初始化会失败。生产环境建议在镜像中固定 CUTLASS 来源,避免
-不同节点使用不同版本头文件。
+CUTLASS 头文件默认用 ``flashinfer`` bundled 的那份;要换版本就设 ``CUTLASS_HOME``。如果头文件
+找不到,W8A8 初始化会直接失败。生产环境建议在镜像里把 CUTLASS 来源固定下来,免得不同节点用上
+不同版本的头文件。
GDN decode kernel
----------------------------------------
-Qwen3.5 等 hybrid 模型可能包含 GDN / linear attention 层。``pymllm`` 为这类模型保留了:
+Qwen3.5 这类 hybrid 模型可能带 GDN / linear attention 层。``pymllm`` 给它们预留了:
- ``pymllm.layers.attention.gdn_backend``
- ``pymllm.layers.attention.hybrid_backend``
- ``mllm_kernel.cuda.jit.gdn_decode``
- ``MambaRadixCache`` / GDN state cache 相关结构
-当前文档重点覆盖 Qwen3 / Qwen3-VL 主路径。GDN 相关路径仍应以具体模型和测试结果为准。
+本文档重点还是 Qwen3 / Qwen3-VL 主路径,GDN 相关路径以具体模型和测试结果为准。
调试与观测
----------------------------------------
-常用检查命令:
+几条常用检查命令:
.. code-block:: bash
@@ -194,10 +189,6 @@ Qwen3.5 等 hybrid 模型可能包含 GDN / linear attention 层。``pymllm``
python3 -m mllm_kernel show-config
python3 -m pymllm show-config
-当首次运行时间异常长时,应区分:
-
-- 模型权重加载时间。
-- FlashInfer / CUDA context 初始化时间。
-- CUTLASS JIT 编译时间。
-- CUDA Graph capture 时间。
-- 实际 prefill/decode 时间。
+首次运行特别慢的时候,要分清楚时间花在哪:模型权重加载、FlashInfer / CUDA context 初始化、
+CUTLASS JIT 编译、CUDA Graph capture,还是真正的 prefill / decode。别把首次 JIT 或 kernel
+初始化的开销当成稳态瓶颈。
diff --git a/docs/pymllm_runtime/models_and_quantization.rst b/docs/pymllm_runtime/models_and_quantization.rst
index e7d92dd1e..cb1c885d6 100644
--- a/docs/pymllm_runtime/models_and_quantization.rst
+++ b/docs/pymllm_runtime/models_and_quantization.rst
@@ -4,15 +4,16 @@ pymllm Models and Quantization
总览
----------------------------------------
-``pymllm`` 的模型实现遵循 PyTorch ``nn.Module`` 风格,并通过 HuggingFace
-``config.architectures`` 字段选择模型类。当前重点支持 Qwen3 family:
+``pymllm`` 的模型实现就是标准的 PyTorch ``nn.Module`` 写法,运行时按 HuggingFace config
+里的 ``architectures`` 字段挑模型类。当前重点是 Qwen3 family:
- ``Qwen3ForCausalLM``:文本模型,例如 Qwen3-0.6B。
- ``Qwen3VLForConditionalGeneration``:图文模型,例如 Qwen3-VL-2B-Instruct。
-- ``Qwen3_5ForCausalLM`` 和 ``Qwen3_5ForConditionalGeneration``:hybrid attention / GDN
- 相关模型骨架。
+- ``Qwen3_5ForCausalLM`` / ``Qwen3_5ForConditionalGeneration``:hybrid attention / GDN
+ 方向的模型骨架。
-量化系统以 linear layer 为核心,使用插件式 ``LinearMethodBase`` 生命周期:
+量化系统围绕 linear layer 展开,用一套插件式的 ``LinearMethodBase`` 生命周期把格式细节
+和模型主逻辑隔开:
.. code-block:: text
@@ -26,8 +27,8 @@ pymllm Models and Quantization
模型注册
----------------------------------------
-模型注册表位于 ``pymllm/models/__init__.py``。运行时会根据 HuggingFace config 中的
-architecture 字符串懒加载模型类:
+模型注册表在 ``pymllm/models/__init__.py``。运行时按 HuggingFace config 里的 architecture
+字符串懒加载对应模型类:
.. code-block:: text
@@ -40,52 +41,59 @@ architecture 字符串懒加载模型类:
"Qwen3_5ForCausalLM"
-> pymllm.models.qwen3_5.Qwen3_5ForCausalLM
-这种注册方式让服务启动阶段只导入目标模型所需的代码,避免在命令行工具或轻量检查中提前加载
-大量 PyTorch/CUDA 依赖。
+懒加载的好处是:服务启动时只导入目标模型用到的代码,命令行工具或轻量检查不会被迫提前拉起
+一大堆 PyTorch / CUDA 依赖。
Qwen3 文本模型
----------------------------------------
-``Qwen3ForCausalLM`` 使用标准 decoder-only 结构:
-
-- token embedding
-- 多层 decoder block
-- Q/K Norm
-- 1D RoPE
-- MLP
-- final norm
-- lm head
-
-它复用 ``RadixAttention``、``RMSNorm``、``MLP``、``ColumnParallelLinear`` 和
-``RowParallelLinear`` 等基础层。与 Qwen3-VL 文本分支相比,Qwen3 文本模型使用 1D RoPE,
-不需要多模态 M-RoPE 的三维 position 逻辑。
+``Qwen3ForCausalLM`` 是标准的 decoder-only 结构:token embedding、多层 decoder block、
+Q/K Norm、1D RoPE、MLP、final norm、lm head。它复用 ``RadixAttention``、``RMSNorm``、
+``MLP``、``ColumnParallelLinear``、``RowParallelLinear`` 这些基础层。和 Qwen3-VL 的文本分支
+比,区别在于这里用的是 1D RoPE,不需要多模态 M-RoPE 那套三维 position 逻辑。
Qwen3-VL 图文模型
----------------------------------------
-``Qwen3VLForConditionalGeneration`` 在文本 decoder 外增加视觉输入处理和 M-RoPE 位置编码。
-在一次图文请求中:
+``Qwen3VLForConditionalGeneration`` 在文本 decoder 之外多了视觉输入处理和 M-RoPE 位置编码。
+一次图文请求大致是这样走的:
1. tokenizer / processor 处理 messages 和图片路径。
-2. ``TokenizerProcess`` 生成 token ids 和多模态输入 tensor。
+2. ``TokenizerProcess`` 产出 token ids 和多模态输入 tensor。
3. 多模态 tensor 通过 ZMQ 或 shared queue 送到 scheduler。
-4. 模型 forward 中先处理视觉侧输入,再进入语言模型 prefill/decode。
-5. decode 阶段使用每个请求保存的 ``mrope_position_delta`` 修正位置。
+4. 模型 forward 里先过视觉侧输入,再进语言模型的 prefill / decode。
+5. decode 阶段用每个请求保存的 ``mrope_position_delta`` 修正位置。
当前 W8A8 量化主要覆盖语言 decoder 的线性层;视觉 encoder、embedding、LayerNorm 和
``lm_head`` 保持全精度。
+Fused projection 与 shard-aware loading
+----------------------------------------
+
+Qwen3 / Qwen3-VL 的 text decoder 用了 fused QKV projection 和 fused gate/up projection。
+对非量化模型,这减少了 projection 层的 module 边界;对 W8A8 和 W4A16 路径,它还顺手省掉了
+把同一层拆成多次 activation quant、GEMM 或 Marlin 调用的开销。
+
+checkpoint 里的权重往往还是 HuggingFace 常见的分离形式,比如 ``q_proj``、``k_proj``、
+``v_proj`` 和 ``gate_proj``、``up_proj``。``MergedLinear`` 用 shard-aware 的 ``weight_loader``
+把这些分离 tensor 写进 fused 参数,运行时布局保持 ``[Q, K, V]`` 或 ``[gate, up]``。权重加载
+完之后,``process_weights_after_loading`` 再去做 W8A8 layout 转换或 W4A16 Marlin repack。
+
+Qwen3 / Qwen3-VL decoder 还用 residual-carry 的形式组织 RMSNorm 的 fused add 路径。在
+Qwen3-VL 里,如果需要注入 deepstack embedding,运行时会先把当前 residual sum 物化出来,再
+执行注入并重置 carry,避免破坏图文 prefill 的语义。
+
量化配置解析
----------------------------------------
-服务启动时,``ModelRunner`` 会解析量化配置。优先级为:
+服务启动时 ``ModelRunner`` 解析量化配置,优先级是:
1. 命令行 ``--quantization.method``。
-2. checkpoint 目录中的量化配置文件。
-3. ``config.json`` 中的 ``quantization_config`` 字段。
+2. checkpoint 目录里的量化配置文件。
+3. ``config.json`` 里的 ``quantization_config`` 字段。
-``compressed-tensors`` 路径使用 ``pymllm.quantization.methods.compressed_tensors``,
-当前支持两类签名:
+``compressed-tensors`` 路径走 ``pymllm.quantization.methods.compressed_tensors``,目前支持
+两类签名:
.. list-table::
:header-rows: 1
@@ -106,13 +114,12 @@ Qwen3-VL 图文模型
- INT8 dynamic per-token activation
- Triton quant + CUTLASS INT8 GEMM
-``ignore`` 字段会让匹配前缀的模块跳过量化。例如 Qwen3-VL 的视觉分支通常保留为全精度。
+``ignore`` 字段会让前缀匹配上的模块跳过量化,比如 Qwen3-VL 的视觉分支通常整体保留全精度。
W4A16 / AWQ Marlin 路径
----------------------------------------
-W4A16 路径面向 ``compressed-tensors`` 的 ``pack-quantized`` checkpoint。当前支持的
-约束是:
+W4A16 面向 ``compressed-tensors`` 的 ``pack-quantized`` checkpoint。当前的约束是:
- ``format == "pack-quantized"``
- ``weights.num_bits == 4``
@@ -121,7 +128,7 @@ W4A16 路径面向 ``compressed-tensors`` 的 ``pack-quantized`` checkpoint。
- ``actorder == null``
- GPU capability 不低于 SM80
-权重加载和执行分为三个阶段:
+权重加载和执行分三步:
.. code-block:: text
@@ -132,20 +139,20 @@ W4A16 路径面向 ``compressed-tensors`` 的 ``pack-quantized`` checkpoint。
process_weights_after_loading()
gptq_marlin_repack()
marlin_permute_scales()
- create runtime-only zero/g_idx placeholders
+ 建好 runtime-only 的 zero / g_idx 占位
│
▼
apply()
gptq_marlin_gemm()
-``create_weights`` 注册与 checkpoint 对齐的参数名,保证 safetensors 加载逻辑可以按名称写入。
-``process_weights_after_loading`` 是 checkpoint layout 到 runtime kernel layout 的边界,repack
-不应放在通用权重加载器或每次 forward 中。
+``create_weights`` 注册和 checkpoint 对齐的参数名,让 safetensors 加载逻辑能按名字写进去。
+``process_weights_after_loading`` 是 checkpoint layout 转 runtime kernel layout 的那条边界,
+repack 只该放在这里,不该塞进通用权重加载器,更不该每次 forward 都做。
W8A8 INT8 路径
----------------------------------------
-W8A8 路径面向 ``compressed-tensors`` 的 ``int-quantized`` checkpoint。当前支持的约束是:
+W8A8 面向 ``compressed-tensors`` 的 ``int-quantized`` checkpoint。当前的约束是:
- ``format == "int-quantized"``
- ``weights.num_bits == 8``
@@ -158,10 +165,10 @@ W8A8 路径面向 ``compressed-tensors`` 的 ``int-quantized`` checkpoint。当
- ``input_activations.strategy == "token"``
- ``input_activations.dynamic == true``
- ``input_activations.symmetric == true``
-- W8A8 CUTLASS 路径当前支持 Ampere / SM8x GPU(SM80-SM89)。已验证目标为
- Jetson Orin SM87;Hopper / SM90 暂不包含在当前支持范围内。
+- W8A8 CUTLASS 路径当前支持 Ampere / SM8x(SM80–SM89)。已验证目标是 Jetson Orin SM87;
+ Hopper / SM90 暂不在支持范围内。
-执行链路如下:
+执行链路:
.. code-block:: text
@@ -178,49 +185,47 @@ W8A8 路径面向 ``compressed-tensors`` 的 ``int-quantized`` checkpoint。当
│
└── output(fp16/bf16)
-checkpoint 中的 INT8 权重通常是 ``[N, K]`` row-major。``process_weights_after_loading``
-会将其转换为 ``[K, N]`` column-major 视图并整理 ``weight_scale``,以满足 CUTLASS kernel
-接口约定。
+checkpoint 里的 INT8 权重通常是 ``[N, K]`` row-major。``process_weights_after_loading`` 会把它
+转成 ``[K, N]`` column-major 视图并整理 ``weight_scale``,以满足 CUTLASS kernel 的接口约定。
LinearMethod 生命周期
----------------------------------------
-所有 linear layer 都持有一个 ``quant_method``:
+每个 linear layer 都持有一个 ``quant_method``:
-- 未量化时使用 ``UnquantizedLinearMethod``,注册普通 ``weight`` 并调用 ``F.linear``。
+- 不量化时用 ``UnquantizedLinearMethod``,注册普通 ``weight`` 并调 ``F.linear``。
- 量化时由 ``QuantizationConfig.get_quant_method(layer, prefix)`` 返回具体方法。
典型生命周期:
-1. 模型构造时,linear layer 调用 ``quant_method.create_weights`` 注册参数。
-2. ``model.load_weights`` 根据参数名和 ``weight_loader`` 写入 checkpoint tensor。
-3. 所有权重加载完成后,``ModelRunner`` 遍历模块并调用
- ``process_weights_after_loading``。
-4. forward 时,linear layer 委托 ``quant_method.apply`` 执行。
+1. 模型构造时,linear layer 调 ``quant_method.create_weights`` 注册参数。
+2. ``model.load_weights`` 按参数名和 ``weight_loader`` 写进 checkpoint tensor。
+3. 权重全部加载完,``ModelRunner`` 遍历模块调 ``process_weights_after_loading``。
+4. forward 时 linear layer 委托 ``quant_method.apply`` 执行。
-这个边界使新增量化方法时不需要改动模型主逻辑,只需要实现新的 config 和 scheme。
+有了这条边界,新增量化方法时基本不用碰模型主逻辑,只要实现新的 config 和 scheme。
新增模型的建议流程
----------------------------------------
-新增模型时建议遵循以下顺序:
+新增模型时建议按这个顺序来:
-1. 在 ``pymllm/models/`` 中新增模型文件。
+1. 在 ``pymllm/models/`` 加模型文件。
2. 在 ``pymllm/models/__init__.py`` 注册 HuggingFace architecture 字符串。
3. 实现最小 forward 接口:``forward(input_ids, positions, forward_batch)``。
4. 复用现有基础层,并确保 linear layer 接受 ``quant_method``。
-5. 实现 ``load_weights``,处理 checkpoint 前缀、stacked projection 和 tied embedding。
-6. 增加 registry、weight loading、forward timing 的单元测试。
+5. 实现 ``load_weights``,处理好 checkpoint 前缀、stacked projection 和 tied embedding。
+6. 补 registry、weight loading、forward timing 的单元测试。
7. 最后再做服务级 smoke test。
新增量化方法的建议流程
----------------------------------------
-新增量化方法时建议保持三层结构:
+新增量化方法时保持三层结构:
1. ``QuantizationConfig``:解析 checkpoint 配置,决定某个 layer 是否量化。
2. ``LinearMethod``:承接 layer 生命周期。
3. ``Scheme``:处理具体格式的参数注册、post-load 转换和 kernel apply。
-不要把 checkpoint 格式判断写入模型类,也不要把 runtime repack 隐藏在通用
-``weight_loader`` 中。这样可以保证模型结构、权重格式和 kernel layout 三者的边界清晰。
+不要把 checkpoint 格式判断写进模型类,也不要把 runtime repack 藏在通用 ``weight_loader``
+里。守住这条,模型结构、权重格式、kernel layout 三者的边界才不会糊在一起。
diff --git a/docs/pymllm_runtime/runtime_design.rst b/docs/pymllm_runtime/runtime_design.rst
index 309ea7a21..7f1c31f66 100644
--- a/docs/pymllm_runtime/runtime_design.rst
+++ b/docs/pymllm_runtime/runtime_design.rst
@@ -4,13 +4,12 @@ pymllm Runtime Design
总览
----------------------------------------
-``pymllm`` 是 mllm 的 Python serving runtime。它不是传统意义上的 mllm C++
-Backend,而是一套围绕 PyTorch/CUDA 生态构建的在线推理服务运行时。当前实现面向
-Jetson Orin 等边缘 GPU 设备,重点支持 Qwen3、Qwen3-VL 和 Qwen3.5 系列模型。
+``pymllm`` 是 mllm 的 Python serving runtime。它不是 mllm 的 C++ Backend,而是一套
+围绕 PyTorch / CUDA 生态搭起来的在线推理服务运行时,目标设备是 Jetson Orin 这类边缘
+GPU,重点支持 Qwen3、Qwen3-VL、Qwen3.5 系列。
-它的设计参考了 SGLang serving runtime 的核心分层,但进行了明显收缩:当前主路径以
-单机单 GPU 为目标,优先保证在 Jetson 上可运行、可调试、可扩展,而不是覆盖大规模
-分布式 serving 的全部复杂度。
+它的分层借鉴了 SGLang serving runtime,但做了明显收缩:主路径只盯单机单 GPU,优先保证
+在 Jetson 上跑得起来、调得动、改得动,而不是去覆盖大规模分布式 serving 的全部复杂度。
.. figure:: ../_static/img/pymllm-arch.png
:width: 100%
@@ -22,23 +21,23 @@ Jetson Orin 等边缘 GPU 设备,重点支持 Qwen3、Qwen3-VL 和 Qwen3.5 系
整体分层
----------------------------------------
-从开发者视角看,``pymllm`` 可以分为五层:
+从开发者视角看,``pymllm`` 大致分五层:
1. **服务入口层**:FastAPI HTTP server,提供 OpenAI-compatible API 和原生
``/generate`` API。
-2. **配置层**:``ServerConfig``、``ModelConfig``、``QuantizationConfig`` 统一解析
- 模型路径、dtype、调度参数、缓存参数、量化参数和加速开关。
-3. **控制面**:``Engine`` 启动 tokenizer、scheduler、detokenizer 子进程,并在主进程中
- 维护 request/response 状态。
-4. **数据面**:scheduler 持有 GPU-owning ``ModelRunnerProcess``,负责 batch 构造、
- KV cache 分配、prefix cache 命中、forward 和 sampling。
+2. **配置层**:``ServerConfig``、``ModelConfig``、``QuantizationConfig`` 统一解析模型
+ 路径、dtype、调度参数、缓存参数、量化参数和各类加速开关。
+3. **控制面**:``Engine`` 拉起 tokenizer、scheduler、detokenizer 子进程,主进程里维护
+ request/response 状态。
+4. **数据面**:scheduler 持有 GPU 的 ``ModelRunnerProcess``,负责 batch 构造、KV cache
+ 分配、prefix cache 命中、forward 和 sampling。
5. **加速层**:FlashInfer、CUDA Graph、Triton、CUTLASS 和 ``mllm-kernel`` 提供 attention、
- quantization、GEMM 和缓存写入等高频算子。
+ 量化、GEMM、缓存写入这些高频算子。
进程拓扑
----------------------------------------
-``Engine`` 在启动时创建三个子进程,并在主进程中保留 request/response 管理逻辑:
+``Engine`` 启动时创建三个子进程,request/response 的管理逻辑留在主进程:
.. code-block:: text
@@ -64,34 +63,33 @@ Jetson Orin 等边缘 GPU 设备,重点支持 Qwen3、Qwen3-VL 和 Qwen3.5 系
▼
RequestResponseProcess
-这个拓扑的核心取舍是:GPU 资源由 scheduler 进程内的 ``ModelRunnerProcess`` 直接持有。
-这样 scheduler 可以在同一进程中完成调度、KV cache 资源释放、prefix cache 更新和模型
-forward,避免再引入 model worker 进程之间的 GPU 资源同步。
+这里最关键的一个取舍是:GPU 资源由 scheduler 进程内的 ``ModelRunnerProcess`` 直接持有。
+这样调度、KV cache 释放、prefix cache 更新和模型 forward 都在同一个进程里完成,省掉了
+model worker 进程之间同步 GPU 资源的那套复杂度。
请求生命周期
----------------------------------------
-一次 chat completion 请求的典型路径如下:
-
-1. HTTP server 接收请求并转换为 ``GenerateReqInput``。
-2. ``RequestResponseProcess`` 为请求分配 request id,并把请求送入 tokenizer。
-3. ``TokenizerProcess`` 调用 tokenizer / processor,生成 ``TokenizedGenerateReqInput``。
-4. ``SchedulerProcess`` 接收 tokenized request,创建 ``Req``,放入等待队列。
-5. scheduler 根据 token budget、running request 数量和 prefill/decode 状态构造
- ``ScheduleBatch``。
-6. ``ModelRunnerProcess`` 为 batch 分配 request slot 和 KV slot,执行 prefix matching。
-7. ``ModelRunner`` 构造 ``ForwardBatch``,初始化 attention backend metadata,调用模型
- ``forward``,并对 logits 做 sampling。
+一次 chat completion 请求的典型路径:
+
+1. HTTP server 收到请求,转成 ``GenerateReqInput``。
+2. ``RequestResponseProcess`` 分配 request id,把请求送进 tokenizer。
+3. ``TokenizerProcess`` 调 tokenizer / processor,产出 ``TokenizedGenerateReqInput``。
+4. ``SchedulerProcess`` 接到 tokenized request,创建 ``Req``,放进等待队列。
+5. scheduler 按 token budget、running request 数和 prefill/decode 状态构造 ``ScheduleBatch``。
+6. ``ModelRunnerProcess`` 为 batch 分配 request slot 和 KV slot,做 prefix matching。
+7. ``ModelRunner`` 构造 ``ForwardBatch``,初始化 attention backend metadata,调模型
+ ``forward``,再对 logits 做 sampling。
8. scheduler 更新每个 ``Req`` 的输出 token、finished reason 和 timing 字段。
-9. ``DetokenizerProcess`` 将 token id 转回文本。
-10. HTTP server 以普通 JSON 或 SSE streaming 形式返回结果。
+9. ``DetokenizerProcess`` 把 token id 转回文本。
+10. HTTP server 以普通 JSON 或 SSE streaming 返回。
控制面:Engine 与配置
----------------------------------------
-``pymllm.configs.server_config.ServerConfig`` 是服务运行时的主配置对象。它覆盖:
+``pymllm.configs.server_config.ServerConfig`` 是服务运行时的主配置对象,覆盖几类参数:
-- 模型和 tokenizer:``model_path``、``tokenizer_path``、``load_format``、``dtype``。
+- 模型与 tokenizer:``model_path``、``tokenizer_path``、``load_format``、``dtype``。
- HTTP server:``host``、``port``、``api_key``、``served_model_name``。
- 调度与内存:``max_running_requests``、``max_total_tokens``、``max_prefill_tokens``、
``mem_fraction_static``。
@@ -101,80 +99,79 @@ forward,避免再引入 model worker 进程之间的 GPU 资源同步。
``cuda_ipc_pool_size_mb``。
- 观测与调试:``log_level``、``decode_log_interval``。
-``Engine`` 启动前会加载 HuggingFace config,解析 EOS token、默认输出长度和 dtype,并确保
-model/tokenizer 路径可用。启动后,``Engine`` 会监控子进程健康状态;任一核心子进程异常退出,
-服务会被标记为 unhealthy。
+``Engine`` 启动前会先加载 HuggingFace config,解析 EOS token、默认输出长度和 dtype,并
+确认 model / tokenizer 路径可用。启动之后它会盯着子进程的健康状态,任何一个核心子进程异常
+退出,整个服务都会被标记为 unhealthy。
调度器
----------------------------------------
-``SchedulerProcess`` 是 pymllm 的中心调度组件。它负责:
+``SchedulerProcess`` 是 pymllm 的中心调度组件,干这几件事:
-- 接收 tokenized requests。
-- 将输入请求转换为内部 ``Req`` 状态。
-- 根据 prefill/decode 状态构造 ``ScheduleBatch``。
-- 控制 ``max_running_requests``、``max_total_tokens``、``max_prefill_tokens`` 等资源约束。
-- 在请求结束或中止时释放 request slot 和 KV slot。
-- 将 decode token 发送给 detokenizer。
+- 接收 tokenized request。
+- 把输入请求转成内部 ``Req`` 状态。
+- 按 prefill / decode 状态构造 ``ScheduleBatch``。
+- 守住 ``max_running_requests``、``max_total_tokens``、``max_prefill_tokens`` 这些资源约束。
+- 请求结束或中止时释放 request slot 和 KV slot。
+- 把 decode token 发给 detokenizer。
-当前调度策略以 FCFS 和单 GPU 资源约束为主。``max_prefill_tokens`` 用于限制一轮调度
-可接纳的 prefill token 数;长 prompt 的运行时 chunked prefill 切分仍待后续接入。
+当前调度策略以 FCFS 加单 GPU 资源约束为主。``max_prefill_tokens`` 用来限制一轮调度能接纳的
+prefill token 数;长 prompt 的运行时 chunked prefill 切分还没接进来,是后续的事。
ModelRunner
----------------------------------------
-``ModelRunner`` 是真正执行模型 forward 的组件。它在初始化阶段完成:
+``ModelRunner`` 是真正跑模型 forward 的组件。初始化阶段它会:
1. 设置 CUDA device 和默认 dtype。
2. 加载模型类和 safetensors 权重。
-3. 解析模型 metadata,例如 layer 数、head 数、head dim、context length。
+3. 解析模型 metadata,比如 layer 数、head 数、head dim、context length。
4. 初始化 request-to-token pool、token-to-KV pool 和 KV allocator。
5. 初始化 attention backend。
6. 预热 cuBLAS。
-7. 按配置捕获 decode CUDA Graph。
+7. 按配置 capture decode CUDA Graph。
-forward 阶段分为 extend 和 decode 两类:
+forward 分 extend 和 decode 两类:
-- **extend / prefill**:处理 prompt token,写入 KV cache,并返回每个请求最后一个 token 的
- logits。
+- **extend / prefill**:处理 prompt token,写 KV cache,返回每个请求最后一个 token 的 logits。
- **decode**:每个请求生成一个新 token,复用已有 KV cache 和 attention metadata。
KV cache 与 prefix cache
----------------------------------------
-``pymllm.mem_cache.memory_pool`` 中的 KV 管理采用三层结构:
+``pymllm.mem_cache.memory_pool`` 里的 KV 管理是三层结构:
.. code-block:: text
ReqToTokenPool
- maps (request slot, position) -> kv index
+ (request slot, position) -> kv index
TokenToKVPoolAllocator
- manages free integer KV slots
+ 管理空闲的整数 KV slot
KVPool
- stores per-layer K/V tensors on GPU
+ 在 GPU 上存每层的 K/V tensor
-``TokenToKVPoolAllocator`` 使用 free-list 管理 KV slot,并通过批量释放接口降低大量请求结束或
-prefix cache eviction 时的开销。``KVPool`` 在条件满足时会调用 ``mllm-kernel`` 的
-``store_cache`` JIT kernel 写入 K/V;否则回退到 PyTorch indexing。
+``TokenToKVPoolAllocator`` 用 free-list 管理 KV slot,并提供批量释放接口,在大量请求结束或
+prefix cache eviction 时降低开销。``KVPool`` 在条件满足时调用 ``mllm-kernel`` 的
+``store_cache`` JIT kernel 写 K/V,否则回退到 PyTorch indexing。
-Prefix cache 当前有三种实现:
+prefix cache 目前有三种实现:
-- ``RadixCache``:标准 radix-tree prefix cache。
-- ``ChunkCache``:关闭 radix cache 时使用的简单缓存路径。
-- ``MambaRadixCache``:为包含 GDN / Mamba-like 状态的 hybrid 模型预留的状态缓存路径。
+- ``RadixCache``:标准的 radix-tree prefix cache。
+- ``ChunkCache``:关掉 radix cache 时用的简单缓存路径。
+- ``MambaRadixCache``:给带 GDN / Mamba-like 状态的 hybrid 模型预留的状态缓存路径。
-当启用 ``RadixCache`` 时,extend batch 会先执行 prefix matching。命中的 prefix token 不再
-重复计算,但对应 radix tree 节点会被 lock,直到请求结束或资源释放时再 unlock。
+开 ``RadixCache`` 时,extend batch 会先做 prefix matching:命中的 prefix token 不再重复计算,
+但对应的 radix tree 节点会被 lock 住,直到请求结束或资源释放才 unlock。
IPC 与多模态数据传输
----------------------------------------
-普通控制消息通过 ZMQ 传输。多模态请求中的大 tensor 可以走 shared queue fast path,
-由 ``enable_shared_queue`` 和 ``tensor_transport_mode`` 控制。
+普通控制消息走 ZMQ。多模态请求里的大 tensor 可以走 shared queue 这条 fast path,由
+``enable_shared_queue`` 和 ``tensor_transport_mode`` 控制。
-``tensor_transport_mode`` 支持三种模式:
+``tensor_transport_mode`` 有三种模式:
.. list-table::
:header-rows: 1
@@ -183,22 +180,22 @@ IPC 与多模态数据传输
- 行为
- 适用场景
* - ``default``
- - GPU tensor 先拷到 CPU,再放入 POSIX shared memory。
+ - GPU tensor 先拷回 CPU,再放进 POSIX shared memory。
- 最稳妥,调试优先。
* - ``cuda_ipc``
- GPU tensor 通过 CUDA IPC handle 跨进程共享。
- - 避免 GPU->CPU 拷贝,但长服务中可能有 PyTorch IPC 生命周期问题。
+ - 省掉 GPU→CPU 拷贝,但长时间服务里可能踩到 PyTorch IPC 的生命周期问题。
* - ``cuda_ipc_pool``
- - 使用预分配 GPU workspace,发送方回收 chunk。
- - 面向生产服务的推荐 GPU tensor 传输方式。
+ - 用预分配的 GPU workspace,发送方回收 chunk。
+ - 面向生产服务推荐的 GPU tensor 传输方式。
与 mllm C++ Backend 的关系
----------------------------------------
-``pymllm`` 和 ``cpu_backend``、``qnn_backend``、``ascend_backend`` 的层级不同:
+``pymllm`` 和 ``cpu_backend``、``qnn_backend``、``ascend_backend`` 不在同一个层级:
-- C++ Backend 接入的是 mllm C++ 的 Tensor、Op、Module、Dispatcher 和设备 allocator。
-- ``pymllm`` 接入的是 Python/PyTorch serving pipeline,主要服务于在线推理、模型加载、
- KV cache、调度和 CUDA kernel 集成。
-- ``mllm-kernel`` 是两者可以共享思想的低层 kernel 工具包,但当前 ``pymllm`` 更直接依赖
- 其中的 Python JIT CUDA kernel。
+- C++ Backend 接的是 mllm C++ 那套 Tensor、Op、Module、Dispatcher 和设备 allocator。
+- ``pymllm`` 接的是 Python / PyTorch serving pipeline,服务在线推理、模型加载、KV cache、
+ 调度和 CUDA kernel 集成。
+- ``mllm-kernel`` 是两边都可以借鉴的低层 kernel 工具包,不过目前 ``pymllm`` 更直接依赖其中
+ 的 Python JIT CUDA kernel。
diff --git a/docs/pymllm_runtime/setup_and_usage.rst b/docs/pymllm_runtime/setup_and_usage.rst
index 3097bbbbf..b42a75e11 100644
--- a/docs/pymllm_runtime/setup_and_usage.rst
+++ b/docs/pymllm_runtime/setup_and_usage.rst
@@ -4,19 +4,16 @@ pymllm Setup and Usage
总览
----------------------------------------
-``pymllm`` 是 mllm 面向 Python 生态的推理服务运行时,主要面向 NVIDIA Jetson
-Orin 系列边缘 GPU 设备,例如 Jetson Orin NX 与 Jetson AGX Orin。它覆盖
-Qwen3 / Qwen3-VL 的 BF16、W4A16 和 W8A8 推理路径,并提供 OpenAI-compatible
-HTTP API。
+``pymllm`` 是 mllm 面向 Python / CUDA 生态的推理服务运行时,主要跑在 NVIDIA Jetson
+Orin 系列边缘 GPU(Orin NX / AGX Orin)上。它针对 Orin Ampere Tensor Core 的 INT8
+算力做了系统级适配,支持 BF16 原生推理以及 W4A16、W8A8_INT8 两种量化方案,兼顾推理
+速度与模型精度,目前已完成对 Qwen3、Qwen3-VL、Qwen3.5 的支持,并对外提供一套
+OpenAI-compatible 的 HTTP API。
环境要求
----------------------------------------
-当前推荐基于 `jetson-containers `_
-提供的 Jetson PyTorch/CUDA 基础镜像进行开发。这样可以避免在 Jetson 上手工处理
-PyTorch、CUDA、cuDNN、Python ABI 等基础依赖。
-
-已验证环境如下:
+下面是当前已经跑通的一组版本:
.. list-table::
:header-rows: 1
@@ -47,7 +44,7 @@ PyTorch、CUDA、cuDNN、Python ABI 等基础依赖。
安装依赖
----------------------------------------
-在 Jetson 容器中克隆仓库后,进入仓库根目录安装 ``pymllm`` 和 ``mllm-kernel``:
+克隆仓库后,进入根目录安装 ``pymllm`` 和 ``mllm-kernel``:
.. code-block:: bash
@@ -55,67 +52,47 @@ PyTorch、CUDA、cuDNN、Python ABI 等基础依赖。
SKBUILD_WHEEL_CMAKE=false python3 -m pip install -e .
python3 -m pip install -e /mllm-kernel --no-deps --no-build-isolation
-``transformers`` 可按项目需要自行安装。``triton`` 和 ``flashinfer`` 可以从
-Jetson AI Lab 的 wheel 源安装,也可以从官方 PyPI 或对应上游项目安装:
+``triton`` 和 ``flashinfer`` 有两个来源,任选其一:
.. code-block:: bash
- # 方式一:从 Jetson AI Lab 安装 Jetson wheel。
+ # 方式一:从 Jetson AI Lab 装 Jetson wheel。
python3 -m pip install --extra-index-url https://pypi.jetson-ai-lab.io/ triton flashinfer
- # 方式二:从官方 PyPI 固定 Triton,再单独安装 FlashInfer。
+ # 方式二:从官方 PyPI 固定 Triton 版本,FlashInfer 仍从 Jetson AI Lab 装。
python3 -m pip install --index-url https://pypi.org/simple triton==3.6.0
python3 -m pip install --extra-index-url https://pypi.jetson-ai-lab.io/ flashinfer
-在 Jetson / aarch64 上,Triton wheel 的可用性会受到 wheel 来源、CUDA 路径和
-``ptxas`` / ``cuda.h`` 查找路径影响。Jetson AI Lab 源提供面向 JetPack 6 /
-CUDA 12.6 的 Triton wheel;在已验证环境中,官方 PyPI 的 ``triton==3.6.0``
-manylinux aarch64 wheel 更接近开箱即用。若使用 Jetson AI Lab wheel 遇到
-``ptxas`` 或 CUDA 头文件查找问题,可显式设置 ``TRITON_PTXAS_PATH`` 和
-``CPATH`` 后重试。无论选择哪个来源,都建议用最小 Triton kernel 或
-``per_token_quant_int8`` 做 smoke test。
-
-最小导入检查:
-
-.. code-block:: bash
+在 aarch64 上,Triton wheel 能不能开箱即用,主要取决于 wheel 来源以及
+``ptxas`` / ``cuda.h`` 的查找路径。在上面这组已验证环境里,官方 PyPI 的
+``triton==3.6.0`` manylinux aarch64 wheel 更接近开箱即用;如果用 Jetson AI Lab
+的 wheel 碰到 ``ptxas`` 或 CUDA 头文件找不到的问题,显式设置 ``TRITON_PTXAS_PATH``
+和 ``CPATH`` 再重试通常能解决。装完后建议用 ``per_token_quant_int8`` 之类的最小
+kernel 跑一次 smoke test,确认 Triton 真的能编译。
- python3 - <<'PY'
- import pymllm
- import mllm_kernel
-
- print("pymllm import ok")
- print("mllm_kernel import ok")
- PY
-
-CUTLASS 头文件
+W8A8 首次运行的 JIT 编译
----------------------------------------
-W8A8 的高性能 GEMM 路径依赖 CUTLASS 头文件。当前查找顺序为:
-
-1. ``CUTLASS_HOME/include``
-2. ``flashinfer`` 内置的 ``data/cutlass/include``
-3. ``/usr/local/include``、``/usr/include``、``/usr/local/cuda/include``
+W8A8 的 INT8 GEMM 走 CUTLASS,依赖 CUTLASS 头文件。默认情况下不需要额外配置——
+``flashinfer`` 自带了一份 bundled CUTLASS,可以直接用;如果想换成自己的版本,设置
+``CUTLASS_HOME`` 即可。
-首次调用 CUTLASS W8A8 kernel 会触发 JIT 编译,编译产物会复用:
+第一次调用 W8A8 kernel 会触发一次 JIT 编译,编译产物缓存在:
.. code-block:: text
~/.cache/mllm_kernel/cutlass_int8_scaled_mm/
-如果需要重新验证首次编译行为,可以删除该目录后再次运行。
+之后复用缓存,不会再编译。想重新验证首次编译行为时,删掉这个目录再跑一次就行。
启动服务
----------------------------------------
-``pymllm`` 的服务入口是 ``pymllm.server.launch``。服务启动后会提供
-``/health``、``/v1/models``、``/v1/completions``、``/v1/chat/completions``、
-``/generate`` 等接口。
+服务入口是 ``pymllm.server.launch``,启动后提供 ``/health``、``/v1/models``、
+``/v1/completions``、``/v1/chat/completions``、``/generate`` 等接口。
-W4A16 / W8A8 量化模型
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-``compressed-tensors`` 量化模型使用同一个启动入口。运行时会根据模型
-``config.json`` 中的量化配置识别 W4A16 或 W8A8 路径。
+W4A16 / W8A8 量化模型和 BF16 原生模型共用同一个入口,运行时会读 ``config.json``
+里的量化配置,自动走 W4A16 或 W8A8 路径。一条典型的量化模型启动命令:
.. code-block:: bash
@@ -123,45 +100,17 @@ W4A16 / W8A8 量化模型
python3 -m pymllm.server.launch \
--server.model_path \
- --server.tokenizer_path \
- --server.load_format safetensors \
--server.dtype float16 \
--quantization.method compressed-tensors \
--server.host 0.0.0.0 \
--server.port 30000 \
- --server.attention_backend auto \
- --server.gdn_decode_backend pytorch \
- --server.mem_fraction_static 0.05 \
+ --server.mem_fraction_static 0.8 \
--server.max_running_requests 1 \
- --server.max_total_tokens 256 \
- --server.max_prefill_tokens 128 \
+ --server.max_total_tokens 4096 \
--server.disable_radix_cache \
- --server.disable_cuda_graph \
--server.log_level debug
-BF16 原生模型
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-BF16 或 FP16 原生模型不需要设置 ``--quantization.method``:
-
-.. code-block:: bash
-
- cd
-
- python3 -m pymllm.server.launch \
- --server.model_path \
- --server.tokenizer_path \
- --server.load_format safetensors \
- --server.dtype bfloat16 \
- --server.host 0.0.0.0 \
- --server.port 30000 \
- --server.attention_backend auto \
- --server.mem_fraction_static 0.05 \
- --server.max_running_requests 1 \
- --server.max_total_tokens 256 \
- --server.max_prefill_tokens 128 \
- --server.disable_radix_cache \
- --server.log_level info
+BF16 / FP16 原生模型用同一条命令,去掉 ``--quantization.method`` 即可。
常用参数
----------------------------------------
@@ -174,21 +123,22 @@ BF16 或 FP16 原生模型不需要设置 ``--quantization.method``:
* - ``--server.model_path``
- 模型权重目录,通常是 HuggingFace 或 ModelScope 格式。
* - ``--server.tokenizer_path``
- - tokenizer 目录;不设置时默认等于 ``model_path``。
+ - tokenizer 目录;不设置时默认等于 ``model_path``,一般不用单独传。
* - ``--server.dtype``
- - 模型运行 dtype,可选 ``auto``、``float16``、``bfloat16``、``float32``。
+ - 模型运行 dtype,可选 ``auto``、``float16``、``bfloat16``。
* - ``--quantization.method compressed-tensors``
- - 启用 ``compressed-tensors`` 权重加载与线性层执行路径。
+ - 启用 ``compressed-tensors`` 权重加载和量化线性层执行路径。
+ * - ``--server.mem_fraction_static``
+ - ``模型权重 + KV cache pool`` 占 GPU 总显存的静态预算比例。设太小,KV pool 预算
+ 不足会导致启动报错;设太大,留给 activation 和 CUDA Graph 的动态空间不够。
+ Jetson 上 Qwen3-VL-2B 量化模型一般在 ``0.5``–``0.8`` 之间起调。
* - ``--server.max_running_requests``
- - 同时运行的请求数。Jetson 小显存环境下通常从 ``1`` 开始调试。
+ - 同时运行的请求数。Jetson 小显存环境一般从 ``1`` 开始调。
* - ``--server.max_total_tokens``
- - KV cache token pool 的总容量上限。
- * - ``--server.max_prefill_tokens``
- - 单轮 prefill 可处理的 token 上限。
+ - KV cache token pool 的容量上限,是整个 worker 全局共享的池子(不是单请求上限)。
+ 实际容量取 ``min(profile 可承载 token 数, max_total_tokens)``,不会绕过显存 profile。
* - ``--server.disable_radix_cache``
- 关闭 Radix Cache,改用 ``ChunkCache``。
- * - ``--server.disable_cuda_graph``
- - 关闭 decode CUDA Graph,便于调试动态路径。
OpenAI-compatible 请求
----------------------------------------
@@ -213,7 +163,7 @@ OpenAI-compatible 请求
"stream": false
}' ; echo
-图文请求中,图片路径需要是容器内可访问的绝对路径,不要带 ``file://`` 前缀:
+图文请求里的图片路径要用服务进程可访问的绝对路径,不要带 ``file://`` 前缀:
.. code-block:: bash
@@ -239,121 +189,106 @@ OpenAI-compatible 请求
-H "Content-Type: application/json" \
--data @/tmp/mm_req_path.json ; echo
-开发与测试
+Benchmark
----------------------------------------
-常用单元测试:
-
-.. code-block:: bash
-
- pytest pymllm/tests/test_compressed_tensors_config.py -q
- pytest pymllm/tests/test_compressed_tensors_runtime.py -q
- pytest pymllm/tests/test_qwen3_model_registry.py -q
- pytest pymllm/tests/test_qwen3_weight_loading.py -q
- pytest pymllm/tests/test_qwen3_forward_timing.py -q
- pytest mllm-kernel/tests/test_int8_scaled_mm_cutlass.py -q
+``bench_one_batch`` 是一个低层离线 benchmark。它直接初始化
+``pymllm.executor.model_runner.ModelRunner``,绕过 HTTP server、tokenizer、scheduler、
+detokenizer 这些进程,只测模型本身一次静态 prefill 加逐 token decode 的开销,因此适合
+分析模型 forward、KV cache、attention、CUDA Graph 和量化 kernel 的模型级表现,也方便
+验证 fused projection、residual-carry 这类模型图优化。它测不到在线服务的 TTFT / ITL /
+E2E,这两个口径不要混用。
-模型级 benchmark:
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+目前 ``bench_one_batch`` 支持三种测速口径:
-``bench_one_batch`` 是对齐 SGLang 口径的低层离线 benchmark。它直接初始化
-``pymllm.executor.model_runner.ModelRunner``,绕过 HTTP server、tokenizer 进程、
-scheduler 进程和 detokenizer 进程,用 synthetic text-only token ids 测一次静态
-prefill,再测逐 token decode。该工具适合分析模型 forward、KV cache、attention、
-CUDA Graph 与量化 kernel 的模型级开销,不代表在线服务的 TTFT / ITL / E2E 指标。
+- **纯文本**:用 synthetic token ids 测纯文本的 prefill / decode;
+- **视觉编码(vit_prefill)**:同步墙钟只包住视觉 encoder(``self.visual(...)``),
+ 反映纯视觉编码速度;
+- **多模态 prefill(multimodal_prefill)**:覆盖“视觉编码 + 图像/文本 token 的 LLM
+ prefill”,反映完整多模态 prefill 速度。
-典型用法:
+纯文本用法:
.. code-block:: bash
PYTHONPATH="$PWD:$PWD/mllm-kernel" python3 -m pymllm.bench_one_batch \
--server.model_path \
- --server.tokenizer_path \
- --server.load_format safetensors \
--server.dtype float16 \
--quantization.method compressed-tensors \
- --server.mem_fraction_static 0.1 \
+ --server.mem_fraction_static 0.8 \
--server.max_running_requests 1 \
--server.max_total_tokens 2048 \
- --server.disable_radix_cache \
--server.log_level info \
--run-name qwen3vl_w8a8_bench_one_batch \
--batch-size 1 \
--input-len 256 512 1024 \
--output-len 128 \
- --result-filename /tmp/pymllm_bench_one_batch.jsonl
+ --result-filename
+
+``--batch-size``、``--input-len``、``--output-len`` 都支持多个值,脚本会遍历所有组合
+并把结果追加到 JSONL 文件。``output_len`` 用的是总输出 token 语义:prefill 之后已经
+拿到第一个 next token,后续 decode loop 再跑 ``output_len - 1`` 步。
+
+多模态 prefill 用法。给 ``--image`` 传一张真实图片,再显式传 ``--input-len`` 时,长度
+口径是 ``image placeholder tokens + text prompt tokens`` 的目标总长——脚本只在文本
+token 上做补齐或截断,绝不动 image token,因此可以用同一张图 sweep
+``314/512/1024/2048`` 等不同总长,测包含视觉编码的完整多模态 prefill 速度:
+
+.. code-block:: bash
+
+ PYTHONPATH="$PWD:$PWD/mllm-kernel" python3 -m pymllm.bench_one_batch \
+ --server.model_path \
+ --server.trust_remote_code true \
+ --server.dtype float16 \
+ --quantization.method compressed-tensors \
+ --server.mem_fraction_static 0.8 \
+ --server.max_running_requests 1 \
+ --server.disable_cuda_graph \
+ --batch-size 1 \
+ --input-len 314 512 1024 2048 \
+ --output-len 1 \
+ --image \
+ --prompt "Describe this image." \
+ --run-name qwen3vl_w8a8_multimodal_prefill \
+ --result-filename
-其中 ``--batch-size``、``--input-len`` 和 ``--output-len`` 都支持多个值,脚本会遍历
-所有组合并向 JSONL 文件追加结果。``output_len`` 采用 SGLang 的总输出 token 语义:
-prefill 后已得到第一个 next token,后续 decode loop 执行 ``output_len - 1`` 步。
+JSONL 里 ``vit_prefill_ms`` 只包住 ``self.visual(...)``,``multimodal_prefill_*``
+则是完整 VIT + LLM prefill 的别名字段,两者口径不同。在 AGX Orin 32GB 上的实测中,
+W8A8 在长 prefill 上明显领先 FP16 / W4A16。
-执行结构:
+脚本的整体执行流程大致是:
.. code-block:: text
pymllm.bench_one_batch CLI
|
- |-- parse GlobalConfig args and BenchArgs
- |-- load HuggingFace AutoConfig into cfg.model.hf_config
+ |-- 解析 GlobalConfig 参数和 BenchArgs
+ |-- 加载 HuggingFace AutoConfig 到 cfg.model.hf_config
|
|-- ModelRunner.initialize()
- | |-- load model and quantization config
- | |-- initialize KV pools and attention backend
- | |-- optionally capture decode CUDA Graph
+ | |-- 加载模型和量化配置
+ | |-- 初始化 KV pool 和 attention backend
+ | |-- 按需 capture decode CUDA Graph
|
- |-- warmup once
+ |-- warmup 一次
|
- |-- for each (batch_size, input_len, output_len):
- |
- |-- clear req_to_token_pool and token_to_kv_pool_allocator
- |-- build synthetic input_ids
- |-- prefill:
- | allocate request slots and KV slots
- | write prompt KV mapping
- | prepare ForwardBatch(EXTEND)
- | synchronize, run forward + sampling, synchronize
- |
- |-- decode loop:
- allocate one KV slot per request
- write current token mapping
- prepare ForwardBatch(DECODE)
- synchronize, run forward + sampling, synchronize
- update seq_lens and next token ids
+ |-- 遍历每个 (batch_size, input_len, output_len):
+ | |-- 清空 req_to_token_pool 和 token_to_kv_pool_allocator
+ | |-- 构造 synthetic input_ids
+ | |-- prefill:分配 request/KV slot,写 KV 映射,跑 forward + sampling
+ | |-- decode loop:逐步分配 KV slot,跑 forward + sampling,更新 seq_lens
|
- |-- append JSONL result rows
+ |-- 追加 JSONL 结果行
-Profile 辅助入口:
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
-``bench_one_batch`` 保留了基于 ``torch.profiler`` 的 profile 参数,主要用于本地
-kernel timeline 分析。当前公开 benchmark 记录没有使用 profile 结果,因此它不作为标准
-性能数据口径的一部分。使用前建议先用较小的 ``input_len`` / ``output_len`` 做一次
-trace 生成验证,再扩大到正式 case。
-
-.. code-block:: bash
-
- PYMLLM_TORCH_PROFILER_DIR=/tmp \
- PYTHONPATH="$PWD:$PWD/mllm-kernel" python3 -m pymllm.bench_one_batch \
- --server.model_path \
- --server.tokenizer_path \
- --server.load_format safetensors \
- --server.dtype bfloat16 \
- --server.mem_fraction_static 0.1 \
- --server.max_running_requests 1 \
- --server.max_total_tokens 2048 \
- --server.log_level info \
- --batch-size 1 \
- --input-len 256 \
- --output-len 128 \
- --profile \
- --profile-stage decode \
- --profile-steps 1
+Profile
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-已知限制
-----------------------------------------
+``bench_one_batch`` 内置了 profile 入口,方便在本地看 kernel timeline。目前有两条路径:
-- W8A8 CUTLASS 当前通过 JIT 编译,首次启动有明显编译开销。
-- W8A8 激活量化使用 Triton kernel;decode 小 batch 下固定量化开销仍是后续优化点。
-- Qwen3-VL 的 ViT、``lm_head``、embedding 和 LayerNorm 不在当前 W8A8 量化范围内。
-- 当前文档中的 Jetson 性能与稳定性结论主要来自 Orin NX / SM87,需要在其他 GPU 上重新验证。
-- OpenAI-compatible API 的服务级指标和 ``bench_one_batch`` 的模型级指标口径不同,不应直接混用。
+- **torch.profiler(已支持)**:``--profile-activities CPU GPU``(默认),输出
+ ``.trace.json.gz`` timeline,可以直接在 Perfetto / chrome://tracing 里看。输出目录
+ 由 ``PYMLLM_TORCH_PROFILER_DIR`` 指定,默认 ``/tmp``。
+- **Nsight Systems / nsys(实验性)**:``--profile-activities CUDA_PROFILER`` 通过
+ ``cudaProfilerStart/Stop`` 驱动 nsys,需要外层用
+ ``nsys --capture-range=cudaProfilerApi`` 包住命令。这条路径还在打磨中,部分场景下
+ 可能不够顺手,仅作为可选的深入分析手段。
diff --git a/pymllm/README-ZH.md b/pymllm/README-ZH.md
index a32c6580c..fcb77fd2e 100644
--- a/pymllm/README-ZH.md
+++ b/pymllm/README-ZH.md
@@ -2,179 +2,127 @@

-`pymllm` 是 `mllm` 的 Python 推理服务入口。本目录当前重点覆盖
-Jetson Orin 上的 Qwen3 / Qwen3-VL 推理、OpenAI-compatible server、
-`compressed-tensors` 量化加载,以及 W8A8 INT8 kernel 路径。
+## 总览
-本文档按 2026-04-27 的开发状态整理,适用于当前集成分支:
+`pymllm` 是 mllm 面向 Python / CUDA 生态的推理服务运行时,主要跑在 NVIDIA Jetson
+Orin 系列边缘 GPU(Orin NX / AGX Orin)上。它针对 Orin Ampere Tensor Core 的 INT8
+算力做了系统级适配,支持 BF16 原生推理以及 W4A16、W8A8_INT8 两种量化方案,兼顾推理
+速度与模型精度,目前已完成对 Qwen3、Qwen3-VL、Qwen3.5 的支持,并对外提供一套
+OpenAI-compatible 的 HTTP API。
-```text
-feature/jetson-qwen3-family-bf16-w4a16-w8a8
-```
-
-## 当前状态
-
-已验证路径:
+## 环境要求
-- `Qwen3-VL-2B-Instruct`:BF16 原生模型服务可用。
-- `Qwen3-VL-2B-Instruct-AWQ-4bit`:`compressed-tensors`
- W4A16 / AWQ Marlin 路径可用。
-- `Qwen3-VL-2B-Instruct-quantized.w8a8`:`compressed-tensors`
- W8A8 `int-quantized` 路径端到端可用。
+下面是当前已经跑通的一组版本:
-已实现并纳入单元测试的模型/组件:
+| 组件 | 版本或说明 |
+| --- | --- |
+| JetPack / Jetson Linux | JetPack `6.2.1` / Jetson Linux `36.4.4` (L4T `R36.4.4`) |
+| Python | `3.10.12` |
+| PyTorch | `2.4.0` |
+| torchvision | `0.19.0a0+48b1edf` |
+| transformers | `5.3.0` |
+| safetensors | `0.7.0` |
+| flashinfer | `0.6.7` |
+| Triton Language | `triton==3.6.0` aarch64 wheel |
+| CUDA | `12.6` |
+| GPU | Jetson Orin NX,SM87 |
-- `Qwen3VLForConditionalGeneration`:图文模型服务主路径。
-- `Qwen3ForCausalLM`:文本模型骨架、权重加载与 timing 字段测试。
-- `compressed-tensors`:
- - `pack-quantized` 4-bit 权重路径,使用 GPTQ Marlin。
- - `int-quantized` W8A8 路径,使用 Triton 激活量化 + CUTLASS
- `int8_scaled_mm`。
+## 安装依赖
-W8A8 当前前向链路:
+克隆仓库后,进入根目录安装 `pymllm` 和 `mllm-kernel`:
-```text
-x(fp16/bf16)
- -> per_token_quant_int8 [Triton, dynamic per-token activation quant]
- -> int8_scaled_mm [CUTLASS, INT8 Tensor Core, fused scales]
- -> output(fp16/bf16)
+```bash
+cd
+SKBUILD_WHEEL_CMAKE=false python3 -m pip install -e .
+python3 -m pip install -e /mllm-kernel --no-deps --no-build-isolation
```
-## 已验证环境
+`triton` 和 `flashinfer` 有两个来源,任选其一:
-以下命令基于 Jetson Orin 环境整理:
+```bash
+# 方式一:从 Jetson AI Lab 装 Jetson wheel。
+python3 -m pip install --extra-index-url https://pypi.jetson-ai-lab.io/ triton flashinfer
-- JetPack / L4T:`R36.4.4`(来自 `/etc/nv_tegra_release`)
-- Python:`3.10.12`
-- PyTorch:`2.4.0`
-- torchvision:`0.19.0a0+48b1edf`
-- transformers:`5.3.0`
-- safetensors:`0.7.0`
-- flashinfer:`0.6.7`
-- Triton Language:官方 PyPI `triton==3.6.0` manylinux aarch64 wheel
-- CUDA:`12.6`
-- GPU:Jetson Orin NX,SM87
+# 方式二:从官方 PyPI 固定 Triton 版本,FlashInfer 仍从 Jetson AI Lab 装。
+python3 -m pip install --index-url https://pypi.org/simple triton==3.6.0
+python3 -m pip install --extra-index-url https://pypi.jetson-ai-lab.io/ flashinfer
+```
-这里的 Triton 指 GPU kernel DSL,不是 Triton Inference Server。Jetson-AI-Lab
-源也提供 `3.4.0`、`3.5.1`、`3.6.0`,但实测中可能需要额外设置
-`TRITON_PTXAS_PATH` 和 `CPATH`。当前建议优先使用官方 PyPI 的
-`triton==3.6.0`,并用最小 CUDA kernel 或 `per_token_quant_int8` 做 smoke test。
+在 aarch64 上,Triton wheel 能不能开箱即用,主要取决于 wheel 来源以及
+`ptxas` / `cuda.h` 的查找路径。在上面这组已验证环境里,官方 PyPI 的
+`triton==3.6.0` manylinux aarch64 wheel 更接近开箱即用;如果用 Jetson AI Lab
+的 wheel 碰到 `ptxas` 或 CUDA 头文件找不到的问题,显式设置 `TRITON_PTXAS_PATH`
+和 `CPATH` 再重试通常能解决。装完后建议用 `per_token_quant_int8` 之类的最小
+kernel 跑一次 smoke test,确认 Triton 真的能编译。
-W8A8 CUTLASS JIT 需要能找到 CUTLASS 头文件。当前查找顺序为:
+## W8A8 首次运行的 JIT 编译
-1. `CUTLASS_HOME/include`
-2. `flashinfer` 内置的 `data/cutlass/include`
-3. `/usr/local/include`、`/usr/include`、`/usr/local/cuda/include`
+W8A8 的 INT8 GEMM 走 CUTLASS,依赖 CUTLASS 头文件。默认情况下不需要额外配置——
+`flashinfer` 自带了一份 bundled CUTLASS,可以直接用;如果想换成自己的版本,设置
+`CUTLASS_HOME` 即可。
-首次调用 CUTLASS kernel 会触发 JIT 编译,耗时约 100 秒;后续会复用:
+第一次调用 W8A8 kernel 会触发一次 JIT 编译,编译产物缓存在:
```text
~/.cache/mllm_kernel/cutlass_int8_scaled_mm/
```
-## 安装开发环境
-
-在仓库根目录执行:
-
-```bash
-cd
-SKBUILD_WHEEL_CMAKE=false python3 -m pip install -e .
-python3 -m pip install -e /mllm-kernel --no-deps --no-build-isolation
-```
-
-最小导入检查:
-
-```bash
-python3 - <<'PY'
-import pymllm
-import mllm_kernel
-
-print("pymllm import ok")
-print("mllm_kernel import ok")
-PY
-```
+之后复用缓存,不会再编译。想重新验证首次编译行为时,删掉这个目录再跑一次就行。
## 启动服务
-### 量化模型(W4A16 / W8A8)
+服务入口是 `pymllm.server.launch`,启动后提供 `/health`、`/v1/models`、
+`/v1/completions`、`/v1/chat/completions`、`/generate` 等接口。
+
+W4A16 / W8A8 量化模型和 BF16 原生模型共用同一个入口,运行时会读 `config.json`
+里的量化配置,自动走 W4A16 或 W8A8 路径。一条典型的量化模型启动命令:
```bash
cd
python3 -m pymllm.server.launch \
--server.model_path \
- --server.tokenizer_path \
- --server.load_format safetensors \
--server.dtype float16 \
--quantization.method compressed-tensors \
--server.host 0.0.0.0 \
--server.port 30000 \
- --server.attention_backend auto \
- --server.gdn_decode_backend pytorch \
- --server.mem_fraction_static 0.05 \
+ --server.mem_fraction_static 0.8 \
--server.max_running_requests 1 \
- --server.max_total_tokens 256 \
- --server.max_prefill_tokens 128 \
- --server.chunked_prefill_size 128 \
+ --server.max_total_tokens 4096 \
--server.disable_radix_cache \
- --server.disable_cuda_graph \
--server.log_level debug
```
-说明:
+BF16 / FP16 原生模型用同一条命令,去掉 `--quantization.method` 即可。
-- `--quantization.method compressed-tensors` 会按模型 `config.json`
- 自动识别 W4A16 或 W8A8 签名。
-- W8A8 路径要求 GPU capability 不低于 SM80。
-- `--server.disable_radix_cache` 会使用 `ChunkCache`,当前已修复该模式下的
- KV slot 泄漏问题。
-- 若 `30000` 已被占用,可改成其他空闲端口。
+## 常用参数
-### BF16 原生模型
+| 参数 | 说明 |
+| --- | --- |
+| `--server.model_path` | 模型权重目录,通常是 HuggingFace 或 ModelScope 格式。 |
+| `--server.tokenizer_path` | tokenizer 目录;不设置时默认等于 `model_path`,一般不用单独传。 |
+| `--server.dtype` | 模型运行 dtype,可选 `auto`、`float16`、`bfloat16`。 |
+| `--quantization.method compressed-tensors` | 启用 `compressed-tensors` 权重加载和量化线性层执行路径。 |
+| `--server.mem_fraction_static` | `模型权重 + KV cache pool` 占 GPU 总显存的静态预算比例。设太小,KV pool 预算不足会导致启动报错;设太大,留给 activation 和 CUDA Graph 的动态空间不够。Jetson 上 Qwen3-VL-2B 量化模型一般在 `0.5`–`0.8` 之间起调。 |
+| `--server.max_running_requests` | 同时运行的请求数。Jetson 小显存环境一般从 `1` 开始调。 |
+| `--server.max_total_tokens` | KV cache token pool 的容量上限,是整个 worker 全局共享的池子(不是单请求上限)。实际容量取 `min(profile 可承载 token 数, max_total_tokens)`,不会绕过显存 profile。 |
+| `--server.disable_radix_cache` | 关闭 Radix Cache,改用 `ChunkCache`。 |
-```bash
-cd
+## OpenAI-compatible 请求
-python3 -m pymllm.server.launch \
- --server.model_path \
- --server.tokenizer_path \
- --server.load_format safetensors \
- --server.dtype float16 \
- --server.host 0.0.0.0 \
- --server.port 30000 \
- --server.attention_backend auto \
- --server.gdn_decode_backend pytorch \
- --server.mem_fraction_static 0.05 \
- --server.max_running_requests 1 \
- --server.max_total_tokens 256 \
- --server.max_prefill_tokens 128 \
- --server.chunked_prefill_size 128 \
- --server.disable_radix_cache \
- --server.disable_cuda_graph \
- --server.log_level debug
-```
-
-## 调用示例
-
-### 健康检查
+健康检查:
```bash
curl -s --noproxy '*' http://127.0.0.1:30000/v1/models ; echo
```
-期望返回中包含:
-
-```text
-"owned_by":"pymllm"
-```
-
-### 文本请求
+文本请求:
```bash
curl -s --noproxy '*' http://127.0.0.1:30000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
- "model": "None",
+ "model": "default",
"messages": [{"role": "user", "content": "你好,只回复:ok"}],
"max_tokens": 8,
"temperature": 0.0,
@@ -182,73 +130,100 @@ curl -s --noproxy '*' http://127.0.0.1:30000/v1/chat/completions \
}' ; echo
```
-### 图文请求
-
-图片路径请使用容器内可访问的绝对路径,不要使用 `file://...` 前缀。
+图文请求里的图片路径要用服务进程可访问的绝对路径,不要带 `file://` 前缀:
```bash
-python3 - <<'PY'
-import json
-
-payload = {
- "model": "None",
- "messages": [
- {
- "role": "user",
- "content": [
- {"type": "text", "text": "请详细描述这张图片。"},
- {"type": "image_url", "image_url": {"url": "/workspace/xcd_mllm/test.png"}},
- ],
- }
- ],
- "max_tokens": 128,
- "temperature": 0.0,
- "stream": False,
+cat > /tmp/mm_req_path.json <<'JSON'
+{
+ "model": "default",
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "请描述这张图片。"},
+ {"type": "image_url", "image_url": {"url": "/workspace/test.png"}}
+ ]
+ }
+ ],
+ "max_tokens": 128,
+ "temperature": 0.0,
+ "stream": false
}
-
-with open("/tmp/mm_req_path.json", "w", encoding="utf-8") as f:
- json.dump(payload, f, ensure_ascii=False)
-
-print("saved /tmp/mm_req_path.json")
-PY
+JSON
curl -s --noproxy '*' http://127.0.0.1:30000/v1/chat/completions \
-H "Content-Type: application/json" \
--data @/tmp/mm_req_path.json ; echo
```
-## 开发与测试
+## Benchmark
-常用单元测试:
+`bench_one_batch` 是一个低层离线 benchmark。它直接初始化
+`pymllm.executor.model_runner.ModelRunner`,绕过 HTTP server、tokenizer、scheduler、
+detokenizer 这些进程,只测模型本身一次静态 prefill 加逐 token decode 的开销,因此适合
+分析模型 forward、KV cache、attention、CUDA Graph 和量化 kernel 的模型级表现,也方便
+验证 fused projection、residual-carry 这类模型图优化。它测不到在线服务的 TTFT / ITL /
+E2E,这两个口径不要混用。
-```bash
-pytest pymllm/tests/test_compressed_tensors_config.py -q
-pytest pymllm/tests/test_compressed_tensors_runtime.py -q
-pytest pymllm/tests/test_qwen3_model_registry.py -q
-pytest pymllm/tests/test_qwen3_weight_loading.py -q
-pytest pymllm/tests/test_qwen3_forward_timing.py -q
-pytest mllm-kernel/tests/test_int8_scaled_mm_cutlass.py -q
-```
+目前 `bench_one_batch` 支持三种测速口径:
+
+- **纯文本**:用 synthetic token ids 测纯文本的 prefill / decode;
+- **视觉编码(vit_prefill)**:同步墙钟只包住视觉 encoder(`self.visual(...)`),反映纯视觉编码速度;
+- **多模态 prefill(multimodal_prefill)**:覆盖“视觉编码 + 图像/文本 token 的 LLM prefill”,反映完整多模态 prefill 速度。
-常用 microbench:
+纯文本用法:
```bash
-python3 pymllm/tests/bench_w8a8_activation_quant.py
-python3 mllm-kernel/benchmarks/bench_int8_scaled_mm.py
-python3 mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py
+PYTHONPATH="$PWD:$PWD/mllm-kernel" python3 -m pymllm.bench_one_batch \
+ --server.model_path \
+ --server.dtype float16 \
+ --quantization.method compressed-tensors \
+ --server.mem_fraction_static 0.8 \
+ --server.max_running_requests 1 \
+ --server.max_total_tokens 2048 \
+ --batch-size 1 \
+ --input-len 256 512 1024 \
+ --output-len 128 \
+ --result-filename
```
-如果需要重新测 CUTLASS 首次编译,可先清理 JIT 缓存:
+`--batch-size`、`--input-len`、`--output-len` 都支持多个值,脚本会遍历所有组合并把结果
+追加到 JSONL 文件。`output_len` 用的是总输出 token 语义:prefill 之后已经拿到第一个
+next token,后续 decode loop 再跑 `output_len - 1` 步。
+
+多模态 prefill 用法。给 `--image` 传一张真实图片,再显式传 `--input-len` 时,长度口径是
+`image placeholder tokens + text prompt tokens` 的目标总长——脚本只在文本 token 上做补齐
+或截断,绝不动 image token,因此可以用同一张图 sweep `314/512/1024/2048` 等不同总长,测
+包含视觉编码的完整多模态 prefill 速度:
```bash
-rm -rf ~/.cache/mllm_kernel/cutlass_int8_scaled_mm/
+PYTHONPATH="$PWD:$PWD/mllm-kernel" python3 -m pymllm.bench_one_batch \
+ --server.model_path \
+ --server.trust_remote_code true \
+ --server.dtype float16 \
+ --quantization.method compressed-tensors \
+ --server.mem_fraction_static 0.8 \
+ --server.max_running_requests 1 \
+ --server.disable_cuda_graph \
+ --batch-size 1 \
+ --input-len 314 512 1024 2048 \
+ --output-len 1 \
+ --image \
+ --prompt "Describe this image." \
+ --result-filename
```
-## 已知限制
+JSONL 里 `vit_prefill_ms` 只包住 `self.visual(...)`,`multimodal_prefill_*` 则是完整
+VIT + LLM prefill 的别名字段,两者口径不同。在 AGX Orin 32GB 上的实测中,W8A8 在长
+prefill 上明显领先 FP16 / W4A16。
+
+### Profile
+
+`bench_one_batch` 内置了 profile 入口,方便在本地看 kernel timeline。目前有两条路径:
-- W8A8 CUTLASS 当前通过 JIT 编译,首次启动存在约 100 秒编译开销。
-- W8A8 激活量化使用 Triton kernel;decode 下固定量化开销仍是后续优化点。
-- Qwen3-VL 的 ViT、`lm_head`、embedding 和 LayerNorm 不在当前 W8A8 量化范围内。
-- 其他 GPU 需要重新验证 tile dispatch、JIT 编译和性能。
-- 为对齐 SGLang/OpenAI 兼容响应,OpenAI API 默认不返回 debug timing。
- 仅在本地诊断时使用 `--server.enable_debug_timing`;严格模型级计时应使用专用 benchmark。
+- **torch.profiler(已支持)**:`--profile-activities CPU GPU`(默认),输出
+ `.trace.json.gz` timeline,可以直接在 Perfetto / chrome://tracing 里看。输出目录由
+ `PYMLLM_TORCH_PROFILER_DIR` 指定,默认 `/tmp`。
+- **Nsight Systems / nsys(实验性)**:`--profile-activities CUDA_PROFILER` 通过
+ `cudaProfilerStart/Stop` 驱动 nsys,需要外层用 `nsys --capture-range=cudaProfilerApi`
+ 包住命令。这条路径还在打磨中,部分场景下可能不够顺手,仅作为可选的深入分析手段。
diff --git a/pymllm/README.md b/pymllm/README.md
index 439f74bc7..0aed552eb 100644
--- a/pymllm/README.md
+++ b/pymllm/README.md
@@ -2,184 +2,131 @@

-`pymllm` is the Python inference and serving entry point for `mllm`. This
-directory currently focuses on Qwen3 / Qwen3-VL serving on Jetson Orin,
-OpenAI-compatible APIs, `compressed-tensors` quantized loading, and the W8A8
-INT8 kernel path.
+## Overview
-This README reflects the development state as of 2026-04-27 for the integration
-branch:
+`pymllm` is mllm's Python / CUDA inference and serving runtime, running mainly
+on NVIDIA Jetson Orin edge GPUs (Orin NX / AGX Orin). It is adapted for the INT8
+throughput of the Orin Ampere Tensor Cores, supports BF16 native inference plus
+two quantization schemes (W4A16 and W8A8_INT8), and currently covers Qwen3,
+Qwen3-VL, and Qwen3.5, exposing an OpenAI-compatible HTTP API.
-```text
-feature/jetson-qwen3-family-bf16-w4a16-w8a8
-```
-
-## Current status
-
-Validated paths:
+## Environment
-- `Qwen3-VL-2B-Instruct`: BF16 base-model serving.
-- `Qwen3-VL-2B-Instruct-AWQ-4bit`: `compressed-tensors` W4A16 / AWQ Marlin
- serving.
-- `Qwen3-VL-2B-Instruct-quantized.w8a8`: `compressed-tensors` W8A8
- `int-quantized` end-to-end serving.
+A known-good set of versions:
-Implemented and unit-tested models/components:
+| Component | Version / notes |
+| --- | --- |
+| JetPack / Jetson Linux | JetPack `6.2.1` / Jetson Linux `36.4.4` (L4T `R36.4.4`) |
+| Python | `3.10.12` |
+| PyTorch | `2.4.0` |
+| torchvision | `0.19.0a0+48b1edf` |
+| transformers | `5.3.0` |
+| safetensors | `0.7.0` |
+| flashinfer | `0.6.7` |
+| Triton Language | `triton==3.6.0` aarch64 wheel |
+| CUDA | `12.6` |
+| GPU | Jetson Orin NX, SM87 |
-- `Qwen3VLForConditionalGeneration`: the main multimodal serving path.
-- `Qwen3ForCausalLM`: text-only model skeleton, weight loading, and timing
- tests.
-- `compressed-tensors`:
- - `pack-quantized` 4-bit weight path via GPTQ Marlin.
- - `int-quantized` W8A8 path via Triton activation quantization and CUTLASS
- `int8_scaled_mm`.
+## Install
-The current W8A8 forward path is:
+Clone the repo, then install `pymllm` and `mllm-kernel` from the repo root:
-```text
-x(fp16/bf16)
- -> per_token_quant_int8 [Triton, dynamic per-token activation quant]
- -> int8_scaled_mm [CUTLASS, INT8 Tensor Core, fused scales]
- -> output(fp16/bf16)
+```bash
+cd
+SKBUILD_WHEEL_CMAKE=false python3 -m pip install -e .
+python3 -m pip install -e /mllm-kernel --no-deps --no-build-isolation
```
-## Validated environment
+`triton` and `flashinfer` have two sources; pick either one:
-The commands below were validated on Jetson Orin with:
+```bash
+# Option 1: Jetson wheels from Jetson AI Lab.
+python3 -m pip install --extra-index-url https://pypi.jetson-ai-lab.io/ triton flashinfer
-- JetPack / L4T: `R36.4.4` (`/etc/nv_tegra_release`)
-- Python: `3.10.12`
-- PyTorch: `2.4.0`
-- torchvision: `0.19.0a0+48b1edf`
-- transformers: `5.3.0`
-- safetensors: `0.7.0`
-- flashinfer: `0.6.7`
-- Triton Language: official PyPI `triton==3.6.0` manylinux aarch64 wheel
-- CUDA: `12.6`
-- GPU: Jetson Orin NX, SM87
+# Option 2: pin Triton from official PyPI, still get FlashInfer from Jetson AI Lab.
+python3 -m pip install --index-url https://pypi.org/simple triton==3.6.0
+python3 -m pip install --extra-index-url https://pypi.jetson-ai-lab.io/ flashinfer
+```
-Triton here means the GPU kernel DSL, not Triton Inference Server. The
-Jetson-AI-Lab index also provides `3.4.0`, `3.5.1`, and `3.6.0`, but the tested
-environment may require extra `TRITON_PTXAS_PATH` and `CPATH` settings with
-those wheels. For this project, prefer the official PyPI `triton==3.6.0` wheel
-and verify it with a minimal CUDA kernel or `per_token_quant_int8` smoke test.
+On aarch64, whether the Triton wheel works out of the box mostly depends on the
+wheel source and the `ptxas` / `cuda.h` lookup paths. In the validated
+environment above, the official PyPI `triton==3.6.0` manylinux aarch64 wheel is
+closest to working out of the box; if a Jetson AI Lab wheel hits `ptxas` or CUDA
+header lookup issues, setting `TRITON_PTXAS_PATH` and `CPATH` explicitly usually
+fixes it. After installing, run a smoke test with a minimal kernel such as
+`per_token_quant_int8` to confirm Triton actually compiles.
-The W8A8 CUTLASS JIT path requires CUTLASS headers. The lookup order is:
+## W8A8 first-run JIT compilation
-1. `CUTLASS_HOME/include`
-2. `flashinfer` bundled `data/cutlass/include`
-3. `/usr/local/include`, `/usr/include`, `/usr/local/cuda/include`
+The W8A8 INT8 GEMM goes through CUTLASS and needs CUTLASS headers. No extra
+setup is required by default — `flashinfer` ships a bundled CUTLASS; set
+`CUTLASS_HOME` if you want to point at your own copy.
-The first CUTLASS kernel call triggers JIT compilation and may take about
-100 seconds. Later runs reuse:
+The first W8A8 kernel call triggers a one-time JIT compile, cached at:
```text
~/.cache/mllm_kernel/cutlass_int8_scaled_mm/
```
-## Install the development environment
-
-Run from the repository root:
-
-```bash
-cd
-SKBUILD_WHEEL_CMAKE=false python3 -m pip install -e .
-python3 -m pip install -e /mllm-kernel --no-deps --no-build-isolation
-```
-
-Run a minimal import check:
-
-```bash
-python3 - <<'PY'
-import pymllm
-import mllm_kernel
-
-print("pymllm import ok")
-print("mllm_kernel import ok")
-PY
-```
+Later runs reuse the cache. To re-check the first-compile behavior, delete this
+directory and run again.
## Launch the server
-### Quantized models (W4A16 / W8A8)
+The entry point is `pymllm.server.launch`. Once up, it serves `/health`,
+`/v1/models`, `/v1/completions`, `/v1/chat/completions`, and `/generate`.
+
+W4A16 / W8A8 quantized models and BF16 base models share the same entry point;
+the runtime reads the quantization config in `config.json` and picks the W4A16
+or W8A8 path automatically. A typical quantized-model launch:
```bash
cd
python3 -m pymllm.server.launch \
--server.model_path \
- --server.tokenizer_path \
- --server.load_format safetensors \
--server.dtype float16 \
--quantization.method compressed-tensors \
--server.host 0.0.0.0 \
--server.port 30000 \
- --server.attention_backend auto \
- --server.gdn_decode_backend pytorch \
- --server.mem_fraction_static 0.05 \
+ --server.mem_fraction_static 0.8 \
--server.max_running_requests 1 \
- --server.max_total_tokens 256 \
- --server.max_prefill_tokens 128 \
- --server.chunked_prefill_size 128 \
+ --server.max_total_tokens 4096 \
--server.disable_radix_cache \
- --server.disable_cuda_graph \
--server.log_level debug
```
-Notes:
+For BF16 / FP16 base models, use the same command and drop
+`--quantization.method`.
-- `--quantization.method compressed-tensors` reads the model `config.json` and
- selects the W4A16 or W8A8 signature automatically.
-- W8A8 requires SM80 or newer GPUs.
-- `--server.disable_radix_cache` uses `ChunkCache`; the KV slot leak in this
- mode has been fixed.
-- If port `30000` is already in use, switch to another free port.
+## Common parameters
-### BF16 base models
-
-```bash
-cd
-
-python3 -m pymllm.server.launch \
- --server.model_path \
- --server.tokenizer_path \
- --server.load_format safetensors \
- --server.dtype float16 \
- --server.host 0.0.0.0 \
- --server.port 30000 \
- --server.attention_backend auto \
- --server.gdn_decode_backend pytorch \
- --server.mem_fraction_static 0.05 \
- --server.max_running_requests 1 \
- --server.max_total_tokens 256 \
- --server.max_prefill_tokens 128 \
- --server.chunked_prefill_size 128 \
- --server.disable_radix_cache \
- --server.disable_cuda_graph \
- --server.log_level debug
-```
+| Parameter | Description |
+| --- | --- |
+| `--server.model_path` | Model weight directory, usually HuggingFace or ModelScope format. |
+| `--server.tokenizer_path` | Tokenizer directory; defaults to `model_path` when unset, so you rarely pass it. |
+| `--server.dtype` | Runtime dtype: `auto`, `float16`, or `bfloat16`. |
+| `--quantization.method compressed-tensors` | Enables `compressed-tensors` weight loading and the quantized linear path. |
+| `--server.mem_fraction_static` | Static budget for `model weights + KV cache pool` as a fraction of total GPU memory. Too small fails to allocate the KV pool at startup; too large leaves no room for activations and CUDA Graph. On Jetson, Qwen3-VL-2B usually starts around `0.5`–`0.8`. |
+| `--server.max_running_requests` | Concurrent requests. On small-VRAM Jetson, start from `1`. |
+| `--server.max_total_tokens` | Upper bound on the KV cache token pool, shared globally across the worker (not a per-request limit). Actual capacity is `min(profiled tokens, max_total_tokens)` and does not bypass the memory profile. |
+| `--server.disable_radix_cache` | Disables Radix Cache, uses `ChunkCache` instead. |
-## Request examples
+## OpenAI-compatible requests
-### Health check
+Health check:
```bash
curl -s --noproxy '*' http://127.0.0.1:30000/v1/models ; echo
```
-Expected response contains:
-
-```text
-"owned_by":"pymllm"
-```
-
-### Text request
+Text request:
```bash
curl -s --noproxy '*' http://127.0.0.1:30000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
- "model": "None",
+ "model": "default",
"messages": [{"role": "user", "content": "Reply with: ok"}],
"max_tokens": 8,
"temperature": 0.0,
@@ -187,79 +134,111 @@ curl -s --noproxy '*' http://127.0.0.1:30000/v1/chat/completions \
}' ; echo
```
-### Image request
-
-Use a container-visible absolute image path. Do not use the `file://...`
-prefix.
+For image requests, use an absolute path the server process can read; do not use
+the `file://` prefix:
```bash
-python3 - <<'PY'
-import json
-
-payload = {
- "model": "None",
- "messages": [
- {
- "role": "user",
- "content": [
- {"type": "text", "text": "Please describe this image in detail."},
- {"type": "image_url", "image_url": {"url": "/workspace/xcd_mllm/test.png"}},
- ],
- }
- ],
- "max_tokens": 128,
- "temperature": 0.0,
- "stream": False,
+cat > /tmp/mm_req_path.json <<'JSON'
+{
+ "model": "default",
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Please describe this image."},
+ {"type": "image_url", "image_url": {"url": "/workspace/test.png"}}
+ ]
+ }
+ ],
+ "max_tokens": 128,
+ "temperature": 0.0,
+ "stream": false
}
-
-with open("/tmp/mm_req_path.json", "w", encoding="utf-8") as f:
- json.dump(payload, f, ensure_ascii=False)
-
-print("saved /tmp/mm_req_path.json")
-PY
+JSON
curl -s --noproxy '*' http://127.0.0.1:30000/v1/chat/completions \
-H "Content-Type: application/json" \
--data @/tmp/mm_req_path.json ; echo
```
-## Development and tests
+## Benchmark
-Common unit tests:
+`bench_one_batch` is a low-level offline benchmark. It initializes
+`pymllm.executor.model_runner.ModelRunner` directly, bypassing the HTTP server,
+tokenizer, scheduler, and detokenizer, and only measures one static prefill plus
+per-token decode of the model itself. It is good for analyzing model forward, KV
+cache, attention, CUDA Graph, and quantized kernels at the model level, and for
+checking graph optimizations such as fused projection and residual-carry. It
+does not measure online-serving TTFT / ITL / E2E — don't mix the two.
-```bash
-pytest pymllm/tests/test_compressed_tensors_config.py -q
-pytest pymllm/tests/test_compressed_tensors_runtime.py -q
-pytest pymllm/tests/test_qwen3_model_registry.py -q
-pytest pymllm/tests/test_qwen3_weight_loading.py -q
-pytest pymllm/tests/test_qwen3_forward_timing.py -q
-pytest mllm-kernel/tests/test_int8_scaled_mm_cutlass.py -q
-```
+`bench_one_batch` supports three measurement modes:
+
+- **Text-only**: prefill / decode with synthetic token ids.
+- **Vision encoding (`vit_prefill`)**: a synchronized wall clock around the
+ vision encoder (`self.visual(...)`) only, reflecting pure vision-encode speed.
+- **Multimodal prefill (`multimodal_prefill`)**: covers "vision encoding + LLM
+ prefill over image/text tokens", reflecting full multimodal prefill speed.
-Common microbenchmarks:
+Text-only:
```bash
-python3 pymllm/tests/bench_w8a8_activation_quant.py
-python3 mllm-kernel/benchmarks/bench_int8_scaled_mm.py
-python3 mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py
+PYTHONPATH="$PWD:$PWD/mllm-kernel" python3 -m pymllm.bench_one_batch \
+ --server.model_path \
+ --server.dtype float16 \
+ --quantization.method compressed-tensors \
+ --server.mem_fraction_static 0.8 \
+ --server.max_running_requests 1 \
+ --server.max_total_tokens 2048 \
+ --batch-size 1 \
+ --input-len 256 512 1024 \
+ --output-len 128 \
+ --result-filename
```
-To measure first-use CUTLASS compilation again, clear the JIT cache:
+`--batch-size`, `--input-len`, and `--output-len` all accept multiple values;
+the script sweeps every combination and appends results to the JSONL file.
+`output_len` uses total-output-token semantics: the first next token is already
+produced after prefill, so the decode loop runs `output_len - 1` more steps.
+
+Multimodal prefill: pass a real image to `--image`, and when you also pass
+`--input-len` explicitly, the length means the target total of
+`image placeholder tokens + text prompt tokens` — the script only pads or
+truncates text tokens, never image tokens. So you can sweep different totals such
+as `314/512/1024/2048` on the same image to measure full multimodal prefill
+including vision encoding:
```bash
-rm -rf ~/.cache/mllm_kernel/cutlass_int8_scaled_mm/
+PYTHONPATH="$PWD:$PWD/mllm-kernel" python3 -m pymllm.bench_one_batch \
+ --server.model_path \
+ --server.trust_remote_code true \
+ --server.dtype float16 \
+ --quantization.method compressed-tensors \
+ --server.mem_fraction_static 0.8 \
+ --server.max_running_requests 1 \
+ --server.disable_cuda_graph \
+ --batch-size 1 \
+ --input-len 314 512 1024 2048 \
+ --output-len 1 \
+ --image \
+ --prompt "Describe this image." \
+ --result-filename
```
-## Known limitations
-
-- The W8A8 CUTLASS path is JIT-compiled, so first startup includes about
- 100 seconds of compilation overhead.
-- W8A8 activation quantization uses a Triton kernel; its fixed decode-time
- cost remains a future optimization target.
-- Qwen3-VL ViT, `lm_head`, embeddings, and LayerNorm are outside the current
- W8A8 quantized scope.
-- Other GPUs need separate validation for tile dispatch, JIT compilation, and
- performance.
-- OpenAI-compatible responses hide debug timing by default for SGLang/OpenAI
- compatibility. Use `--server.enable_debug_timing` only for local diagnostics;
- strict model-level timing should use dedicated benchmarks.
+In the JSONL, `vit_prefill_ms` wraps only `self.visual(...)`, while
+`multimodal_prefill_*` are alias fields for the full VIT + LLM prefill — the two
+have different scopes. In measurements on AGX Orin 32GB, W8A8 clearly leads FP16
+/ W4A16 on long prefill.
+
+### Profile
+
+`bench_one_batch` has a built-in profile entry for inspecting kernel timelines
+locally. There are two paths:
+
+- **torch.profiler (supported)**: `--profile-activities CPU GPU` (default),
+ emits a `.trace.json.gz` timeline you can open in Perfetto / chrome://tracing.
+ The output directory is set by `PYMLLM_TORCH_PROFILER_DIR`, defaulting to
+ `/tmp`.
+- **Nsight Systems / nsys (experimental)**: `--profile-activities CUDA_PROFILER`
+ drives nsys via `cudaProfilerStart/Stop`, and needs an outer wrapper like
+ `nsys --capture-range=cudaProfilerApi`. This path is still being polished and
+ may be rough in places; treat it as an optional deep-dive tool.
diff --git a/pymllm/bench_one_batch.py b/pymllm/bench_one_batch.py
index a62be2bb2..190f92434 100644
--- a/pymllm/bench_one_batch.py
+++ b/pymllm/bench_one_batch.py
@@ -13,8 +13,8 @@
import os
import re
import statistics
+import sys
import time
-from contextlib import contextmanager, nullcontext
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Iterator, Optional, Sequence
@@ -51,6 +51,10 @@ class BenchArgs:
profile_start_step: Optional[int] = None
profile_steps: int = 1
skip_warmup: bool = False
+ image_path: Optional[Path] = None
+ prompt: str = "Describe this image."
+ input_len_was_provided: bool = False
+ correctness_test: bool = False
@dataclass
@@ -60,6 +64,34 @@ class DecodeState:
mrope_position_deltas: Optional[torch.Tensor] = None
+@dataclass
+class ExtendResult:
+ next_token_ids: torch.Tensor
+ state: DecodeState
+ vit_prefill_ms: Optional[float] = None
+ vit_prefill_tokens: Optional[int] = None
+ vit_prefill_tps: Optional[float] = None
+
+ def __iter__(self) -> Iterator[Any]:
+ # Preserve the old ``next_token_ids, state = extend(...)`` call pattern.
+ yield self.next_token_ids
+ yield self.state
+
+
+@dataclass
+class MultimodalBenchInput:
+ input_ids: torch.Tensor
+ pixel_values: torch.Tensor
+ image_grid_thw: torch.Tensor
+ vit_prefill_tokens: int
+
+
+@dataclass
+class MultimodalProcessorBundle:
+ processor_output: Any
+ pad_token_id: int
+
+
def _positive_int(value: str) -> int:
parsed = int(value)
if parsed <= 0:
@@ -121,8 +153,12 @@ def add_bench_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
group.add_argument(
"--profile-activities",
nargs="+",
- choices=["CPU", "GPU"],
+ choices=["CPU", "GPU", "CUDA_PROFILER"],
default=["CPU", "GPU"],
+ help=(
+ "CPU/GPU use the torch profiler; CUDA_PROFILER drives nsys via "
+ "cudaProfilerStart/Stop (use with nsys --capture-range=cudaProfilerApi)."
+ ),
)
group.add_argument(
"--profile-stage",
@@ -150,6 +186,35 @@ def add_bench_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
action="store_true",
help="Skip the initial non-recorded warmup run.",
)
+ group.add_argument(
+ "--image",
+ "--image-path",
+ dest="image_path",
+ type=Path,
+ default=None,
+ help=(
+ "Optional image path for multimodal benchmark mode. When set, "
+ "omitted --input-len uses the processed prompt length; explicit "
+ "--input-len sweeps target total multimodal prefill length "
+ "(image placeholder tokens + text tokens)."
+ ),
+ )
+ group.add_argument(
+ "--prompt",
+ default=BenchArgs.prompt,
+ help="Prompt text used with --image in multimodal benchmark mode.",
+ )
+ group.add_argument(
+ "--correct",
+ "--correctness-test",
+ dest="correctness_test",
+ action="store_true",
+ help=(
+ "Run a single-stage smoke correctness check (encode real prompts, "
+ "prefill, greedy decode, print decoded text) instead of the "
+ "latency benchmark."
+ ),
+ )
return parser
@@ -163,7 +228,15 @@ def make_parser() -> argparse.ArgumentParser:
return parser
-def _bench_args_from_namespace(namespace: argparse.Namespace) -> BenchArgs:
+def _argv_has_option(argv: Sequence[str], option: str) -> bool:
+ return any(arg == option or arg.startswith(f"{option}=") for arg in argv)
+
+
+def _bench_args_from_namespace(
+ namespace: argparse.Namespace,
+ *,
+ input_len_was_provided: bool = False,
+) -> BenchArgs:
return BenchArgs(
run_name=namespace.run_name,
batch_size=list(namespace.batch_size),
@@ -180,6 +253,10 @@ def _bench_args_from_namespace(namespace: argparse.Namespace) -> BenchArgs:
profile_start_step=namespace.profile_start_step,
profile_steps=namespace.profile_steps,
skip_warmup=namespace.skip_warmup,
+ image_path=namespace.image_path,
+ prompt=namespace.prompt,
+ input_len_was_provided=input_len_was_provided,
+ correctness_test=namespace.correctness_test,
)
@@ -187,16 +264,25 @@ def parse_args(
argv: Optional[Sequence[str]] = None,
) -> tuple[GlobalConfig, BenchArgs]:
parser = make_parser()
- cfg = read_args(argv=argv, parser=parser)
- namespace = parser.parse_args(argv)
- return cfg, _bench_args_from_namespace(namespace)
+ cli_argv = list(sys.argv[1:] if argv is None else argv)
+ cfg = read_args(argv=cli_argv, parser=parser)
+ namespace = parser.parse_args(cli_argv)
+ return cfg, _bench_args_from_namespace(
+ namespace,
+ input_len_was_provided=_argv_has_option(cli_argv, "--input-len"),
+ )
def generate_settings(args: BenchArgs) -> list[BenchSetting]:
+ input_lens = (
+ [0]
+ if args.image_path is not None and not args.input_len_was_provided
+ else args.input_len
+ )
return [
BenchSetting(batch_size=batch_size, input_len=input_len, output_len=output_len)
for batch_size in args.batch_size
- for input_len in args.input_len
+ for input_len in input_lens
for output_len in args.output_len
]
@@ -267,6 +353,158 @@ def summarize_latencies(
return result
+def make_vit_prefill_metrics(
+ *,
+ vit_prefill_ms: float,
+ vit_prefill_tokens: int,
+) -> dict[str, Any]:
+ latency = float(vit_prefill_ms) / 1000.0
+ throughput = _safe_div(float(vit_prefill_tokens), latency)
+ return {
+ "vit_prefill_latency": latency,
+ "vit_prefill_ms": float(vit_prefill_ms),
+ "vit_prefill_tokens": int(vit_prefill_tokens),
+ "vit_prefill_throughput": throughput,
+ "vit_prefill_tps": throughput,
+ }
+
+
+def make_multimodal_prefill_metrics(
+ *,
+ prefill_latency: float,
+ batch_size: int,
+ input_len: int,
+) -> dict[str, Any]:
+ tokens = int(batch_size) * int(input_len)
+ throughput = _safe_div(float(tokens), float(prefill_latency))
+ return {
+ "multimodal_prefill_latency": float(prefill_latency),
+ "multimodal_prefill_ms": float(prefill_latency) * 1000.0,
+ "multimodal_prefill_tokens": tokens,
+ "multimodal_prefill_throughput": throughput,
+ "multimodal_prefill_tps": throughput,
+ }
+
+
+def _get_processor_value(processor_output: Any, key: str) -> Any:
+ if hasattr(processor_output, "get"):
+ return processor_output.get(key)
+ return getattr(processor_output, key, None)
+
+
+def _resize_multimodal_input_ids(
+ input_ids: torch.Tensor,
+ *,
+ target_input_len: int,
+ image_token_id: int,
+ pad_token_id: int,
+) -> torch.Tensor:
+ if int(pad_token_id) == int(image_token_id):
+ pad_token_id = 0 if int(image_token_id) != 0 else 1
+ if target_input_len <= 0:
+ raise ValueError(
+ f"target_input_len must be positive, got {target_input_len}"
+ )
+ if input_ids.dim() != 2 or input_ids.shape[0] != 1:
+ raise ValueError(
+ "bench_one_batch multimodal resize expects input_ids shape [1, seq_len], "
+ f"got {tuple(input_ids.shape)}"
+ )
+
+ seq = input_ids[0]
+ image_mask = seq == image_token_id
+ image_token_count = int(image_mask.sum().item())
+ if target_input_len < image_token_count:
+ raise ValueError(
+ "target_input_len must be at least the number of image tokens "
+ f"({image_token_count}), got {target_input_len}"
+ )
+ if int(seq.numel()) == target_input_len:
+ return input_ids
+
+ text_budget = target_input_len - image_token_count
+ resized_tokens: list[int] = []
+ kept_text = 0
+ for token in seq.tolist():
+ token_id = int(token)
+ if token_id == image_token_id:
+ resized_tokens.append(token_id)
+ elif kept_text < text_budget:
+ resized_tokens.append(token_id)
+ kept_text += 1
+ if len(resized_tokens) == target_input_len:
+ break
+
+ if len(resized_tokens) < target_input_len:
+ resized_tokens.extend([int(pad_token_id)] * (target_input_len - len(resized_tokens)))
+
+ return torch.tensor(
+ [resized_tokens],
+ dtype=input_ids.dtype,
+ device=input_ids.device,
+ )
+
+
+def make_multimodal_bench_input_from_processor_output(
+ processor_output: Any,
+ *,
+ batch_size: int,
+ image_token_id: int,
+ device: str | torch.device,
+ target_input_len: Optional[int] = None,
+ pad_token_id: int = 0,
+) -> MultimodalBenchInput:
+ input_ids = _get_processor_value(processor_output, "input_ids")
+ pixel_values = _get_processor_value(processor_output, "pixel_values")
+ image_grid_thw = _get_processor_value(processor_output, "image_grid_thw")
+
+ if input_ids is None:
+ raise ValueError("Multimodal processor output does not contain input_ids")
+ if pixel_values is None:
+ raise ValueError("Multimodal processor output does not contain pixel_values")
+ if image_grid_thw is None:
+ raise ValueError(
+ "Multimodal processor output does not contain image_grid_thw"
+ )
+
+ input_ids_t = torch.as_tensor(input_ids)
+ if input_ids_t.dim() == 1:
+ input_ids_t = input_ids_t.unsqueeze(0)
+ if input_ids_t.shape[0] != 1:
+ raise ValueError(
+ "bench_one_batch multimodal mode expects one processed prompt before "
+ f"batch repetition, got batch dimension {input_ids_t.shape[0]}"
+ )
+ if target_input_len is not None:
+ input_ids_t = _resize_multimodal_input_ids(
+ input_ids_t,
+ target_input_len=target_input_len,
+ image_token_id=image_token_id,
+ pad_token_id=pad_token_id,
+ )
+
+ input_ids_t = input_ids_t.repeat(batch_size, 1).to(
+ device=device, dtype=torch.int32
+ )
+ pixel_values_t = torch.as_tensor(pixel_values)
+ pixel_values_t = pixel_values_t.repeat(
+ (batch_size,) + (1,) * (pixel_values_t.dim() - 1)
+ )
+ image_grid_thw_t = torch.as_tensor(image_grid_thw)
+ if image_grid_thw_t.dim() == 1:
+ image_grid_thw_t = image_grid_thw_t.unsqueeze(0)
+ image_grid_thw_t = image_grid_thw_t.repeat(batch_size, 1).to(
+ device=device, dtype=torch.int64
+ )
+
+ return MultimodalBenchInput(
+ input_ids=input_ids_t,
+ pixel_values=pixel_values_t.to(device=device),
+ image_grid_thw=image_grid_thw_t,
+ vit_prefill_tokens=int((input_ids_t == image_token_id).sum().item()),
+ )
+
+
def make_profile_trace_path(
*,
output_dir: Path,
@@ -282,7 +520,7 @@ def make_profile_trace_path(
filename = (
f"{safe_prefix}_{safe_run_name}_bs{setting.batch_size}"
f"_in{setting.input_len}_out{setting.output_len}_{stage}"
- f"{step_part}.trace.json"
+ f"{step_part}.trace.json.gz"
)
return output_dir / filename
@@ -298,6 +536,20 @@ def _safe_div(numerator: float, denominator: float) -> float:
return float(numerator / denominator)
+def _max_batch_size_for(runner: Any, input_len: int, output_len: int) -> int:
+ """SGLang-style capacity bound on the static batch.
+
+ Mirrors ``ModelRunner.max_batch_size`` in SGLang's bench_one_batch:
+ ``max_total_num_tokens // (input_len + output_len)``. Used to skip
+ settings the KV pool cannot hold instead of failing mid-run on alloc.
+ """
+ total = int(getattr(runner, "max_total_num_tokens", 0) or 0)
+ denom = int(input_len) + int(output_len)
+ if denom <= 0:
+ return 0
+ return total // denom
+
+
def _sync_device(device: str | torch.device) -> None:
torch_device = torch.device(device)
if torch_device.type == "cuda":
@@ -330,6 +582,83 @@ def _load_hf_config(cfg: GlobalConfig) -> None:
logger.info("Loaded model config: %s", cfg.model.hf_config.__class__.__name__)
+def _extract_image_token_id(hf_config: Any) -> int:
+ image_token_id = getattr(hf_config, "image_token_id", None)
+ if image_token_id is None:
+ raise ValueError("Model config does not define image_token_id")
+ return int(image_token_id)
+
+
+def _render_multimodal_prompt(
+ processor: Any,
+ *,
+ prompt: str,
+ image_path: Path,
+) -> str:
+ if not hasattr(processor, "apply_chat_template"):
+ return prompt
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "image": str(image_path)},
+ {"type": "text", "text": prompt},
+ ],
+ }
+ ]
+ try:
+ rendered = processor.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ )
+ except Exception as exc:
+ logger.warning("Processor chat template failed, using raw prompt: %s", exc)
+ return prompt
+ if isinstance(rendered, list):
+ if not rendered:
+ return prompt
+ return str(rendered[0])
+ return str(rendered)
+
+
+def _make_multimodal_processor_output(
+ *,
+ cfg: GlobalConfig,
+ prompt: str,
+ image_path: Path,
+) -> MultimodalProcessorBundle:
+ if cfg.server.tokenizer_path is None:
+ raise ValueError("--server.tokenizer_path or --server.model_path is required")
+ if not image_path.exists():
+ raise FileNotFoundError(f"Image path does not exist: {image_path}")
+
+ from PIL import Image
+ from transformers import AutoProcessor
+
+ processor = AutoProcessor.from_pretrained(
+ str(cfg.server.tokenizer_path),
+ trust_remote_code=cfg.server.trust_remote_code,
+ )
+ tokenizer = getattr(processor, "tokenizer", None)
+ pad_token_id = getattr(tokenizer, "pad_token_id", None)
+ if pad_token_id is None:
+ pad_token_id = getattr(tokenizer, "eos_token_id", None)
+ if pad_token_id is None:
+ pad_token_id = 0
+ image = Image.open(image_path).convert("RGB")
+ text = _render_multimodal_prompt(
+ processor,
+ prompt=prompt,
+ image_path=image_path,
+ )
+ return MultimodalProcessorBundle(
+ processor_output=processor(images=[image], text=[text], return_tensors="pt"),
+ pad_token_id=int(pad_token_id),
+ )
+
+
def _append_jsonl(path: Path, row: dict[str, Any]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("a", encoding="utf-8") as fp:
@@ -354,45 +683,79 @@ def _profiler_activities(args: BenchArgs) -> list[Any]:
return activities
-@contextmanager
-def _maybe_profile(
- *,
- args: BenchArgs,
- setting: BenchSetting,
- stage: str,
- step: Optional[int] = None,
-) -> Iterator[None]:
- if not _profile_stage_enabled(args, stage):
- with nullcontext():
- yield
- return
+def _resolve_profile_output_dir() -> Path:
+ output_dir = Path(os.environ.get("PYMLLM_TORCH_PROFILER_DIR", "/tmp"))
+ output_dir.mkdir(parents=True, exist_ok=True)
+ return output_dir
+
+
+def _start_profile(args: BenchArgs, trace_path: Path) -> Any:
+ """Start profiling and return a handle.
+
+ Mirrors SGLang's ``start_profile``: ``CUDA_PROFILER`` drives nsys via
+ ``cudaProfilerStart``; otherwise a torch profiler with ``with_stack=True``
+ is started so kernels can be mapped back to Python source. Returns
+ ``"cuda"`` for the nsys path, the profiler object for the torch path, or
+ ``None`` when no activity is available. ``trace_path`` is accepted for
+ symmetry with ``_stop_profile`` (the torch path saves it on stop).
+ """
+ if "CUDA_PROFILER" in args.profile_activities:
+ try:
+ torch.cuda.cudart().cudaProfilerStart()
+ logger.info("CUDA profiler started (nsys will begin capturing).")
+ except Exception as exc: # pragma: no cover - depends on nsys runtime
+ logger.warning("Failed to start CUDA profiler: %s", exc)
+ return "cuda"
activities = _profiler_activities(args)
if not activities:
- with nullcontext():
- yield
- return
+ return None
from torch.profiler import profile
- output_dir = Path(os.environ.get("PYMLLM_TORCH_PROFILER_DIR", "/tmp"))
- output_dir.mkdir(parents=True, exist_ok=True)
- trace_path = make_profile_trace_path(
- output_dir=output_dir,
- prefix=args.profile_filename_prefix,
- run_name=args.run_name,
- setting=setting,
- stage=stage,
- step=step,
- )
- with profile(
+ profiler = profile(
activities=activities,
+ with_stack=True,
record_shapes=args.profile_record_shapes,
- ) as profiler:
- yield
- profiler.step()
- profiler.export_chrome_trace(str(trace_path))
- logger.info("Wrote torch profiler trace: %s", trace_path)
+ )
+ profiler.start()
+ return profiler
+
+
+def _stop_profile(handle: Any, args: BenchArgs, trace_path: Path, stage: str) -> None:
+ """Stop profiling and, for the torch path, save the chrome trace.
+
+ Mirrors SGLang's ``stop_profile``, including printing the key_averages
+ table. The trace is written as ``.trace.json.gz`` (torch gzips when the
+ filename ends with ``.gz``).
+ """
+ if handle is None:
+ return
+ if handle == "cuda":
+ try:
+ torch.cuda.cudart().cudaProfilerStop()
+ logger.info("CUDA profiler stopped for %s (nsys dumps traces).", stage)
+ except Exception as exc: # pragma: no cover - depends on nsys runtime
+ logger.warning("Failed to stop CUDA profiler: %s", exc)
+ return
+
+ handle.stop()
+ trace_path.parent.mkdir(parents=True, exist_ok=True)
+ handle.export_chrome_trace(str(trace_path))
+ try:
+ sort_key = (
+ "self_cuda_time_total"
+ if torch.cuda.is_available()
+ else "self_cpu_time_total"
+ )
+ print(
+ handle.key_averages(
+ group_by_input_shape=args.profile_record_shapes
+ ).table(sort_by=sort_key)
+ )
+ except Exception as exc:
+ logger.warning("Failed to print profiler key_averages: %s", exc)
+ logger.info("Wrote torch profiler trace for %s: %s", stage, trace_path)
class PymllmBenchRunner:
@@ -420,7 +783,14 @@ def clear(self) -> None:
self.runner.req_to_token_pool.clear()
self.runner.token_to_kv_pool_allocator.clear()
- def extend(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, DecodeState]:
+ def extend(
+ self,
+ input_ids: torch.Tensor,
+ *,
+ pixel_values: Optional[torch.Tensor] = None,
+ image_grid_thw: Optional[torch.Tensor] = None,
+ benchmark_vision_timing: bool = False,
+ ) -> ExtendResult:
if input_ids.dim() != 2:
raise ValueError("input_ids must have shape [batch_size, input_len]")
@@ -468,6 +838,12 @@ def extend(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, DecodeState]:
extend_prefix_lens=extend_prefix_lens,
out_cache_loc=out_cache_loc.to(torch.int64),
)
+ if pixel_values is not None:
+ forward_batch.pixel_values = pixel_values.to(device=self.device)
+ if image_grid_thw is not None:
+ forward_batch.image_grid_thw = image_grid_thw.to(device=self.device)
+ forward_batch.benchmark_vision_timing = benchmark_vision_timing
+
logits_output = self.runner.forward(forward_batch)
next_token_ids = self._sample_greedy(logits_output, forward_batch)
state = DecodeState(
@@ -477,7 +853,22 @@ def extend(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, DecodeState]:
forward_batch, "mrope_position_deltas", None
),
)
- return next_token_ids, state
+ vit_prefill_ms = getattr(forward_batch, "vit_prefill_ms", None)
+ vit_prefill_tokens = getattr(forward_batch, "vit_prefill_tokens", None)
+ vit_prefill_tps = getattr(forward_batch, "vit_prefill_tps", None)
+ return ExtendResult(
+ next_token_ids=next_token_ids,
+ state=state,
+ vit_prefill_ms=(
+ float(vit_prefill_ms) if vit_prefill_ms is not None else None
+ ),
+ vit_prefill_tokens=(
+ int(vit_prefill_tokens) if vit_prefill_tokens is not None else None
+ ),
+ vit_prefill_tps=(
+ float(vit_prefill_tps) if vit_prefill_tps is not None else None
+ ),
+ )
def decode(
self,
@@ -496,13 +887,18 @@ def decode(
raise RuntimeError(f"Failed to allocate {batch_size} decode KV slots")
seq_lens = state.seq_lens + 1
- for i in range(batch_size):
- slot = int(state.req_pool_indices[i].item())
- write_pos = int(seq_lens[i].item()) - 1
- self.runner.req_to_token_pool.write(
- (slot, slice(write_pos, write_pos + 1)),
- out_cache_loc[i : i + 1],
- )
+ # Tensorized KV-mapping write. The production decode path
+ # (orchestrator/model_runner_process.py) keeps slot/write_pos as plain
+ # CPU bookkeeping and never does a per-request CUDA ``.item()`` sync.
+ # Doing per-request ``.item()`` here would add 2*batch_size CPU-GPU
+ # syncs inside the timed decode region that SGLang does not have,
+ # biasing decode latency once batch_size > 1. Write all rows at once:
+ # req_to_token[req_pool_indices, seq_lens - 1] = out_cache_loc.
+ write_positions = (seq_lens - 1).to(torch.int64)
+ self.runner.req_to_token_pool.write(
+ (state.req_pool_indices, write_positions),
+ out_cache_loc.to(torch.int32),
+ )
forward_batch = self.runner.prepare_forward_batch_decode(
input_ids=input_ids.to(device=self.device, dtype=torch.int32),
@@ -532,6 +928,7 @@ def _sample_greedy(self, logits_output: Any, forward_batch: Any) -> torch.Tensor
logits_output,
forward_batch,
temperatures=temperatures,
+ is_all_greedy=True,
).to(torch.int32)
def _require_initialized(self) -> None:
@@ -561,46 +958,120 @@ def run_single_setting(
setting: BenchSetting,
seed: int,
record_result: bool,
+ multimodal_processor_bundle: Optional[MultimodalProcessorBundle] = None,
+ allow_profile: bool = True,
) -> Optional[dict[str, Any]]:
bench_runner.clear()
vocab_size = getattr(bench_runner.runner, "vocab_size", 10000)
- input_ids = make_synthetic_input_ids(
- batch_size=setting.batch_size,
- input_len=setting.input_len,
- vocab_size=vocab_size,
- seed=seed,
- device=bench_runner.device,
+ mm_input = None
+ effective_setting = setting
+ if multimodal_processor_bundle is None:
+ input_ids = make_synthetic_input_ids(
+ batch_size=setting.batch_size,
+ input_len=setting.input_len,
+ vocab_size=vocab_size,
+ seed=seed,
+ device=bench_runner.device,
+ )
+ else:
+ hf_config = bench_runner.runner.model_config.hf_config
+ image_token_id = _extract_image_token_id(hf_config)
+ mm_input = make_multimodal_bench_input_from_processor_output(
+ multimodal_processor_bundle.processor_output,
+ batch_size=setting.batch_size,
+ image_token_id=image_token_id,
+ device=bench_runner.device,
+ target_input_len=(
+ setting.input_len
+ if args.input_len_was_provided
+ else None
+ ),
+ pad_token_id=multimodal_processor_bundle.pad_token_id,
+ )
+ input_ids = mm_input.input_ids
+ effective_setting = BenchSetting(
+ batch_size=setting.batch_size,
+ input_len=int(input_ids.shape[1]),
+ output_len=setting.output_len,
+ )
+
+ max_bs = _max_batch_size_for(
+ bench_runner.runner,
+ effective_setting.input_len,
+ effective_setting.output_len,
)
+ if effective_setting.batch_size > max_bs:
+ logger.info(
+ "skipping (batch_size=%d, input_len=%d, output_len=%d): exceeds max "
+ "batch size %d (max_total_num_tokens=%d). SGLang-style skip.",
+ effective_setting.batch_size,
+ effective_setting.input_len,
+ effective_setting.output_len,
+ max_bs,
+ int(getattr(bench_runner.runner, "max_total_num_tokens", 0) or 0),
+ )
+ return None
- with _maybe_profile(args=args, setting=setting, stage="prefill"):
- prefill_latency, extend_result = _timed_call(
- bench_runner.device,
- lambda: bench_runner.extend(input_ids),
+ prefill_profile = allow_profile and _profile_stage_enabled(args, "prefill")
+ prefill_trace: Optional[Path] = None
+ prefill_handle: Any = None
+ if prefill_profile:
+ prefill_trace = make_profile_trace_path(
+ output_dir=_resolve_profile_output_dir(),
+ prefix=args.profile_filename_prefix,
+ run_name=args.run_name,
+ setting=effective_setting,
+ stage="prefill",
)
+ prefill_handle = _start_profile(args, prefill_trace)
+ prefill_latency, extend_result = _timed_call(
+ bench_runner.device,
+ lambda: bench_runner.extend(
+ input_ids,
+ pixel_values=mm_input.pixel_values if mm_input is not None else None,
+ image_grid_thw=(
+ mm_input.image_grid_thw if mm_input is not None else None
+ ),
+ benchmark_vision_timing=mm_input is not None,
+ ),
+ )
+ if prefill_profile:
+ _stop_profile(prefill_handle, args, prefill_trace, "prefill")
next_token_ids, state = extend_result
decode_latencies: list[float] = []
decode_steps = max(0, setting.output_len - 1)
+ decode_profile = allow_profile and _profile_stage_enabled(args, "decode")
profile_start_step = args.profile_start_step
if profile_start_step is None:
- profile_start_step = decode_steps // 2 if decode_steps else 0
+ # Align SGLang: default to output_len // 2.
+ profile_start_step = effective_setting.output_len // 2
profile_stop_step = profile_start_step + args.profile_steps
+ decode_trace: Optional[Path] = None
+ decode_handle: Any = None
+ # One continuous profiler spans [profile_start_step, profile_stop_step),
+ # producing a single decode trace, matching SGLang (not one file per step).
for step in range(decode_steps):
- should_profile_decode = (
- _profile_stage_enabled(args, "decode")
- and profile_start_step <= step < profile_stop_step
- )
- profile_context = (
- _maybe_profile(args=args, setting=setting, stage="decode", step=step)
- if should_profile_decode
- else nullcontext()
- )
- with profile_context:
- decode_latency, decode_result = _timed_call(
- bench_runner.device,
- lambda: bench_runner.decode(next_token_ids, state),
+ if decode_profile and step == profile_start_step:
+ decode_trace = make_profile_trace_path(
+ output_dir=_resolve_profile_output_dir(),
+ prefix=args.profile_filename_prefix,
+ run_name=args.run_name,
+ setting=effective_setting,
+ stage="decode",
)
+ decode_handle = _start_profile(args, decode_trace)
+
+ decode_latency, decode_result = _timed_call(
+ bench_runner.device,
+ lambda: bench_runner.decode(next_token_ids, state),
+ )
+
+ if decode_handle is not None and step >= profile_stop_step - 1:
+ _stop_profile(decode_handle, args, decode_trace, "decode")
+ decode_handle = None
+
next_token_ids, state = decode_result
decode_latencies.append(decode_latency)
@@ -612,20 +1083,70 @@ def run_single_setting(
decode_latency,
)
+ # Save if the requested profile window ran past the final decode step.
+ if decode_handle is not None:
+ _stop_profile(decode_handle, args, decode_trace, "decode")
+ decode_handle = None
+
if not record_result:
return None
+ extra_metrics = None
+ if mm_input is not None:
+ extra_metrics = make_multimodal_prefill_metrics(
+ prefill_latency=prefill_latency,
+ batch_size=effective_setting.batch_size,
+ input_len=effective_setting.input_len,
+ )
+ if (
+ extend_result.vit_prefill_ms is not None
+ and extend_result.vit_prefill_tokens is not None
+ ):
+ extra_metrics.update(
+ make_vit_prefill_metrics(
+ vit_prefill_ms=extend_result.vit_prefill_ms,
+ vit_prefill_tokens=extend_result.vit_prefill_tokens,
+ )
+ )
+
return summarize_latencies(
- setting=setting,
+ setting=effective_setting,
prefill_latency=prefill_latency,
decode_latencies=decode_latencies,
run_name=args.run_name,
device=bench_runner.device,
dtype=str(bench_runner.runner.dtype),
cuda_graph=bench_runner.runner.graph_runner is not None,
+ extra=extra_metrics,
)
+def _align_runner_capacity_with_batch_sizes(
+ cfg: GlobalConfig, batch_sizes: Sequence[int]
+) -> None:
+ """Ensure the runner can hold and CUDA-graph-capture the largest batch.
+
+ Mirrors SGLang ``main()`` which sets
+ ``server_args.cuda_graph_max_bs = max(bench_args.batch_size)``. In pymllm
+ the CUDA graph capture batch sizes are derived from
+ ``ModelRunner.max_running_requests`` (see ``CudaGraphRunner``), which also
+ sizes ``req_to_token_pool``. Without this, sweeping a batch size larger
+ than the configured capture set makes decode silently fall off the graph
+ path and run eager, biasing decode latency versus SGLang.
+ """
+ if not batch_sizes:
+ return
+ requested = max(batch_sizes)
+ configured = cfg.server.max_running_requests
+ if configured is None or configured < requested:
+ cfg.server.max_running_requests = requested
+ logger.info(
+ "Raised max_running_requests to %d to cover bench batch sizes "
+ "(SGLang cuda_graph_max_bs alignment).",
+ requested,
+ )
+
+
def run_benchmark(cfg: GlobalConfig, args: BenchArgs) -> list[dict[str, Any]]:
_load_hf_config(cfg)
logger.info(
@@ -633,9 +1154,17 @@ def run_benchmark(cfg: GlobalConfig, args: BenchArgs) -> list[dict[str, Any]]:
"do not chunk this benchmark."
)
+ _align_runner_capacity_with_batch_sizes(cfg, args.batch_size)
bench_runner = PymllmBenchRunner.create(cfg)
try:
settings = generate_settings(args)
+ multimodal_processor_bundle = None
+ if args.image_path is not None:
+ multimodal_processor_bundle = _make_multimodal_processor_output(
+ cfg=cfg,
+ prompt=args.prompt,
+ image_path=args.image_path,
+ )
if not args.skip_warmup and settings:
first = settings[0]
warmup_setting = BenchSetting(
@@ -655,6 +1184,8 @@ def run_benchmark(cfg: GlobalConfig, args: BenchArgs) -> list[dict[str, Any]]:
setting=warmup_setting,
seed=args.seed,
record_result=False,
+ multimodal_processor_bundle=multimodal_processor_bundle,
+ allow_profile=False,
)
results: list[dict[str, Any]] = []
@@ -671,8 +1202,11 @@ def run_benchmark(cfg: GlobalConfig, args: BenchArgs) -> list[dict[str, Any]]:
setting=setting,
seed=args.seed + index,
record_result=True,
+ multimodal_processor_bundle=multimodal_processor_bundle,
)
- assert result is not None
+ if result is None:
+ # Setting skipped (e.g. exceeds KV pool capacity); do not record.
+ continue
_append_jsonl(args.result_filename, result)
logger.info("Result: %s", json.dumps(result, sort_keys=True))
results.append(result)
@@ -681,10 +1215,81 @@ def run_benchmark(cfg: GlobalConfig, args: BenchArgs) -> list[dict[str, Any]]:
bench_runner.shutdown()
+DEFAULT_CORRECTNESS_PROMPTS = (
+ "The capital of France is",
+ "The capital of the United Kingdom is",
+ "Today is a sunny day and I like",
+)
+
+
+def _load_tokenizer(cfg: GlobalConfig) -> Any:
+ if cfg.server.tokenizer_path is None:
+ raise ValueError("--server.tokenizer_path or --server.model_path is required")
+
+ from transformers import AutoTokenizer
+
+ return AutoTokenizer.from_pretrained(
+ str(cfg.server.tokenizer_path),
+ trust_remote_code=cfg.server.trust_remote_code,
+ )
+
+
+def correctness_test(
+ bench_runner: PymllmBenchRunner,
+ cfg: GlobalConfig,
+ args: BenchArgs,
+) -> None:
+ """Single-stage smoke correctness check.
+
+ Encode a real prompt, run one full prefill at batch_size=1, greedy-decode
+ ``output_len`` tokens, and print the decoded text. Unlike SGLang's
+ ``--correct`` (which exercises a cut_len two-stage prefill to test prefix-KV
+ reuse), this runs each prompt as a single full prefill. Greedy decoding
+ makes the per-prompt output identical to SGLang's batched path. The cut_len
+ two-stage variant can be layered on later: ``prepare_forward_batch_extend``
+ already accepts ``extend_prefix_lens > 0`` and ``req_to_token_pool.write``
+ can pre-populate prefix KV indices.
+ """
+ tokenizer = _load_tokenizer(cfg)
+ output_len = args.output_len[0]
+ prompts = list(DEFAULT_CORRECTNESS_PROMPTS)
+
+ for idx, prompt in enumerate(prompts):
+ token_ids = list(tokenizer.encode(prompt))
+ if not token_ids:
+ logger.warning("Prompt %d encoded to an empty token list, skipping.", idx)
+ continue
+ input_ids = torch.tensor(
+ [token_ids], dtype=torch.int32, device=bench_runner.device
+ )
+
+ bench_runner.clear()
+ next_token_ids, state = bench_runner.extend(input_ids)
+ output_ids = token_ids + [int(next_token_ids[0].item())]
+ for _ in range(max(0, output_len - 1)):
+ next_token_ids, state = bench_runner.decode(next_token_ids, state)
+ output_ids.append(int(next_token_ids[0].item()))
+
+ print(f"========== Prompt {idx} ==========")
+ print(tokenizer.decode(output_ids), "\n")
+
+
+def run_correctness(cfg: GlobalConfig, args: BenchArgs) -> None:
+ _load_hf_config(cfg)
+ bench_runner = PymllmBenchRunner.create(cfg)
+ try:
+ correctness_test(bench_runner, cfg, args)
+ finally:
+ bench_runner.shutdown()
+
+
def main(argv: Optional[Sequence[str]] = None) -> None:
cfg, args = parse_args(argv)
_configure_logging(cfg.server.log_level)
- run_benchmark(cfg, args)
+ if args.correctness_test:
+ run_correctness(cfg, args)
+ else:
+ run_benchmark(cfg, args)
if __name__ == "__main__":
diff --git a/pymllm/executor/model_runner.py b/pymllm/executor/model_runner.py
index a50baa13e..7e10bf818 100644
--- a/pymllm/executor/model_runner.py
+++ b/pymllm/executor/model_runner.py
@@ -37,6 +37,7 @@
import gc
import logging
+import os
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
@@ -79,6 +80,9 @@ def get_available_gpu_memory(device: str = "cuda", gpu_id: int = 0) -> float:
if device != "cuda" or not torch.cuda.is_available():
return 0.0
torch.cuda.set_device(gpu_id)
+ props = torch.cuda.get_device_properties(gpu_id)
+ if getattr(props, "is_integrated", False):
+ return _get_system_available_memory_gb()
free, _ = torch.cuda.mem_get_info(gpu_id)
return free / (1 << 30)
@@ -92,6 +96,26 @@ def get_total_gpu_memory(device: str = "cuda", gpu_id: int = 0) -> float:
return total / (1 << 30)
+def _get_system_available_memory_gb() -> float:
+ try:
+ import psutil
+
+ return psutil.virtual_memory().available / (1 << 30)
+ except Exception:
+ pass
+
+ if hasattr(os, "sysconf"):
+ try:
+ page_size = os.sysconf("SC_PAGE_SIZE")
+ avail_pages = os.sysconf("SC_AVPHYS_PAGES")
+ return (page_size * avail_pages) / (1 << 30)
+ except (OSError, ValueError):
+ pass
+
+ free, _ = torch.cuda.mem_get_info()
+ return free / (1 << 30)
+
+
# ---------------------------------------------------------------------------
# LogitsProcessorOutput
# ---------------------------------------------------------------------------
@@ -115,6 +139,19 @@ class LogitsProcessorOutput:
hidden_states: Optional[torch.Tensor] = None
+@dataclass
+class MemoryProfileResult:
+ pre_model_available_gb: float
+ available_gb: float
+ mem_fraction: float
+ static_kv_budget_gb: float
+ cell_size_bytes: int
+ profiled_max_tokens: int
+ requested_max_total_tokens: Optional[int]
+ effective_max_tokens: int
+ gdn_pool_gb: float = 0.0
+
+
# ---------------------------------------------------------------------------
# Penalty helpers
# ---------------------------------------------------------------------------
@@ -274,6 +311,12 @@ def __init__(
# Forward pass counter (monotonically increasing).
self.forward_pass_id: int = 0
+ # GPU memory available before model weights are loaded. This is used to
+ # match SGLang's mem_fraction_static semantics: static memory includes
+ # both model weights and the KV cache pool.
+ self._pre_model_load_available_gb: float = 0.0
+ self._last_memory_profile: Optional[MemoryProfileResult] = None
+
# ------------------------------------------------------------------
# Initialisation
# ------------------------------------------------------------------
@@ -301,6 +344,11 @@ def initialize(self) -> None:
torch.set_default_dtype(self.dtype)
# Load the model
+ if self.device == "cuda":
+ self._pre_model_load_available_gb = get_available_gpu_memory(
+ self.device,
+ self.gpu_id,
+ )
self.load_model()
# Extract model metadata from hf_config
@@ -703,10 +751,7 @@ def init_memory_pool(self) -> None:
self.max_running_requests = max_reqs
if self.max_total_num_tokens <= 0:
- raise RuntimeError(
- "Not enough memory for KV cache. "
- "Try reducing context_length or using a smaller model."
- )
+ raise RuntimeError(self._format_kv_cache_memory_error())
# Create ReqToTokenPool
self.req_to_token_pool = make_req_to_token_pool(
@@ -774,24 +819,26 @@ def init_memory_pool(self) -> None:
def _profile_max_num_tokens(self) -> int:
"""Profile available memory to determine maximum KV-cache tokens.
- If ``server_config.max_total_tokens`` is explicitly set that value
- is used directly. Otherwise a memory-fraction-based heuristic
- similar to sglang's ``profile_max_num_token`` is applied.
+ ``mem_fraction_static`` follows SGLang's semantics: it is the fraction
+ of total static memory budget used by model weights plus KV cache pool.
+ ``server_config.max_total_tokens`` is an upper bound on the profiled
+ capacity, not a replacement for profiling.
"""
- # If user explicitly set max_total_tokens, use that.
- if self.server_config.max_total_tokens is not None:
- return self.server_config.max_total_tokens
-
if self.device != "cuda":
# For CPU, use a conservative default.
+ if self.server_config.max_total_tokens is not None:
+ return self.server_config.max_total_tokens
return 4096
available_gb = get_available_gpu_memory(self.device, self.gpu_id)
+ pre_model_available_gb = getattr(self, "_pre_model_load_available_gb", 0.0)
+ if pre_model_available_gb <= 0:
+ pre_model_available_gb = available_gb
- # Determine memory fraction for static allocation (KV cache).
+ # Determine memory fraction for static allocation (model weights + KV cache).
mem_fraction = self.server_config.mem_fraction_static
if mem_fraction is None:
- mem_fraction = 0.85 # default: use 85% of remaining memory
+ mem_fraction = 0.85
# Calculate per-token KV cache size in bytes.
kv_element_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
@@ -809,7 +856,9 @@ def _profile_max_num_tokens(self) -> int:
)
return 4096
- rest_memory_bytes = int(available_gb * mem_fraction * (1 << 30))
+ rest_memory_gb = available_gb - pre_model_available_gb * (1 - mem_fraction)
+ rest_memory_bytes = int(rest_memory_gb * (1 << 30))
+ gdn_pool_gb = 0.0
# Reserve memory for GDN pool if hybrid model
if self.num_gdn_layers > 0:
@@ -856,23 +905,106 @@ def _profile_max_num_tokens(self) -> int:
)
gdn_pool_bytes = recurrent_bytes + conv_bytes
rest_memory_bytes -= gdn_pool_bytes
+ gdn_pool_gb = gdn_pool_bytes / (1 << 30)
logger.info(
"GDN pool memory reservation: %.2f GB",
- gdn_pool_bytes / (1 << 30),
+ gdn_pool_gb,
)
- max_num_tokens = rest_memory_bytes // cell_size
+ profiled_max_tokens = max(rest_memory_bytes // cell_size, 0)
+ max_num_tokens = profiled_max_tokens
+
+ if self.server_config.max_total_tokens is not None:
+ if self.server_config.max_total_tokens > max_num_tokens:
+ logger.warning(
+ "max_total_tokens=%d is larger than the profiled value %d. "
+ "Use the profiled value instead.",
+ self.server_config.max_total_tokens,
+ max_num_tokens,
+ )
+ max_num_tokens = min(max_num_tokens, self.server_config.max_total_tokens)
+
+ self._last_memory_profile = MemoryProfileResult(
+ pre_model_available_gb=pre_model_available_gb,
+ available_gb=available_gb,
+ mem_fraction=mem_fraction,
+ static_kv_budget_gb=rest_memory_gb,
+ cell_size_bytes=cell_size,
+ profiled_max_tokens=profiled_max_tokens,
+ requested_max_total_tokens=self.server_config.max_total_tokens,
+ effective_max_tokens=max_num_tokens,
+ gdn_pool_gb=gdn_pool_gb,
+ )
logger.info(
- "Memory profiling: avail=%.2f GB, fraction=%.2f, "
- "cell_size=%d bytes, max_tokens=%d",
+ "Memory profiling: pre_model_avail=%.2f GB, avail=%.2f GB, "
+ "fraction=%.2f, static_kv_budget=%.2f GB, cell_size=%d bytes, "
+ "max_tokens=%d",
+ pre_model_available_gb,
available_gb,
mem_fraction,
+ rest_memory_gb,
cell_size,
max_num_tokens,
)
- return max(max_num_tokens, 1) # at least 1
+ return max_num_tokens
+
+ def _format_kv_cache_memory_error(self) -> str:
+ profile = getattr(self, "_last_memory_profile", None)
+ if profile is None:
+ return (
+ "Not enough memory for KV cache. Try increasing "
+ "--server.mem_fraction_static, reducing --server.max_total_tokens, "
+ "lowering --server.max_running_requests, or using a "
+ "smaller/quantized model."
+ )
+
+ requested = (
+ "unset"
+ if profile.requested_max_total_tokens is None
+ else str(profile.requested_max_total_tokens)
+ )
+ message = [
+ "Not enough memory for KV cache.",
+ (
+ "Memory profile: "
+ f"pre_model_avail={profile.pre_model_available_gb:.2f} GB, "
+ f"avail_after_model={profile.available_gb:.2f} GB, "
+ f"mem_fraction_static={profile.mem_fraction:.2f}, "
+ f"static_kv_budget={profile.static_kv_budget_gb:.2f} GB, "
+ f"cell_size={profile.cell_size_bytes} bytes, "
+ f"profiled max_tokens={profile.profiled_max_tokens}, "
+ f"requested max_total_tokens={requested}."
+ ),
+ ]
+ if profile.static_kv_budget_gb <= 0:
+ message.append(
+ "The static KV budget is non-positive, so model weights and "
+ "other static allocations already exceed the requested static "
+ "memory fraction."
+ )
+ if profile.gdn_pool_gb > 0:
+ message.append(f"GDN pool reservation={profile.gdn_pool_gb:.2f} GB.")
+ if (
+ profile.requested_max_total_tokens is not None
+ and profile.profiled_max_tokens < profile.requested_max_total_tokens
+ ):
+ message.append(
+ "The requested max_total_tokens is above the profiled KV "
+ "capacity for this launch."
+ )
+ message.append(
+ "Try increasing --server.mem_fraction_static, reducing "
+ "--server.max_total_tokens, lowering --server.max_running_requests, "
+ "or using a smaller/quantized model."
+ )
+ message.append(
+ "Note: server mode may need a higher mem_fraction_static than "
+ "bench_one_batch because tokenizer/scheduler/detokenizer/HTTP "
+ "processes and IPC buffers consume additional Jetson unified memory."
+ )
+ return " ".join(message)
# ------------------------------------------------------------------
# Attention backend
@@ -1287,6 +1419,7 @@ def sample(
top_ps: Optional[torch.Tensor] = None,
top_ks: Optional[torch.Tensor] = None,
penalty_params: Optional[Dict[str, Any]] = None,
+ is_all_greedy: Optional[bool] = None,
) -> torch.Tensor:
"""Sample next-token IDs from logits.
@@ -1310,6 +1443,10 @@ def sample(
``frequency_penalties``, ``presence_penalties`` (tensors of
shape ``[batch_size]``), and ``token_histories`` (list of
list of int).
+ is_all_greedy
+ CPU-side metadata indicating that every request should use greedy
+ sampling. Supplying this avoids a CUDA tensor reduction and
+ synchronization in the decode hot path.
Returns
-------
@@ -1338,12 +1475,14 @@ def sample(
)
# Greedy path: temperature=0 (or all zeros) → argmax, no sampling.
- if temperatures is not None:
- all_greedy = bool((temperatures < 1e-6).all())
- else:
- all_greedy = False
+ if is_all_greedy is None:
+ is_all_greedy = (
+ bool((temperatures < 1e-6).all())
+ if temperatures is not None
+ else False
+ )
- if all_greedy:
+ if is_all_greedy:
return logits.argmax(dim=-1).to(torch.int32)
# Stochastic path: apply temperature then sample.
diff --git a/pymllm/layers/__init__.py b/pymllm/layers/__init__.py
index d328ca7ef..18e5a1dcf 100644
--- a/pymllm/layers/__init__.py
+++ b/pymllm/layers/__init__.py
@@ -17,6 +17,7 @@
apply_llama31_rope,
apply_llama31_rope_pos_ids,
apply_mrope,
+ apply_mrope_fused_,
apply_rope,
apply_rope_pos_ids,
apply_rope_with_cos_sin_cache,
@@ -51,6 +52,7 @@
"RMSNorm",
"GemmaRMSNorm",
"apply_mrope",
+ "apply_mrope_fused_",
"apply_rope",
"apply_llama31_rope",
"apply_rope_pos_ids",
diff --git a/pymllm/layers/rms_norm.py b/pymllm/layers/rms_norm.py
index e9a4c6ed0..21163b4c2 100644
--- a/pymllm/layers/rms_norm.py
+++ b/pymllm/layers/rms_norm.py
@@ -10,6 +10,57 @@
from pymllm.layers.utils import set_weight_attrs
+_PATCHED_CUDA_DEVICE_PROPERTIES = False
+
+
+class _CudaDevicePropertiesProxy:
+ def __init__(self, props):
+ self._props = props
+
+ def __getattr__(self, name: str):
+ if name == "shared_memory_per_block_optin":
+ return _infer_shared_memory_per_block_optin(self._props)
+ return getattr(self._props, name)
+
+
+def _infer_shared_memory_per_block_optin(props) -> int:
+ """Infer opt-in shared memory for older PyTorch device properties."""
+ if hasattr(props, "shared_memory_per_block_optin"):
+ return int(props.shared_memory_per_block_optin)
+ if hasattr(props, "shared_memory_per_multiprocessor"):
+ return int(props.shared_memory_per_multiprocessor)
+ return int(getattr(props, "shared_memory_per_block", 0))
+
+
+def _patch_cuda_device_properties_for_flashinfer_norm() -> None:
+ """Provide a missing PyTorch device property required by FlashInfer norm.
+
+ Some Jetson PyTorch builds expose neither ``shared_memory_per_block_optin``
+ nor ``shared_memory_per_multiprocessor`` on ``torch.cuda.DeviceProperties``.
+ FlashInfer norm kernels query that attribute while choosing their CUTE
+ kernel config. Wrap the properties object so FlashInfer can still choose
+ a valid shared-memory limit instead of falling back to slow PyTorch RMSNorm.
+ """
+ global _PATCHED_CUDA_DEVICE_PROPERTIES
+ if _PATCHED_CUDA_DEVICE_PROPERTIES or not torch.cuda.is_available():
+ return
+
+ original_get_device_properties = torch.cuda.get_device_properties
+ props = original_get_device_properties(0)
+ if hasattr(props, "shared_memory_per_block_optin"):
+ _PATCHED_CUDA_DEVICE_PROPERTIES = True
+ return
+
+ def patched_get_device_properties(*args, **kwargs):
+ props = original_get_device_properties(*args, **kwargs)
+ if hasattr(props, "shared_memory_per_block_optin"):
+ return props
+ return _CudaDevicePropertiesProxy(props)
+
+ torch.cuda.get_device_properties = patched_get_device_properties
+ _PATCHED_CUDA_DEVICE_PROPERTIES = True
+
+
def _torch_rmsnorm(
x: torch.Tensor,
weight: torch.Tensor,
@@ -45,6 +96,7 @@ def forward(
if residual is not None:
try:
+ _patch_cuda_device_properties_for_flashinfer_norm()
flashinfer.norm.fused_add_rmsnorm(
x, residual, self.weight.data, self.eps
)
@@ -54,6 +106,7 @@ def forward(
return _torch_rmsnorm(residual, self.weight, self.eps), residual
try:
+ _patch_cuda_device_properties_for_flashinfer_norm()
# FlashInfer rmsnorm accepts 2D/3D input; flatten higher-rank tensors to 2D.
if x.dim() in (2, 3):
return flashinfer.norm.rmsnorm(x, self.weight, self.eps)
@@ -83,6 +136,7 @@ def forward(
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
+ _patch_cuda_device_properties_for_flashinfer_norm()
flashinfer.norm.gemma_fused_add_rmsnorm(
x, residual, self.weight.data, self.eps
)
@@ -95,6 +149,7 @@ def forward(
)
# gemma_rmsnorm is defined on 2D input; flatten other ranks to 2D.
+ _patch_cuda_device_properties_for_flashinfer_norm()
if x.dim() == 2:
return flashinfer.norm.gemma_rmsnorm(x, self.weight, self.eps)
diff --git a/pymllm/layers/rope.py b/pymllm/layers/rope.py
index 94f89b20d..17fdb39ee 100644
--- a/pymllm/layers/rope.py
+++ b/pymllm/layers/rope.py
@@ -5,6 +5,13 @@
import torch
import flashinfer
+try:
+ import triton
+ import triton.language as tl
+except Exception: # pragma: no cover - Triton is optional outside CUDA builds.
+ triton = None
+ tl = None
+
def apply_rope(
q: torch.Tensor,
@@ -399,3 +406,203 @@ def apply_mrope(
else k_rot
)
return q_out, k_out
+
+
+if triton is not None:
+
+ @triton.jit
+ def _triton_mrope_forward_fused(
+ q_ptr,
+ k_ptr,
+ cos_sin_cache_ptr,
+ positions_ptr,
+ q_stride,
+ k_stride,
+ positions_stride,
+ n_qh: tl.constexpr,
+ n_kh: tl.constexpr,
+ hd: tl.constexpr,
+ rd: tl.constexpr,
+ pad_n_qh: tl.constexpr,
+ pad_n_kh: tl.constexpr,
+ pad_hd: tl.constexpr,
+ mrope_section_t: tl.constexpr,
+ mrope_section_h: tl.constexpr,
+ mrope_section_w: tl.constexpr,
+ is_interleaved: tl.constexpr,
+ is_neox_style: tl.constexpr,
+ ):
+ pid = tl.program_id(0)
+ q_ptr = q_ptr + pid * q_stride
+ k_ptr = k_ptr + pid * k_stride
+ half_rd = rd // 2
+
+ t = tl.load(positions_ptr + pid)
+ h = tl.load(positions_ptr + positions_stride + pid)
+ w = tl.load(positions_ptr + 2 * positions_stride + pid)
+
+ t_cos = cos_sin_cache_ptr + t * rd
+ h_cos = cos_sin_cache_ptr + h * rd
+ w_cos = cos_sin_cache_ptr + w * rd
+ t_sin = t_cos + half_rd
+ h_sin = h_cos + half_rd
+ w_sin = w_cos + half_rd
+
+ cos_offsets = tl.arange(0, pad_hd // 2)
+ if is_interleaved:
+ h_mask = ((cos_offsets % 3) == 1) & (
+ cos_offsets <= 3 * mrope_section_h
+ )
+ w_mask = ((cos_offsets % 3) == 2) & (
+ cos_offsets <= 3 * mrope_section_w
+ )
+ t_mask = ~(h_mask | w_mask)
+ else:
+ t_end = mrope_section_t
+ h_end = t_end + mrope_section_h
+ t_mask = cos_offsets < mrope_section_t
+ h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
+ w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd)
+
+ t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0.0)
+ t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0.0)
+ h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0.0)
+ h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0.0)
+ w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0.0)
+ w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0.0)
+ cos_row = t_cos_row + h_cos_row + w_cos_row
+ sin_row = t_sin_row + h_sin_row + w_sin_row
+
+ if is_neox_style:
+ qh = tl.arange(0, pad_n_qh)[:, None]
+ kh = tl.arange(0, pad_n_kh)[:, None]
+ dim = tl.arange(0, pad_hd // 2)[None, :]
+ q_mask = (qh < n_qh) & (dim < half_rd)
+ k_mask = (kh < n_kh) & (dim < half_rd)
+ q_first = qh * hd + dim
+ k_first = kh * hd + dim
+ q_second = q_first + half_rd
+ k_second = k_first + half_rd
+
+ q1 = tl.load(q_ptr + q_first, mask=q_mask, other=0.0).to(cos_row.dtype)
+ q2 = tl.load(q_ptr + q_second, mask=q_mask, other=0.0).to(cos_row.dtype)
+ k1 = tl.load(k_ptr + k_first, mask=k_mask, other=0.0).to(cos_row.dtype)
+ k2 = tl.load(k_ptr + k_second, mask=k_mask, other=0.0).to(cos_row.dtype)
+
+ tl.store(q_ptr + q_first, q1 * cos_row - q2 * sin_row, mask=q_mask)
+ tl.store(q_ptr + q_second, q2 * cos_row + q1 * sin_row, mask=q_mask)
+ tl.store(k_ptr + k_first, k1 * cos_row - k2 * sin_row, mask=k_mask)
+ tl.store(k_ptr + k_second, k2 * cos_row + k1 * sin_row, mask=k_mask)
+ else:
+ qh = tl.arange(0, pad_n_qh)[:, None]
+ kh = tl.arange(0, pad_n_kh)[:, None]
+ pair = tl.arange(0, pad_hd // 2)[None, :]
+ q_mask = (qh < n_qh) & (pair < half_rd)
+ k_mask = (kh < n_kh) & (pair < half_rd)
+ even = 2 * pair
+ odd = even + 1
+
+ q_even = tl.load(q_ptr + qh * hd + even, mask=q_mask, other=0.0).to(
+ cos_row.dtype
+ )
+ q_odd = tl.load(q_ptr + qh * hd + odd, mask=q_mask, other=0.0).to(
+ cos_row.dtype
+ )
+ k_even = tl.load(k_ptr + kh * hd + even, mask=k_mask, other=0.0).to(
+ cos_row.dtype
+ )
+ k_odd = tl.load(k_ptr + kh * hd + odd, mask=k_mask, other=0.0).to(
+ cos_row.dtype
+ )
+
+ tl.store(q_ptr + qh * hd + even, q_even * cos_row - q_odd * sin_row, mask=q_mask)
+ tl.store(q_ptr + qh * hd + odd, q_odd * cos_row + q_even * sin_row, mask=q_mask)
+ tl.store(k_ptr + kh * hd + even, k_even * cos_row - k_odd * sin_row, mask=k_mask)
+ tl.store(k_ptr + kh * hd + odd, k_odd * cos_row + k_even * sin_row, mask=k_mask)
+
+else:
+ _triton_mrope_forward_fused = None
+
+
+def _can_use_mrope_fused(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ positions: torch.Tensor,
+ cos_sin_cache: torch.Tensor,
+) -> bool:
+ return (
+ triton is not None
+ and q.is_cuda
+ and k.is_cuda
+ and positions.is_cuda
+ and cos_sin_cache.is_cuda
+ and q.dim() == 3
+ and k.dim() == 3
+ and positions.dim() == 2
+ and q.is_contiguous()
+ and k.is_contiguous()
+ and q.shape[-1] == k.shape[-1]
+ and cos_sin_cache.shape[-1] <= q.shape[-1]
+ )
+
+
+def apply_mrope_fused_(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ positions: torch.Tensor,
+ cos_sin_cache: torch.Tensor,
+ mrope_section: List[int],
+ mrope_interleaved: bool = True,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Apply M-RoPE in place with a Triton fused kernel when available.
+
+ The fallback returns the reference PyTorch outputs, preserving functional
+ correctness on CPU or non-contiguous inputs.
+ """
+ if positions.ndim == 1:
+ positions = positions.unsqueeze(0).expand(3, -1)
+ if positions.stride(-1) != 1:
+ positions = positions.contiguous()
+
+ if not _can_use_mrope_fused(q, k, positions, cos_sin_cache):
+ return apply_mrope(
+ q,
+ k,
+ positions,
+ cos_sin_cache,
+ mrope_section,
+ mrope_interleaved,
+ )
+
+ num_tokens, num_q_heads, head_size = q.shape
+ num_kv_heads = k.shape[1]
+ rotary_dim = cos_sin_cache.shape[-1]
+ q_2d = q.reshape(num_tokens, num_q_heads * head_size)
+ k_2d = k.reshape(num_tokens, num_kv_heads * head_size)
+
+ pad_n_qh = triton.next_power_of_2(num_q_heads)
+ pad_n_kh = triton.next_power_of_2(num_kv_heads)
+ pad_hd = triton.next_power_of_2(head_size)
+
+ _triton_mrope_forward_fused[(num_tokens,)](
+ q_2d,
+ k_2d,
+ cos_sin_cache,
+ positions,
+ q_2d.stride(0),
+ k_2d.stride(0),
+ positions.stride(0),
+ num_q_heads,
+ num_kv_heads,
+ head_size,
+ rotary_dim,
+ pad_n_qh,
+ pad_n_kh,
+ pad_hd,
+ int(mrope_section[0]),
+ int(mrope_section[1]),
+ int(mrope_section[2]),
+ bool(mrope_interleaved),
+ True,
+ )
+ return q, k
diff --git a/pymllm/models/qwen3_vl.py b/pymllm/models/qwen3_vl.py
index 2b945fc18..273a291a4 100644
--- a/pymllm/models/qwen3_vl.py
+++ b/pymllm/models/qwen3_vl.py
@@ -35,7 +35,7 @@
import torch.nn as nn
import torch.nn.functional as F
-from pymllm.layers import RMSNorm, apply_mrope
+from pymllm.layers import RMSNorm, apply_mrope_fused_
from pymllm.layers.attention.radix_attention import RadixAttention
from pymllm.layers.linear import Linear, MergedLinear
from pymllm.layers.mlp import MLP
@@ -577,6 +577,21 @@ def _build_cos_sin_cache(
return torch.cat([torch.cos(freqs), torch.sin(freqs)], dim=-1).to(dtype)
+def _run_with_synchronized_wall_timing(
+ fn,
+ *,
+ device: torch.device,
+ enabled: bool,
+) -> Tuple[torch.Tensor, float]:
+ if enabled and device.type == "cuda":
+ torch.cuda.synchronize(device)
+ tic = time.perf_counter()
+ result = fn()
+ if enabled and device.type == "cuda":
+ torch.cuda.synchronize(device)
+ return result, (time.perf_counter() - tic) * 1000.0
+
+
def get_rope_index(
input_ids: torch.Tensor,
image_grid_thw: Optional[torch.Tensor],
@@ -776,6 +791,21 @@ def _get_qm(suffix):
layer_id=layer_id,
)
+ def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None:
+ """Keep the RoPE cache on the same device/dtype as the query.
+
+ The cache is built in FP32 for stability, but repeated ``to()``
+ conversions inside the hot path are expensive. Match SGLang by
+ materialising the converted cache once and reusing it afterwards.
+ """
+ if (
+ self.cos_sin_cache.device != query.device
+ or self.cos_sin_cache.dtype != query.dtype
+ ):
+ self.cos_sin_cache = self.cos_sin_cache.to(
+ device=query.device, dtype=query.dtype
+ )
+
def forward(
self,
positions: torch.Tensor,
@@ -798,11 +828,12 @@ def forward(
# as [T] for purely text-only batches; expand to [3, T] in that case.
if positions.ndim == 1:
positions = positions.unsqueeze(0).expand(3, -1)
- q, k = apply_mrope(
+ self._match_cos_sin_cache_dtype(q)
+ q, k = apply_mrope_fused_(
q,
k,
positions,
- self.cos_sin_cache.to(q.dtype),
+ self.cos_sin_cache,
self.mrope_section,
self.mrope_interleaved,
)
@@ -1183,6 +1214,7 @@ def forward(
input_deepstack_embeds = None
vit_prefill_ms = None
vit_prefill_tokens = None
+ vit_prefill_tps = None
llm_prefill_ms = None
llm_decode_ms = None
@@ -1193,11 +1225,14 @@ def forward(
and not forward_batch.forward_mode.is_decode()
):
# Run vision encoder
- _vit_t0 = time.perf_counter()
- vision_features = (
- self.visual(pixel_values, grid_thw=image_grid_thw)
+ benchmark_vision_timing = bool(
+ getattr(forward_batch, "benchmark_vision_timing", False)
+ )
+ vision_features, vit_prefill_ms = _run_with_synchronized_wall_timing(
+ lambda: self.visual(pixel_values, grid_thw=image_grid_thw),
+ device=pixel_values.device,
+ enabled=benchmark_vision_timing,
)
- vit_prefill_ms = (time.perf_counter() - _vit_t0) * 1000.0
# Separate main embeddings and deepstack embeddings
if self.num_deepstack_embeddings > 0:
@@ -1211,6 +1246,10 @@ def forward(
input_embeds = self.model.embed_tokens(input_ids)
image_mask = input_ids == self.image_token_id
vit_prefill_tokens = int(image_mask.sum().item())
+ if vit_prefill_ms > 0:
+ vit_prefill_tps = vit_prefill_tokens / (vit_prefill_ms / 1000.0)
+ else:
+ vit_prefill_tps = 0.0
if vit_prefill_tokens != int(vision_embeds.shape[0]):
raise ValueError(
"Image features and image tokens do not match, "
@@ -1249,6 +1288,7 @@ def forward(
llm_prefill_ms = _llm_ms
forward_batch.vit_prefill_ms = vit_prefill_ms
forward_batch.vit_prefill_tokens = vit_prefill_tokens
+ forward_batch.vit_prefill_tps = vit_prefill_tps
forward_batch.llm_prefill_ms = llm_prefill_ms
forward_batch.llm_decode_ms = None
else:
diff --git a/pymllm/orchestrator/model_runner_process.py b/pymllm/orchestrator/model_runner_process.py
index 383fa2da5..9a0e5cf92 100644
--- a/pymllm/orchestrator/model_runner_process.py
+++ b/pymllm/orchestrator/model_runner_process.py
@@ -272,6 +272,7 @@ def _forward_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
temps_tensor = torch.tensor(temperatures, dtype=torch.float32, device=device)
top_ps_tensor = torch.tensor(top_ps, dtype=torch.float32, device=device)
top_ks_tensor = torch.tensor(top_ks, dtype=torch.int32, device=device)
+ is_all_greedy = all(t < 1e-6 for t in temperatures)
# Collect token histories for penalty computation.
# Each entry is (input_ids + output_ids_so_far) for the request.
@@ -408,6 +409,7 @@ def _forward_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
top_ps=top_ps_tensor,
top_ks=top_ks_tensor,
penalty_params=penalty_params,
+ is_all_greedy=is_all_greedy,
)
# ==============================================================
diff --git a/pymllm/tests/test_bench_one_batch.py b/pymllm/tests/test_bench_one_batch.py
index cc2a87ae7..36e9afd71 100644
--- a/pymllm/tests/test_bench_one_batch.py
+++ b/pymllm/tests/test_bench_one_batch.py
@@ -2,14 +2,19 @@
import pytest
import torch
+from types import SimpleNamespace
from pymllm.configs.global_config import GlobalConfig
from pymllm.bench_one_batch import (
BenchArgs,
BenchSetting,
+ DecodeState,
+ PymllmBenchRunner,
generate_settings,
+ make_multimodal_bench_input_from_processor_output,
make_profile_trace_path,
make_synthetic_input_ids,
+ make_vit_prefill_metrics,
parse_args,
summarize_latencies,
)
@@ -53,6 +58,10 @@ def test_parse_args_accepts_server_config_and_list_bench_args(tmp_path):
"--profile-activities",
"CPU",
"GPU",
+ "--image",
+ str(tmp_path / "image.jpg"),
+ "--prompt",
+ "What is in this image?",
]
)
@@ -67,6 +76,26 @@ def test_parse_args_accepts_server_config_and_list_bench_args(tmp_path):
assert bench_args.result_filename == result_file
assert bench_args.profile_stage == "decode"
assert bench_args.profile_activities == ["CPU", "GPU"]
+ assert bench_args.image_path == tmp_path / "image.jpg"
+ assert bench_args.prompt == "What is in this image?"
+ assert bench_args.input_len_was_provided is True
+
+
+def test_parse_args_detects_default_input_len_when_omitted_in_image_mode(tmp_path):
+ model_dir = tmp_path / "model"
+ model_dir.mkdir()
+
+ _, bench_args = parse_args(
+ [
+ "--server.model_path",
+ str(model_dir),
+ "--image",
+ str(tmp_path / "image.jpg"),
+ ]
+ )
+
+ assert bench_args.input_len == [256, 512, 1024]
+ assert bench_args.input_len_was_provided is False
def test_generate_settings_has_stable_batch_input_output_order(tmp_path):
@@ -85,6 +114,39 @@ def test_generate_settings_has_stable_batch_input_output_order(tmp_path):
]
+def test_generate_settings_uses_processed_prompt_length_for_image_mode(tmp_path):
+ args = BenchArgs(
+ batch_size=[1, 2],
+ input_len=[256, 512],
+ output_len=[8],
+ result_filename=tmp_path / "out.jsonl",
+ image_path=tmp_path / "image.jpg",
+ )
+
+ assert generate_settings(args) == [
+ BenchSetting(batch_size=1, input_len=0, output_len=8),
+ BenchSetting(batch_size=2, input_len=0, output_len=8),
+ ]
+
+
+def test_generate_settings_sweeps_input_len_when_explicit_in_image_mode(tmp_path):
+ args = BenchArgs(
+ batch_size=[1, 2],
+ input_len=[512, 1024],
+ output_len=[8],
+ result_filename=tmp_path / "out.jsonl",
+ image_path=tmp_path / "image.jpg",
+ input_len_was_provided=True,
+ )
+
+ assert generate_settings(args) == [
+ BenchSetting(batch_size=1, input_len=512, output_len=8),
+ BenchSetting(batch_size=1, input_len=1024, output_len=8),
+ BenchSetting(batch_size=2, input_len=512, output_len=8),
+ BenchSetting(batch_size=2, input_len=1024, output_len=8),
+ ]
+
+
def test_make_synthetic_input_ids_is_seeded_int32_and_vocab_capped():
first = make_synthetic_input_ids(
batch_size=2,
@@ -136,6 +198,202 @@ def test_summarize_latencies_matches_sglang_style_metrics():
assert result["cuda_graph"] is True
+def test_make_vit_prefill_metrics_reports_seconds_and_tps():
+ result = make_vit_prefill_metrics(vit_prefill_ms=12.5, vit_prefill_tokens=25)
+
+ assert result == {
+ "vit_prefill_latency": pytest.approx(0.0125),
+ "vit_prefill_ms": pytest.approx(12.5),
+ "vit_prefill_tokens": 25,
+ "vit_prefill_throughput": pytest.approx(2000.0),
+ "vit_prefill_tps": pytest.approx(2000.0),
+ }
+
+
+def test_make_multimodal_prefill_metrics_aliases_full_prefill_latency():
+ from pymllm.bench_one_batch import make_multimodal_prefill_metrics
+
+ result = make_multimodal_prefill_metrics(
+ prefill_latency=0.25,
+ batch_size=2,
+ input_len=128,
+ )
+
+ assert result == {
+ "multimodal_prefill_latency": pytest.approx(0.25),
+ "multimodal_prefill_ms": pytest.approx(250.0),
+ "multimodal_prefill_tokens": 256,
+ "multimodal_prefill_throughput": pytest.approx(1024.0),
+ "multimodal_prefill_tps": pytest.approx(1024.0),
+ }
+
+
+def test_make_multimodal_bench_input_repeats_processor_output_per_batch():
+ bench_input = make_multimodal_bench_input_from_processor_output(
+ {
+ "input_ids": torch.tensor([[1, 5, 5, 2]], dtype=torch.int64),
+ "pixel_values": torch.arange(6, dtype=torch.float32).reshape(2, 3),
+ "image_grid_thw": torch.tensor([[1, 2, 2]], dtype=torch.int64),
+ },
+ batch_size=3,
+ image_token_id=5,
+ device="cpu",
+ )
+
+ assert bench_input.input_ids.shape == (3, 4)
+ assert bench_input.input_ids.dtype == torch.int32
+ assert bench_input.vit_prefill_tokens == 6
+ torch.testing.assert_close(
+ bench_input.pixel_values,
+ torch.arange(6, dtype=torch.float32).reshape(2, 3).repeat(3, 1),
+ )
+ torch.testing.assert_close(
+ bench_input.image_grid_thw,
+ torch.tensor([[1, 2, 2], [1, 2, 2], [1, 2, 2]], dtype=torch.int64),
+ )
+
+
+def test_make_multimodal_bench_input_pads_text_tokens_to_target_len():
+ bench_input = make_multimodal_bench_input_from_processor_output(
+ {
+ "input_ids": torch.tensor([[101, 5, 5, 102]], dtype=torch.int64),
+ "pixel_values": torch.arange(6, dtype=torch.float32).reshape(2, 3),
+ "image_grid_thw": torch.tensor([[1, 2, 2]], dtype=torch.int64),
+ },
+ batch_size=1,
+ image_token_id=5,
+ device="cpu",
+ target_input_len=7,
+ pad_token_id=0,
+ )
+
+ torch.testing.assert_close(
+ bench_input.input_ids,
+ torch.tensor([[101, 5, 5, 102, 0, 0, 0]], dtype=torch.int32),
+ )
+ assert bench_input.vit_prefill_tokens == 2
+
+
+def test_make_multimodal_bench_input_does_not_pad_with_image_token_id():
+ bench_input = make_multimodal_bench_input_from_processor_output(
+ {
+ "input_ids": torch.tensor([[101, 5, 5, 102]], dtype=torch.int64),
+ "pixel_values": torch.arange(6, dtype=torch.float32).reshape(2, 3),
+ "image_grid_thw": torch.tensor([[1, 2, 2]], dtype=torch.int64),
+ },
+ batch_size=1,
+ image_token_id=5,
+ device="cpu",
+ target_input_len=6,
+ pad_token_id=5,
+ )
+
+ torch.testing.assert_close(
+ bench_input.input_ids,
+ torch.tensor([[101, 5, 5, 102, 0, 0]], dtype=torch.int32),
+ )
+ assert bench_input.vit_prefill_tokens == 2
+
+
+def test_make_multimodal_bench_input_truncates_text_tokens_not_image_tokens():
+ bench_input = make_multimodal_bench_input_from_processor_output(
+ {
+ "input_ids": torch.tensor([[101, 7, 5, 5, 102, 103]], dtype=torch.int64),
+ "pixel_values": torch.arange(6, dtype=torch.float32).reshape(2, 3),
+ "image_grid_thw": torch.tensor([[1, 2, 2]], dtype=torch.int64),
+ },
+ batch_size=1,
+ image_token_id=5,
+ device="cpu",
+ target_input_len=4,
+ pad_token_id=0,
+ )
+
+ torch.testing.assert_close(
+ bench_input.input_ids,
+ torch.tensor([[101, 7, 5, 5]], dtype=torch.int32),
+ )
+ assert bench_input.vit_prefill_tokens == 2
+
+
+def test_make_multimodal_bench_input_rejects_target_shorter_than_image_tokens():
+ with pytest.raises(ValueError, match="target_input_len"):
+ make_multimodal_bench_input_from_processor_output(
+ {
+ "input_ids": torch.tensor([[5, 5, 5, 102]], dtype=torch.int64),
+ "pixel_values": torch.arange(9, dtype=torch.float32).reshape(3, 3),
+ "image_grid_thw": torch.tensor([[1, 3, 1]], dtype=torch.int64),
+ },
+ batch_size=1,
+ image_token_id=5,
+ device="cpu",
+ target_input_len=2,
+ pad_token_id=0,
+ )
+
+
+def test_extend_attaches_multimodal_inputs_and_returns_vit_metrics():
+ class _ReqPool:
+ def alloc(self, batch_size):
+ return list(range(batch_size))
+
+ def write(self, index, value):
+ del index, value
+
+ def clear(self):
+ pass
+
+ class _KvPool:
+ def alloc(self, count):
+ return torch.arange(count, dtype=torch.int64)
+
+ def clear(self):
+ pass
+
+ class _Runner:
+ def __init__(self):
+ self.device = "cpu"
+ self.dtype = torch.float32
+ self.req_to_token_pool = _ReqPool()
+ self.token_to_kv_pool_allocator = _KvPool()
+ self.gdn_pool = None
+ self.last_forward_batch = None
+
+ def prepare_forward_batch_extend(self, **kwargs):
+ return SimpleNamespace(batch_size=kwargs["req_pool_indices"].shape[0])
+
+ def forward(self, forward_batch):
+ self.last_forward_batch = forward_batch
+ forward_batch.vit_prefill_ms = 4.0
+ forward_batch.vit_prefill_tokens = 8
+ return object()
+
+ def sample(self, logits_output, forward_batch, **kwargs):
+ del logits_output, forward_batch, kwargs
+ return torch.tensor([7], dtype=torch.int32)
+
+ fake_runner = _Runner()
+ bench_runner = PymllmBenchRunner(fake_runner)
+ pixel_values = torch.ones((2, 3), dtype=torch.float32)
+ image_grid_thw = torch.tensor([[1, 2, 2]], dtype=torch.int64)
+
+ result = bench_runner.extend(
+ torch.tensor([[1, 5, 5, 2]], dtype=torch.int32),
+ pixel_values=pixel_values,
+ image_grid_thw=image_grid_thw,
+ benchmark_vision_timing=True,
+ )
+
+ assert result.vit_prefill_ms == pytest.approx(4.0)
+ assert result.vit_prefill_tokens == 8
+ torch.testing.assert_close(fake_runner.last_forward_batch.pixel_values, pixel_values)
+ torch.testing.assert_close(
+ fake_runner.last_forward_batch.image_grid_thw,
+ image_grid_thw,
+ )
+ assert fake_runner.last_forward_batch.benchmark_vision_timing is True
+
+
def test_make_profile_trace_path_is_deterministic_and_sanitized(tmp_path):
path = make_profile_trace_path(
output_dir=tmp_path,
@@ -146,4 +404,70 @@ def test_make_profile_trace_path_is_deterministic_and_sanitized(tmp_path):
)
assert path.parent == tmp_path
- assert path.name == "pymllm_profile_qwen3_vl_w8a8_bs1_in256_out8_decode.trace.json"
+ assert (
+ path.name
+ == "pymllm_profile_qwen3_vl_w8a8_bs1_in256_out8_decode.trace.json.gz"
+ )
+
+
+def test_decode_writes_batch_kv_mapping_with_tensor_indices():
+ from pymllm.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
+
+ class _Runner:
+ def __init__(self):
+ self.device = "cpu"
+ self.dtype = torch.float32
+ self.req_to_token_pool = ReqToTokenPool(
+ max_reqs=6,
+ max_context_len=8,
+ device="cpu",
+ )
+ self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
+ size=32,
+ device="cpu",
+ )
+ self.last_decode_kwargs = None
+
+ def prepare_forward_batch_decode(self, **kwargs):
+ self.last_decode_kwargs = kwargs
+ return SimpleNamespace(batch_size=kwargs["req_pool_indices"].shape[0])
+
+ def forward(self, forward_batch):
+ del forward_batch
+ return object()
+
+ def sample(self, logits_output, forward_batch, **kwargs):
+ del logits_output, forward_batch, kwargs
+ return torch.tensor([11, 12, 13], dtype=torch.int32)
+
+ fake_runner = _Runner()
+ bench_runner = PymllmBenchRunner(fake_runner)
+ state = DecodeState(
+ req_pool_indices=torch.tensor([2, 0, 4], dtype=torch.int64),
+ seq_lens=torch.tensor([4, 2, 6], dtype=torch.int32),
+ )
+
+ next_token_ids, next_state = bench_runner.decode(
+ torch.tensor([1, 2, 3], dtype=torch.int32),
+ state,
+ )
+
+ torch.testing.assert_close(
+ fake_runner.req_to_token_pool.req_to_token[
+ torch.tensor([2, 0, 4], dtype=torch.int64),
+ torch.tensor([4, 2, 6], dtype=torch.int64),
+ ],
+ torch.tensor([1, 2, 3], dtype=torch.int32),
+ )
+ torch.testing.assert_close(
+ next_state.seq_lens,
+ torch.tensor([5, 3, 7], dtype=torch.int32),
+ )
+ torch.testing.assert_close(
+ fake_runner.last_decode_kwargs["out_cache_loc"],
+ torch.tensor([1, 2, 3], dtype=torch.int64),
+ )
+ torch.testing.assert_close(
+ next_token_ids,
+ torch.tensor([11, 12, 13], dtype=torch.int32),
+ )
diff --git a/pymllm/tests/test_model_runner_memory_pool.py b/pymllm/tests/test_model_runner_memory_pool.py
new file mode 100644
index 000000000..72c36c1a1
--- /dev/null
+++ b/pymllm/tests/test_model_runner_memory_pool.py
@@ -0,0 +1,167 @@
+from __future__ import annotations
+
+from types import SimpleNamespace
+
+import pytest
+import torch
+
+from pymllm.executor import model_runner
+from pymllm.executor.model_runner import ModelRunner
+
+
+def _make_runner(
+ *,
+ mem_fraction_static: float,
+ max_total_tokens: int | None = None,
+ free_gb: float = 5.0,
+ total_gb: float = 10.0,
+) -> SimpleNamespace:
+ return SimpleNamespace(
+ server_config=SimpleNamespace(
+ mem_fraction_static=mem_fraction_static,
+ max_total_tokens=max_total_tokens,
+ ),
+ device="cuda",
+ gpu_id=0,
+ kv_cache_dtype=torch.float16,
+ num_kv_heads=1,
+ head_dim=1,
+ v_head_dim=1,
+ num_hidden_layers=1,
+ num_gdn_layers=0,
+ _pre_model_load_available_gb=total_gb,
+ _current_available_gb=free_gb,
+ )
+
+
+@pytest.fixture
+def _cuda_available(monkeypatch):
+ monkeypatch.setattr(model_runner.torch.cuda, "is_available", lambda: True)
+
+
+def test_profile_max_num_tokens_treats_mem_fraction_static_as_static_pool_fraction(
+ _cuda_available,
+ monkeypatch,
+):
+ runner = _make_runner(mem_fraction_static=0.4, free_gb=5.0, total_gb=10.0)
+ monkeypatch.setattr(
+ model_runner,
+ "get_available_gpu_memory",
+ lambda *args, **kwargs: runner._current_available_gb,
+ )
+
+ # SGLang-compatible formula:
+ # free_after_model - pre_model_free * (1 - mem_fraction_static)
+ # = 5 GiB - 10 GiB * 0.6 = -1 GiB, so the pool is not allocatable.
+ assert ModelRunner._profile_max_num_tokens(runner) <= 0
+
+
+def test_profile_max_num_tokens_clamps_negative_capacity_and_records_diagnostics(
+ _cuda_available,
+ monkeypatch,
+):
+ runner = _make_runner(
+ mem_fraction_static=0.4,
+ max_total_tokens=4096,
+ free_gb=5.0,
+ total_gb=10.0,
+ )
+ monkeypatch.setattr(
+ model_runner,
+ "get_available_gpu_memory",
+ lambda *args, **kwargs: runner._current_available_gb,
+ )
+
+ assert ModelRunner._profile_max_num_tokens(runner) == 0
+ profile = runner._last_memory_profile
+ assert profile.static_kv_budget_gb == pytest.approx(-1.0)
+ assert profile.profiled_max_tokens == 0
+ assert profile.requested_max_total_tokens == 4096
+
+
+def test_kv_cache_memory_error_includes_profile_details_and_server_hint():
+ runner = _make_runner(
+ mem_fraction_static=0.75,
+ max_total_tokens=4096,
+ free_gb=1.99,
+ total_gb=8.42,
+ )
+ runner._last_memory_profile = model_runner.MemoryProfileResult(
+ pre_model_available_gb=8.42,
+ available_gb=1.99,
+ mem_fraction=0.75,
+ static_kv_budget_gb=-0.11,
+ cell_size_bytes=114688,
+ profiled_max_tokens=0,
+ requested_max_total_tokens=4096,
+ effective_max_tokens=0,
+ gdn_pool_gb=0.0,
+ )
+
+ message = ModelRunner._format_kv_cache_memory_error(runner)
+
+ assert "profiled max_tokens=0" in message
+ assert "requested max_total_tokens=4096" in message
+ assert "static_kv_budget=-0.11 GB" in message
+ assert "mem_fraction_static=0.75" in message
+ assert "server mode may need a higher mem_fraction_static than bench_one_batch" in message
+ assert "--server.mem_fraction_static" in message
+ assert "--server.max_total_tokens" in message
+
+
+def test_profile_max_num_tokens_caps_user_max_total_tokens_to_profiled_capacity(
+ _cuda_available,
+ monkeypatch,
+):
+ runner = _make_runner(
+ mem_fraction_static=0.8,
+ max_total_tokens=8_000_000_000,
+ free_gb=5.0,
+ total_gb=10.0,
+ )
+ monkeypatch.setattr(
+ model_runner,
+ "get_available_gpu_memory",
+ lambda *args, **kwargs: runner._current_available_gb,
+ )
+
+ # Profiled capacity is 3 GiB / 4 bytes per token. A larger user limit should
+ # not bypass profiling or force an oversized KV pool allocation.
+ assert ModelRunner._profile_max_num_tokens(runner) == 805_306_368
+
+
+def test_profile_max_num_tokens_uses_user_limit_when_below_profiled_capacity(
+ _cuda_available,
+ monkeypatch,
+):
+ runner = _make_runner(
+ mem_fraction_static=0.8,
+ max_total_tokens=4096,
+ free_gb=5.0,
+ total_gb=10.0,
+ )
+ monkeypatch.setattr(
+ model_runner,
+ "get_available_gpu_memory",
+ lambda *args, **kwargs: runner._current_available_gb,
+ )
+
+ assert ModelRunner._profile_max_num_tokens(runner) == 4096
+
+
+def test_available_gpu_memory_uses_system_memory_for_integrated_gpu(monkeypatch):
+ monkeypatch.setattr(model_runner.torch.cuda, "is_available", lambda: True)
+ monkeypatch.setattr(model_runner.torch.cuda, "set_device", lambda gpu_id: None)
+ monkeypatch.setattr(
+ model_runner.torch.cuda,
+ "get_device_properties",
+ lambda gpu_id: SimpleNamespace(is_integrated=True),
+ )
+ monkeypatch.setattr(model_runner, "_get_system_available_memory_gb", lambda: 7.5)
+
+ def fail_mem_get_info(*args, **kwargs):
+ raise AssertionError("integrated GPU memory should use system available memory")
+
+ monkeypatch.setattr(model_runner.torch.cuda, "mem_get_info", fail_mem_get_info)
+
+ assert model_runner.get_available_gpu_memory("cuda", 0) == 7.5
diff --git a/pymllm/tests/test_model_runner_sampling.py b/pymllm/tests/test_model_runner_sampling.py
new file mode 100644
index 000000000..09d74cbe0
--- /dev/null
+++ b/pymllm/tests/test_model_runner_sampling.py
@@ -0,0 +1,31 @@
+from __future__ import annotations
+
+from types import SimpleNamespace
+
+import torch
+
+from pymllm.executor.model_runner import LogitsProcessorOutput, ModelRunner
+
+
+def test_sample_uses_cpu_greedy_flag_without_tensor_reduction(monkeypatch):
+ runner = SimpleNamespace(device="cpu")
+ logits_output = LogitsProcessorOutput(
+ next_token_logits=torch.tensor([[1.0, 3.0], [5.0, 2.0]])
+ )
+ forward_batch = SimpleNamespace(batch_size=2)
+
+ def fail_count_nonzero(*args, **kwargs):
+ raise AssertionError("greedy sampling should not reduce temperature tensors")
+
+ monkeypatch.setattr(torch, "count_nonzero", fail_count_nonzero)
+
+ next_token_ids = ModelRunner.sample(
+ runner,
+ logits_output,
+ forward_batch,
+ temperatures=torch.zeros((2,), dtype=torch.float32),
+ is_all_greedy=True,
+ )
+
+ assert next_token_ids.tolist() == [1, 0]
+ assert next_token_ids.dtype == torch.int32
diff --git a/pymllm/tests/test_mrope_triton.py b/pymllm/tests/test_mrope_triton.py
new file mode 100644
index 000000000..93f5e25c4
--- /dev/null
+++ b/pymllm/tests/test_mrope_triton.py
@@ -0,0 +1,70 @@
+from __future__ import annotations
+
+import pytest
+import torch
+
+from pymllm.layers.rope import apply_mrope, apply_mrope_fused_
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required")
+def test_apply_mrope_fused_matches_reference_and_keeps_inputs_in_place():
+ torch.manual_seed(0)
+ num_tokens = 17
+ num_q_heads = 4
+ num_kv_heads = 2
+ head_dim = 8
+ mrope_section = [2, 1, 1]
+
+ q = torch.randn(
+ (num_tokens, num_q_heads, head_dim),
+ device="cuda",
+ dtype=torch.float16,
+ )
+ k = torch.randn(
+ (num_tokens, num_kv_heads, head_dim),
+ device="cuda",
+ dtype=torch.float16,
+ )
+ positions = torch.randint(
+ 0,
+ 32,
+ (3, num_tokens),
+ device="cuda",
+ dtype=torch.long,
+ )
+ cos_sin_cache = torch.randn(
+ (32, head_dim),
+ device="cuda",
+ dtype=torch.float16,
+ )
+
+ expected_q, expected_k = apply_mrope(
+ q,
+ k,
+ positions,
+ cos_sin_cache,
+ mrope_section,
+ mrope_interleaved=True,
+ )
+
+ q_actual = q.clone()
+ k_actual = k.clone()
+ q_ptr = q_actual.data_ptr()
+ k_ptr = k_actual.data_ptr()
+
+ out_q, out_k = apply_mrope_fused_(
+ q_actual,
+ k_actual,
+ positions,
+ cos_sin_cache,
+ mrope_section,
+ mrope_interleaved=True,
+ )
+ torch.cuda.synchronize()
+
+ assert out_q.data_ptr() == q_ptr
+ assert out_k.data_ptr() == k_ptr
+ assert q_actual.data_ptr() == q_ptr
+ assert k_actual.data_ptr() == k_ptr
+ torch.testing.assert_close(q_actual, expected_q, atol=1e-2, rtol=1e-2)
+ torch.testing.assert_close(k_actual, expected_k, atol=1e-2, rtol=1e-2)
diff --git a/pymllm/tests/test_qwen3_vl_deepstack.py b/pymllm/tests/test_qwen3_vl_deepstack.py
index c36bacf39..694bd37c8 100644
--- a/pymllm/tests/test_qwen3_vl_deepstack.py
+++ b/pymllm/tests/test_qwen3_vl_deepstack.py
@@ -1,6 +1,7 @@
from __future__ import annotations
from types import SimpleNamespace
+import time
import numpy as np
import pytest
@@ -12,6 +13,7 @@
Qwen3VLTextModel,
Qwen3VLVisionModel,
_compute_cu_seqlens_from_grid,
+ _run_with_synchronized_wall_timing,
)
@@ -63,6 +65,25 @@ def forward(self, pixel_values, grid_thw):
return torch.ones((1, 2), dtype=torch.float32)
+class _FakeTextModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.embed_tokens = nn.Embedding(8, 2)
+
+ def forward(
+ self,
+ input_ids,
+ positions,
+ forward_batch,
+ input_embeds=None,
+ input_deepstack_embeds=None,
+ ):
+ del positions, forward_batch, input_deepstack_embeds
+ if input_embeds is not None:
+ return input_embeds
+ return self.embed_tokens(input_ids)
+
+
def _make_vl_config() -> SimpleNamespace:
text_config = SimpleNamespace(
hidden_size=2,
@@ -229,6 +250,95 @@ def test_forward_rejects_mismatched_image_token_and_feature_counts():
)
+def test_forward_records_vit_prefill_tps_when_benchmark_timing_enabled():
+ model = Qwen3VLForConditionalGeneration(_make_vl_config())
+ model.visual = _FakeVisual()
+ model.model = _FakeTextModel()
+
+ forward_batch = SimpleNamespace(
+ forward_mode=_Mode(),
+ batch_size=1,
+ extend_start_loc=torch.tensor([0], dtype=torch.int64),
+ extend_seq_lens=torch.tensor([4], dtype=torch.int64),
+ pixel_values=torch.zeros((1, 3), dtype=torch.float32),
+ image_grid_thw=torch.tensor([[1, 1, 1]], dtype=torch.int64),
+ benchmark_vision_timing=True,
+ )
+
+ model(
+ input_ids=torch.tensor([1, 4, 5, 2], dtype=torch.int64),
+ positions=torch.arange(4, dtype=torch.int64),
+ forward_batch=forward_batch,
+ )
+
+ assert forward_batch.vit_prefill_ms is not None
+ assert forward_batch.vit_prefill_ms >= 0.0
+ assert forward_batch.vit_prefill_tokens == 1
+ assert forward_batch.vit_prefill_tps >= 0.0
+
+
+def test_vision_timing_includes_host_side_work_when_benchmark_enabled():
+ def host_heavy_fn():
+ time.sleep(0.02)
+ return torch.tensor([1.0])
+
+ result, elapsed_ms = _run_with_synchronized_wall_timing(
+ host_heavy_fn,
+ device=torch.device("cpu"),
+ enabled=True,
+ )
+
+ torch.testing.assert_close(result, torch.tensor([1.0]))
+ assert elapsed_ms >= 15.0
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
+def test_cuda_vision_timing_includes_host_side_work_when_benchmark_enabled():
+ device = torch.device("cuda")
+
+ def host_heavy_cuda_fn():
+ time.sleep(0.02)
+ return torch.ones((1,), device=device)
+
+ result, elapsed_ms = _run_with_synchronized_wall_timing(
+ host_heavy_cuda_fn,
+ device=device,
+ enabled=True,
+ )
+
+ assert result.device.type == "cuda"
+ assert elapsed_ms >= 15.0
+
+
+def test_cuda_vision_timing_uses_wall_clock_not_event_elapsed(monkeypatch):
+ class _FakeCudaEvent:
+ def __init__(self, *args, **kwargs):
+ del args, kwargs
+
+ def record(self):
+ pass
+
+ def elapsed_time(self, other):
+ del other
+ return 0.0
+
+ monkeypatch.setattr(torch.cuda, "Event", _FakeCudaEvent)
+ monkeypatch.setattr(torch.cuda, "synchronize", lambda *args, **kwargs: None)
+
+ def host_heavy_fn():
+ time.sleep(0.02)
+ return torch.tensor([1.0])
+
+ result, elapsed_ms = _run_with_synchronized_wall_timing(
+ host_heavy_fn,
+ device=torch.device("cuda"),
+ enabled=True,
+ )
+
+ torch.testing.assert_close(result, torch.tensor([1.0]))
+ assert elapsed_ms >= 15.0
+
+
def test_vision_interpolation_indices_match_sglang_hf():
model = Qwen3VLVisionModel(
depth=0,
diff --git a/pymllm/tests/test_rms_norm.py b/pymllm/tests/test_rms_norm.py
index 9663f5444..16756a8aa 100644
--- a/pymllm/tests/test_rms_norm.py
+++ b/pymllm/tests/test_rms_norm.py
@@ -25,3 +25,23 @@ def fail_fused_add_rmsnorm(*args, **kwargs):
_, residual_out = norm(x, residual)
torch.testing.assert_close(residual_out, x + residual)
+
+
+def test_flashinfer_norm_device_properties_patch_adds_optin(monkeypatch):
+ class FakeProps:
+ shared_memory_per_block = 49152
+ shared_memory_per_multiprocessor = 167936
+
+ monkeypatch.setattr(rms_norm_module, "_PATCHED_CUDA_DEVICE_PROPERTIES", False)
+ monkeypatch.setattr(rms_norm_module.torch.cuda, "is_available", lambda: True)
+ monkeypatch.setattr(
+ rms_norm_module.torch.cuda,
+ "get_device_properties",
+ lambda device=0: FakeProps(),
+ )
+
+ rms_norm_module._patch_cuda_device_properties_for_flashinfer_norm()
+
+ props = rms_norm_module.torch.cuda.get_device_properties(0)
+ assert props.shared_memory_per_block_optin == 167936
+ assert props.shared_memory_per_block == 49152