Skip to content

akaiHuang/cuda2mlx

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

cuda2mlx

把 CUDA / PyTorch 模型 port 到 Apple Silicon MLX —— 4 層自動化框架

cuda2mlx PT vs MLX benchmark

License: MIT Python 3.10+ MLX 0.18+ Tests Smoke Status

LLM · Vision · 3D · Sparse · Custom Metal Kernel · Signal — 一站式 CUDA / PyTorch → MLX port toolkit,覆蓋所有模型類型.


🍎 什麼是 MLX?為什麼要 port 過去?

MLXApple 2023 年底開源的 ML 陣列框架,專為 Apple Silicon(M1/M2/M3/M4)設計。可以想成「Apple 自家的 PyTorch + JAX 混血」。

特色 意義
Unified memory M-chip 上 CPU/GPU 共用記憶體,省掉 PyTorch tensor.to('cuda') 的 host↔device 拷貝
Lazy evaluation 像 JAX 那樣構圖再執行,可融合 ops
NumPy-like API mx.arraymx.fft.rfft2mx.random.normal,PyTorch 熟手幾乎 0 學習成本
Metal-backed 走 Apple 自家 GPU API,比 PyTorch MPS backend 通常更快、更穩
CUDA backend (2026-) MLX 程式碼可直接跑 NVIDIA Linux server,寫一次跨平台
mx.fast.metal_kernel 內聯 Metal C++,等同 CUDA Triton 那種 fused kernel 能力

Port CUDA/PyTorch → MLX 的實際好處

  1. M-series Mac 跑 inference 不用買 NVIDIA GPU(本機 dev / 小團隊省 5000-30000 USD)
  2. 省 host↔device 拷貝,small-batch latency 體感快很多
  3. 能用 ANE(Apple Neural Engine 走 Core ML 路徑,更省電)
  4. 同一份 MLX 程式碼也能在 Linux NVIDIA 跑(透過 mlx[cuda] backend)

但是 — 每個模型 port 一次很痛(NCHW↔NHWC、無 complex、無 spconv、custom kernel 沒對應…)。所以才有 cuda2mlx


為什麼需要這個

每個 PyTorch model 都要重新 port 到 MLX 一次。LLM / vision / 3D / 影像修復 / 音訊 / 自訂 kernel 都在重複造輪子。

HuggingFace transformers-to-mlx Skill 只解 transformers LLM;mflux 只解 FLUX;其他 80% 沒人做框架。

cuda2mlx一站式:四層轉換 Tier 從規則式 layout 改寫到 mx.fast.metal_kernel 翻譯,全 model 類型涵蓋 — LLM、vision foundation、3D 生成、Mamba SSM、Gaussian Splatting、特殊函數。

基於 25 個既有 MLX port 累積的 paired snippet 抽取而成。


✅ 實測結果

完整 PyTorch ViT → MLX port,數值差 8.49 × 10⁻⁷

PT/MLX TinyViT model:
  arch         : dim=48 depth=4 heads=4 patches=16
  params       : 123,802
  input shape  : PT=(2, 3, 32, 32) NCHW, MLX=(2, 32, 32, 3) NHWC
  max_abs_diff : 8.49e-07
  status       : PASS ✓

完整 LLaMA-style transformer block → MLX port,數值差 2.38 × 10⁻⁷

PT/MLX TinyLLaMA block:
  arch         : dim=64 n_heads=8 n_kv_heads=2 hidden=128 seq_len=16
  params       : 34,944
  components   : RMSNorm + 1D RoPE + GQA(8/2) + SwiGLU + causal mask
  max_abs_diff : 2.38e-07
  status       : PASS ✓

11 / 11 E2E parity test — 全部對 PyTorch 數值對齊

11/11 E2E parity test results

9 個 cookbook entry 對 PyTorch ground truth 數值對齊

Test Op max_abs_diff 狀態
1 SDPA 4.77e-07
2 reflect_pad 0.00 (完全一致)
3 rfft2 round-trip 9.54e-07
4 GroupNorm 4.77e-07
5 axial RoPE (norm-preserving) 0.00
6 grid_sample (bilinear) 1.19e-07
7 variable-length SDPA 4.77e-07
8 scatter_add 0.00
9 masked_set 0.00

📊 效能 — PyTorch (CPU) vs MLX (Metal) on M-series Mac

我們實測了 cuda2mlx port 後的 LLaMA-style block 在 Apple Silicon 上的 forward latency(這台機器,非 benchmark 引用):

Block 大小 params PT (CPU) MLX (Metal) 加速
dim=256, hidden=512, seq=64, B=4 558K 2.12 ms 2.07 ms 1.03×
dim=512, hidden=1024, seq=128, B=4 2.2M 5.11 ms 3.22 ms 1.59×
dim=1024, hidden=2752, seq=256, B=2 11M(≈ LLaMA-1B 單 block 級) 14.74 ms 6.84 ms 2.16×
dim=2048, hidden=5504, seq=512, B=1 44M(≈ LLaMA-7B 單 block 級) 45.85 ms 42.93 ms 1.07×

觀察:MLX 在 2M–11M params 範圍 1.5–2× 快;極小模型有 kernel launch overhead;超大模型 fp32 上 memory-bound。 用 python -m cuda2mlx.tests.bench_pt_vs_mlx 在自己 Mac 重跑。 Caveat:這是 PT-CPU vs MLX-Metal 對比;本機沒 NVIDIA 無法直接跟 CUDA 比。MLX 跟 NVIDIA H100 的對比可參考 Apple 官方 LLaMA benchmark

analyze CLI 在 5 個真實 CUDA repo 跑過

Repo 檔數 總命中 Tier 1/2/3/4 關鍵 markers
TRELLIS 117 126 37/29/30/30 flash_attn ×30, spconv ×27
ProPainter 98 199 183/14/0/2 deform_conv2d
LaMa 98 170 169/1/0/0 Conv-heavy
InstantMesh 44 22 17/5/0/0 AdaLN ×4
LGM 16 44 32/11/0/1 diff_gaussian

💬 推薦用法 — 丟給 Claude 一句話就 port 好

cuda2mlx 是設計給 LLM agent(Claude / Cursor / Claude Code) 當工具庫用的。最簡單的使用方式:

複製這個 repo URL,貼到 Claude,告訴它你想 port 哪個模型。

範例 prompt:

我想把這個模型 port 到 Apple Silicon MLX:
https://github.com/<作者>/<某個 CUDA/PyTorch 模型>

請參考 cuda2mlx 的轉換規則跟 cookbook:
https://github.com/akaiHuang/cuda2mlx

幫我:
1. 先跑 cuda2mlx.analyze 看這個 model 需要哪些 Tier 工作
2. 套 Tier 1 規則轉 state_dict + weight reshape
3. 用 cookbook 對應的 op 取代 PyTorch 寫法
4. 跑 cuda2mlx.tests.parity 確認數值對齊(atol < 1e-4)

Claude 會自動:

  • 抓 28+8 個 marker 出 coverage 報告(告訴你 ~80% 工作可自動化)
  • rules/ 處理 NCHW→NHWC、Conv weight transpose、Sequentiallayers.N、刪 dropout
  • cookbook/llmcookbook/ 套對應 pattern(RMSNorm / RoPE / SDPA / FFT / GroupNorm / scatter…)
  • Tier 3 / Tier 4 卡關時參考 hard/ 5 個 paired snippet 當 few-shot
  • tests/parity.py 給你數值差報告

不用看完所有文件、不用懂 Metal kernel,讓 agent 用框架幫你 port


🚀 30 秒上手(手動)

如果想自己跑,不靠 agent:

git clone https://github.com/akaiHuang/cuda2mlx
cd cuda2mlx
pip install mlx numpy
# 分析任何 PyTorch repo,3 秒看完它需要做哪些轉換工作
python -m cuda2mlx.analyze.analyze /path/to/pytorch/repo
# 把 PyTorch state_dict 轉成 MLX 格式(自動處理 Conv weight reshape + key rename)
from cuda2mlx.rules import StateDictRenamer, auto_reshape

renamer = StateDictRenamer.default()
mlx_state = {
    renamer.apply(k): auto_reshape(k, v.numpy())
    for k, v in torch_state_dict.items()
}
# 用 cookbook 的 SDPA / RoPE / GroupNorm 取代 PyTorch 對應 op
from cuda2mlx.cookbook.attention.sdpa import scaled_dot_product_attention
from cuda2mlx.cookbook.attention.axial_rope import compute_axial_cis, apply_rotary_enc
from cuda2mlx.cookbook.conv.reflect_pad import reflect_pad_nhwc
# 跑 numerical parity 確認 port 正確
from cuda2mlx.tests.parity import compare_modules

result = compare_modules(pt_model, mlx_model, input_shape=(1, 3, 224, 224),
                         nchw_to_nhwc=True)
print(result)  # ParityResult(PASS | max_abs=8.49e-07 ...)

🏗️ 四層轉換 Tier(80% 工作可自動化)

5 real CUDA repos Tier breakdown
實測:5 個真實 PyTorch repo 各自需要的 Tier 工作量分布(561 markers 總命中,跑 python -m cuda2mlx.analyze.analyze <repo> 可重現)

整個 port 流程一張圖

flowchart LR
    A[PyTorch / CUDA<br/>source repo] --> B[analyze<br/>36-marker scan]
    B --> C{Tier 分派}
    C -->|45%| T1[Tier 1<br/>Rules<br/>auto rewrite]
    C -->|30%| T2[Tier 2<br/>Cookbook<br/>pattern templates]
    C -->|15%| T3[Tier 3<br/>Metal kernels<br/>mx.fast.metal_kernel]
    C -->|10%| T4[Tier 4<br/>Hard manual<br/>LLM few-shot]
    T1 --> M[MLX model]
    T2 --> M
    T3 --> M
    T4 --> M
    M --> P[parity test<br/>vs PyTorch]
    P -->|max_abs_diff &lt; 1e-4| S[✅ shipped]
    P -->|fail| C

    style A fill:#FF6B6B,color:#fff,stroke-width:0px
    style M fill:#0A84FF,color:#fff,stroke-width:0px
    style S fill:#30D158,color:#fff,stroke-width:0px
    style P fill:#FF9F0A,color:#fff,stroke-width:0px
    style T1 fill:#30D158,color:#fff,stroke-width:0px
    style T2 fill:#0A84FF,color:#fff,stroke-width:0px
    style T3 fill:#BF5AF2,color:#fff,stroke-width:0px
    style T4 fill:#FF9F0A,color:#fff,stroke-width:0px
Loading
Tier 內容 自動化 程式碼比例
1 — Rules 純規則:layout NCHW↔NHWC、Conv weight reshape、state_dict regex rename、dropout 移除、Sequential→layers ≥90% ~45%
2 — Cookbook Pattern templates:LLM(RMSNorm / 1D RoPE / GQA / SwiGLU / KV-cache / causal mask / sampling)+ Vision(SDPA / axial RoPE / reflect/grid_sample / FFT NCHW↔NHWC / GroupNorm / AdaLN / scatter) 60–80% ~30%
3 — Metal kernels 業界空白區 — CUDA C++ → mx.fast.metal_kernel:fused element-wise、Mamba selective scan、Flow Matching sampler 40–60% ~15%
4 — Hard manual Paired (PT, MLX) snippet library 給 LLM few-shot:DCNv2、FFC、Morton sparse attention、submanifold 3D conv、scipy.special 0–20% ~10%

📦 內容

cuda2mlx/
├── analyze/          # 36-marker grep → Tier 分類報告
├── rules/            # 5 規則模組(layout / weight / state_dict / dropout / sequential)
├── cookbook/         # 18 pattern entry
│   ├── llm/          # RMSNorm / 1D RoPE / SwiGLU / GQA / KV-cache / causal mask / sampling
│   ├── attention/    # SDPA / axial RoPE / variable-length
│   ├── conv/         # reflect_pad / grid_sample / dilation
│   ├── fft/          # rfft2 NCHW↔NHWC
│   ├── norm/         # GroupNorm / AdaLN
│   └── indexing/     # scatter / bool_mask
├── metal_kernels/    # CUDA→Metal 翻譯範例 + cheatsheet
├── hard/             # 5 paired snippet(LLM few-shot pool)
├── tests/            # parity harness + 11 個 E2E test(含 TinyViT + TinyLLaMA)
└── docs/

6382 行 Python · 52 檔案 · 6 個 Markdown · 0 runtime 依賴(除 mlx + numpy)

Cookbook 全索引(18 entry)

類別 入口 對應 PyTorch / CUDA
🔤 LLM cookbook/llm/rmsnorm.py LlamaRMSNorm, T5LayerNorm
cookbook/llm/rope_1d.py LLaMA LlamaRotaryEmbedding (1D, torch.polar)
cookbook/llm/swiglu.py LLaMA LlamaMLP (gate/up/down_proj)
cookbook/llm/gqa.py HF repeat_kv + grouped attention
cookbook/llm/kv_cache.py HF DynamicCache / StaticCache
cookbook/llm/causal_mask.py torch.triu(... -inf), is_causal=True
cookbook/llm/sampling.py LogitsProcessor, top_p/top_k/temperature
🎯 Attention cookbook/attention/sdpa.py F.scaled_dot_product_attention
cookbook/attention/axial_rope.py SAM 3 2D axial RoPE (torch.polar, view_as_complex)
cookbook/attention/variable_length.py nested tensor / per-sample loop
🖼️ Conv cookbook/conv/reflect_pad.py F.pad(mode='reflect')
cookbook/conv/grid_sample.py F.grid_sample(align_corners=True)
cookbook/conv/dilation.py Conv2d(dilation=d) effective padding
🌊 FFT cookbook/fft/rfft2_nchw.py torch.fft.rfft2 NCHW↔NHWC
📏 Norm cookbook/norm/groupnorm.py nn.GroupNorm (PT-compatible 模式)
cookbook/norm/adaln.py DiT AdaLN modulation
🔢 Indexing cookbook/indexing/scatter.py torch.scatter_add
cookbook/indexing/bool_mask.py tensor[mask] = value

🆚 跟其他工具比

工具 範圍 自動化深度
cuda2mlx(本框架) LLM · Vision · 3D · Sparse · Metal kernel · Signal(全) Tier 1 + 2 + 3 + 4
HF transformers-to-mlx Skill 只 transformers LLM LLM-assisted(單一範圍)
mflux 只 FLUX 圖像生成 單模型族
torch2mlx (SynapticSage) 36 個標準 NLP 架構 只 Tier 1
Xforge PT→MLX/CoreML GUI 只標準層
mlx[cuda] backend 反方向(MLX→CUDA 跑) n/a

唯一覆蓋所有模型類型的 port 框架 — LLM、vision、3D、signal、custom kernel 同一套工具一次解決。


🧪 跑測試

# 全模組 self-test(37/37 過)
for f in $(find cuda2mlx -name "*.py" -not -name "__init__.py"); do python3 $f; done

# E2E cookbook parity(9 個對 PyTorch 比對)
python -m cuda2mlx.tests.test_e2e_cookbook_parity

# E2E TinyViT 完整 vision model port parity
python -m cuda2mlx.tests.test_e2e_tiny_vit

# E2E TinyLLaMA transformer block port parity(RMSNorm+RoPE+GQA+SwiGLU)
python -m cuda2mlx.tests.test_e2e_tiny_llama

🗺️ Roadmap

  • v0.1 — rules + cookbook (vision) + metal kernel + hard examples + analyze CLI + parity harness
  • v0.2(本版)cookbook/llm/ 全 7 個 LLM ops + TinyLLaMA E2E parity
  • 🔜 v0.3 — LLM-assisted port runner(吃 PyTorch 檔,自動產 MLX 草稿 + iterative parity loop)
  • 🔜 v0.4 — VLM 支援、PyPI 發布、Apple Silicon CI
  • 📅 v0.5 — 訓練側組件(autocast、gradient checkpointing、fused optimizer)

💡 25 個既有 MLX port 資產

完整清單見 INVENTORY.md。下面是已抽進 framework 的代表:

  • 3D 生成 — Trellis(含 trellis2-apple/metal_kernels.py 業界少見 mx.fast.metal_kernel 實戰)、Hunyuan3D
  • Vision foundation — DINOv3、SAM 3(2D axial RoPE 無 complex)
  • 影像修復 — LaMa(FFC FFT conv)、ProPainter(DCNv2、RAFT optical flow)
  • SSM — Mamba selective scan MLX 實作(學界稀缺)
  • Signalmlx-stftmlx-special 全家桶(Bessel、Gamma、Hyp2F1、Airy、Wigner)
  • GS — Gaussian Splatting Q4 (sharp-mlx-q4)

🤝 貢獻

歡迎開 Issue / PR。重點需要的方向:

  • 補更多 Tier 2 cookbook entry(每加一個 pattern,受益的 model 數量是非線性的)
  • Tier 3 CUDA→Metal kernel 翻譯(看 metal_kernels/templates/cuda_to_metal_cheatsheet.md
  • 真實 model port 範例(在 examples/ 加 end-to-end demo)

📜 授權

MIT — 見 LICENSE.

來源

@akaiHuang 啟動。框架的 cookbook/hard/ 不是憑空設計 — 全部從 25 個已驗證可運作的 MLX port(trellis-mlx / sam3-mlx / lama-mlx / propainter-mlx / meadow-mamba / sharp-mlx-q4 / mlx-special 全家桶…)抽出來的 (PyTorch, MLX) 配對範例。踩過的雷已經填好在 cookbook 裡,你不用重踩。

About

CUDA/PyTorch → MLX port toolkit. LLM · Vision · 3D · Sparse · Custom Metal Kernel · Signal. 4-Tier conversion framework with 11/11 E2E parity tests.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages