LLM · Vision · 3D · Sparse · Custom Metal Kernel · Signal — 一站式 CUDA / PyTorch → MLX port toolkit,覆蓋所有模型類型.
MLX 是 Apple 2023 年底開源的 ML 陣列框架,專為 Apple Silicon(M1/M2/M3/M4)設計。可以想成「Apple 自家的 PyTorch + JAX 混血」。
| 特色 | 意義 |
|---|---|
| Unified memory | M-chip 上 CPU/GPU 共用記憶體,省掉 PyTorch tensor.to('cuda') 的 host↔device 拷貝 |
| Lazy evaluation | 像 JAX 那樣構圖再執行,可融合 ops |
| NumPy-like API | mx.array、mx.fft.rfft2、mx.random.normal,PyTorch 熟手幾乎 0 學習成本 |
| Metal-backed | 走 Apple 自家 GPU API,比 PyTorch MPS backend 通常更快、更穩 |
| CUDA backend (2026-) | MLX 程式碼可直接跑 NVIDIA Linux server,寫一次跨平台 |
mx.fast.metal_kernel |
內聯 Metal C++,等同 CUDA Triton 那種 fused kernel 能力 |
Port CUDA/PyTorch → MLX 的實際好處:
- M-series Mac 跑 inference 不用買 NVIDIA GPU(本機 dev / 小團隊省 5000-30000 USD)
- 省 host↔device 拷貝,small-batch latency 體感快很多
- 能用 ANE(Apple Neural Engine 走 Core ML 路徑,更省電)
- 同一份 MLX 程式碼也能在 Linux NVIDIA 跑(透過
mlx[cuda]backend)
但是 — 每個模型 port 一次很痛(NCHW↔NHWC、無 complex、無 spconv、custom kernel 沒對應…)。所以才有 cuda2mlx。
每個 PyTorch model 都要重新 port 到 MLX 一次。LLM / vision / 3D / 影像修復 / 音訊 / 自訂 kernel 都在重複造輪子。
HuggingFace
transformers-to-mlxSkill 只解 transformers LLM;mflux只解 FLUX;其他 80% 沒人做框架。
cuda2mlx 是一站式:四層轉換 Tier 從規則式 layout 改寫到 mx.fast.metal_kernel 翻譯,全 model 類型涵蓋 — LLM、vision foundation、3D 生成、Mamba SSM、Gaussian Splatting、特殊函數。
基於 25 個既有 MLX port 累積的 paired snippet 抽取而成。
PT/MLX TinyViT model:
arch : dim=48 depth=4 heads=4 patches=16
params : 123,802
input shape : PT=(2, 3, 32, 32) NCHW, MLX=(2, 32, 32, 3) NHWC
max_abs_diff : 8.49e-07
status : PASS ✓
PT/MLX TinyLLaMA block:
arch : dim=64 n_heads=8 n_kv_heads=2 hidden=128 seq_len=16
params : 34,944
components : RMSNorm + 1D RoPE + GQA(8/2) + SwiGLU + causal mask
max_abs_diff : 2.38e-07
status : PASS ✓
| Test | Op | max_abs_diff | 狀態 |
|---|---|---|---|
| 1 | SDPA | 4.77e-07 | ✅ |
| 2 | reflect_pad | 0.00 (完全一致) | ✅ |
| 3 | rfft2 round-trip | 9.54e-07 | ✅ |
| 4 | GroupNorm | 4.77e-07 | ✅ |
| 5 | axial RoPE (norm-preserving) | 0.00 | ✅ |
| 6 | grid_sample (bilinear) | 1.19e-07 | ✅ |
| 7 | variable-length SDPA | 4.77e-07 | ✅ |
| 8 | scatter_add | 0.00 | ✅ |
| 9 | masked_set | 0.00 | ✅ |
我們實測了 cuda2mlx port 後的 LLaMA-style block 在 Apple Silicon 上的 forward latency(這台機器,非 benchmark 引用):
| Block 大小 | params | PT (CPU) | MLX (Metal) | 加速 |
|---|---|---|---|---|
| dim=256, hidden=512, seq=64, B=4 | 558K | 2.12 ms | 2.07 ms | 1.03× |
| dim=512, hidden=1024, seq=128, B=4 | 2.2M | 5.11 ms | 3.22 ms | 1.59× |
| dim=1024, hidden=2752, seq=256, B=2 | 11M(≈ LLaMA-1B 單 block 級) | 14.74 ms | 6.84 ms | 2.16× |
| dim=2048, hidden=5504, seq=512, B=1 | 44M(≈ LLaMA-7B 單 block 級) | 45.85 ms | 42.93 ms | 1.07× |
觀察:MLX 在 2M–11M params 範圍 1.5–2× 快;極小模型有 kernel launch overhead;超大模型 fp32 上 memory-bound。 用
python -m cuda2mlx.tests.bench_pt_vs_mlx在自己 Mac 重跑。 Caveat:這是 PT-CPU vs MLX-Metal 對比;本機沒 NVIDIA 無法直接跟 CUDA 比。MLX 跟 NVIDIA H100 的對比可參考 Apple 官方 LLaMA benchmark。
| Repo | 檔數 | 總命中 | Tier 1/2/3/4 | 關鍵 markers |
|---|---|---|---|---|
| TRELLIS | 117 | 126 | 37/29/30/30 | flash_attn ×30, spconv ×27 |
| ProPainter | 98 | 199 | 183/14/0/2 | deform_conv2d ✓ |
| LaMa | 98 | 170 | 169/1/0/0 | Conv-heavy |
| InstantMesh | 44 | 22 | 17/5/0/0 | AdaLN ×4 |
| LGM | 16 | 44 | 32/11/0/1 | diff_gaussian ✓ |
cuda2mlx 是設計給 LLM agent(Claude / Cursor / Claude Code) 當工具庫用的。最簡單的使用方式:
複製這個 repo URL,貼到 Claude,告訴它你想 port 哪個模型。
範例 prompt:
我想把這個模型 port 到 Apple Silicon MLX:
https://github.com/<作者>/<某個 CUDA/PyTorch 模型>
請參考 cuda2mlx 的轉換規則跟 cookbook:
https://github.com/akaiHuang/cuda2mlx
幫我:
1. 先跑 cuda2mlx.analyze 看這個 model 需要哪些 Tier 工作
2. 套 Tier 1 規則轉 state_dict + weight reshape
3. 用 cookbook 對應的 op 取代 PyTorch 寫法
4. 跑 cuda2mlx.tests.parity 確認數值對齊(atol < 1e-4)
Claude 會自動:
- 抓 28+8 個 marker 出 coverage 報告(告訴你 ~80% 工作可自動化)
- 套
rules/處理 NCHW→NHWC、Conv weight transpose、Sequential→layers.N、刪 dropout - 從
cookbook/llm跟cookbook/套對應 pattern(RMSNorm / RoPE / SDPA / FFT / GroupNorm / scatter…) - Tier 3 / Tier 4 卡關時參考
hard/5 個 paired snippet 當 few-shot - 跑
tests/parity.py給你數值差報告
不用看完所有文件、不用懂 Metal kernel,讓 agent 用框架幫你 port。
如果想自己跑,不靠 agent:
git clone https://github.com/akaiHuang/cuda2mlx
cd cuda2mlx
pip install mlx numpy# 分析任何 PyTorch repo,3 秒看完它需要做哪些轉換工作
python -m cuda2mlx.analyze.analyze /path/to/pytorch/repo# 把 PyTorch state_dict 轉成 MLX 格式(自動處理 Conv weight reshape + key rename)
from cuda2mlx.rules import StateDictRenamer, auto_reshape
renamer = StateDictRenamer.default()
mlx_state = {
renamer.apply(k): auto_reshape(k, v.numpy())
for k, v in torch_state_dict.items()
}# 用 cookbook 的 SDPA / RoPE / GroupNorm 取代 PyTorch 對應 op
from cuda2mlx.cookbook.attention.sdpa import scaled_dot_product_attention
from cuda2mlx.cookbook.attention.axial_rope import compute_axial_cis, apply_rotary_enc
from cuda2mlx.cookbook.conv.reflect_pad import reflect_pad_nhwc# 跑 numerical parity 確認 port 正確
from cuda2mlx.tests.parity import compare_modules
result = compare_modules(pt_model, mlx_model, input_shape=(1, 3, 224, 224),
nchw_to_nhwc=True)
print(result) # ParityResult(PASS | max_abs=8.49e-07 ...)
實測:5 個真實 PyTorch repo 各自需要的 Tier 工作量分布(561 markers 總命中,跑
python -m cuda2mlx.analyze.analyze <repo> 可重現)
flowchart LR
A[PyTorch / CUDA<br/>source repo] --> B[analyze<br/>36-marker scan]
B --> C{Tier 分派}
C -->|45%| T1[Tier 1<br/>Rules<br/>auto rewrite]
C -->|30%| T2[Tier 2<br/>Cookbook<br/>pattern templates]
C -->|15%| T3[Tier 3<br/>Metal kernels<br/>mx.fast.metal_kernel]
C -->|10%| T4[Tier 4<br/>Hard manual<br/>LLM few-shot]
T1 --> M[MLX model]
T2 --> M
T3 --> M
T4 --> M
M --> P[parity test<br/>vs PyTorch]
P -->|max_abs_diff < 1e-4| S[✅ shipped]
P -->|fail| C
style A fill:#FF6B6B,color:#fff,stroke-width:0px
style M fill:#0A84FF,color:#fff,stroke-width:0px
style S fill:#30D158,color:#fff,stroke-width:0px
style P fill:#FF9F0A,color:#fff,stroke-width:0px
style T1 fill:#30D158,color:#fff,stroke-width:0px
style T2 fill:#0A84FF,color:#fff,stroke-width:0px
style T3 fill:#BF5AF2,color:#fff,stroke-width:0px
style T4 fill:#FF9F0A,color:#fff,stroke-width:0px
| Tier | 內容 | 自動化 | 程式碼比例 |
|---|---|---|---|
| 1 — Rules | 純規則:layout NCHW↔NHWC、Conv weight reshape、state_dict regex rename、dropout 移除、Sequential→layers | ≥90% | ~45% |
| 2 — Cookbook | Pattern templates:LLM(RMSNorm / 1D RoPE / GQA / SwiGLU / KV-cache / causal mask / sampling)+ Vision(SDPA / axial RoPE / reflect/grid_sample / FFT NCHW↔NHWC / GroupNorm / AdaLN / scatter) | 60–80% | ~30% |
| 3 — Metal kernels | 業界空白區 — CUDA C++ → mx.fast.metal_kernel:fused element-wise、Mamba selective scan、Flow Matching sampler |
40–60% | ~15% |
| 4 — Hard manual | Paired (PT, MLX) snippet library 給 LLM few-shot:DCNv2、FFC、Morton sparse attention、submanifold 3D conv、scipy.special | 0–20% | ~10% |
cuda2mlx/
├── analyze/ # 36-marker grep → Tier 分類報告
├── rules/ # 5 規則模組(layout / weight / state_dict / dropout / sequential)
├── cookbook/ # 18 pattern entry
│ ├── llm/ # RMSNorm / 1D RoPE / SwiGLU / GQA / KV-cache / causal mask / sampling
│ ├── attention/ # SDPA / axial RoPE / variable-length
│ ├── conv/ # reflect_pad / grid_sample / dilation
│ ├── fft/ # rfft2 NCHW↔NHWC
│ ├── norm/ # GroupNorm / AdaLN
│ └── indexing/ # scatter / bool_mask
├── metal_kernels/ # CUDA→Metal 翻譯範例 + cheatsheet
├── hard/ # 5 paired snippet(LLM few-shot pool)
├── tests/ # parity harness + 11 個 E2E test(含 TinyViT + TinyLLaMA)
└── docs/
6382 行 Python · 52 檔案 · 6 個 Markdown · 0 runtime 依賴(除 mlx + numpy)
| 類別 | 入口 | 對應 PyTorch / CUDA |
|---|---|---|
| 🔤 LLM | cookbook/llm/rmsnorm.py |
LlamaRMSNorm, T5LayerNorm |
cookbook/llm/rope_1d.py |
LLaMA LlamaRotaryEmbedding (1D, torch.polar) |
|
cookbook/llm/swiglu.py |
LLaMA LlamaMLP (gate/up/down_proj) |
|
cookbook/llm/gqa.py |
HF repeat_kv + grouped attention |
|
cookbook/llm/kv_cache.py |
HF DynamicCache / StaticCache |
|
cookbook/llm/causal_mask.py |
torch.triu(... -inf), is_causal=True |
|
cookbook/llm/sampling.py |
LogitsProcessor, top_p/top_k/temperature |
|
| 🎯 Attention | cookbook/attention/sdpa.py |
F.scaled_dot_product_attention |
cookbook/attention/axial_rope.py |
SAM 3 2D axial RoPE (torch.polar, view_as_complex) |
|
cookbook/attention/variable_length.py |
nested tensor / per-sample loop | |
| 🖼️ Conv | cookbook/conv/reflect_pad.py |
F.pad(mode='reflect') |
cookbook/conv/grid_sample.py |
F.grid_sample(align_corners=True) |
|
cookbook/conv/dilation.py |
Conv2d(dilation=d) effective padding |
|
| 🌊 FFT | cookbook/fft/rfft2_nchw.py |
torch.fft.rfft2 NCHW↔NHWC |
| 📏 Norm | cookbook/norm/groupnorm.py |
nn.GroupNorm (PT-compatible 模式) |
cookbook/norm/adaln.py |
DiT AdaLN modulation | |
| 🔢 Indexing | cookbook/indexing/scatter.py |
torch.scatter_add |
cookbook/indexing/bool_mask.py |
tensor[mask] = value |
| 工具 | 範圍 | 自動化深度 |
|---|---|---|
cuda2mlx(本框架) |
LLM · Vision · 3D · Sparse · Metal kernel · Signal(全) | Tier 1 + 2 + 3 + 4 |
HF transformers-to-mlx Skill |
只 transformers LLM | LLM-assisted(單一範圍) |
mflux |
只 FLUX 圖像生成 | 單模型族 |
torch2mlx (SynapticSage) |
36 個標準 NLP 架構 | 只 Tier 1 |
Xforge |
PT→MLX/CoreML GUI | 只標準層 |
mlx[cuda] backend |
反方向(MLX→CUDA 跑) | n/a |
唯一覆蓋所有模型類型的 port 框架 — LLM、vision、3D、signal、custom kernel 同一套工具一次解決。
# 全模組 self-test(37/37 過)
for f in $(find cuda2mlx -name "*.py" -not -name "__init__.py"); do python3 $f; done
# E2E cookbook parity(9 個對 PyTorch 比對)
python -m cuda2mlx.tests.test_e2e_cookbook_parity
# E2E TinyViT 完整 vision model port parity
python -m cuda2mlx.tests.test_e2e_tiny_vit
# E2E TinyLLaMA transformer block port parity(RMSNorm+RoPE+GQA+SwiGLU)
python -m cuda2mlx.tests.test_e2e_tiny_llama- ✅ v0.1 — rules + cookbook (vision) + metal kernel + hard examples + analyze CLI + parity harness
- ✅ v0.2(本版) —
cookbook/llm/全 7 個 LLM ops + TinyLLaMA E2E parity - 🔜 v0.3 — LLM-assisted port runner(吃 PyTorch 檔,自動產 MLX 草稿 + iterative parity loop)
- 🔜 v0.4 — VLM 支援、PyPI 發布、Apple Silicon CI
- 📅 v0.5 — 訓練側組件(autocast、gradient checkpointing、fused optimizer)
完整清單見 INVENTORY.md。下面是已抽進 framework 的代表:
- 3D 生成 — Trellis(含
trellis2-apple/metal_kernels.py業界少見mx.fast.metal_kernel實戰)、Hunyuan3D - Vision foundation — DINOv3、SAM 3(2D axial RoPE 無 complex)
- 影像修復 — LaMa(FFC FFT conv)、ProPainter(DCNv2、RAFT optical flow)
- SSM — Mamba selective scan MLX 實作(學界稀缺)
- Signal —
mlx-stft、mlx-special全家桶(Bessel、Gamma、Hyp2F1、Airy、Wigner) - GS — Gaussian Splatting Q4 (
sharp-mlx-q4)
歡迎開 Issue / PR。重點需要的方向:
- 補更多 Tier 2 cookbook entry(每加一個 pattern,受益的 model 數量是非線性的)
- Tier 3 CUDA→Metal kernel 翻譯(看
metal_kernels/templates/cuda_to_metal_cheatsheet.md) - 真實 model port 範例(在
examples/加 end-to-end demo)
MIT — 見 LICENSE.
由 @akaiHuang 啟動。框架的 cookbook/ 跟 hard/ 不是憑空設計 — 全部從 25 個已驗證可運作的 MLX port(trellis-mlx / sam3-mlx / lama-mlx / propainter-mlx / meadow-mamba / sharp-mlx-q4 / mlx-special 全家桶…)抽出來的 (PyTorch, MLX) 配對範例。踩過的雷已經填好在 cookbook 裡,你不用重踩。
