Skip to content

[module_fused_split_gdr_update] refactor#3777

Open
amd-ruitang3 wants to merge 1 commit into
ROCm:mainfrom
amd-ruitang3:module_fused_split_gdr_update_refactor
Open

[module_fused_split_gdr_update] refactor#3777
amd-ruitang3 wants to merge 1 commit into
ROCm:mainfrom
amd-ruitang3:module_fused_split_gdr_update_refactor

Conversation

@amd-ruitang3

@amd-ruitang3 amd-ruitang3 commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

Motivation

remove torch depency. Build time drop to 10.6s from 35.9s (-70.5%)

Technical Details

Test Plan

Test Result

Submission Checklist

@amd-ruitang3 amd-ruitang3 requested review from a team and Copilot June 17, 2026 08:57
@github-actions

Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3777 --add-label <label>

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors the module_fused_split_gdr_update op to remove Torch/C10 dependencies from the C++/CUDA(HIP) compilation units by switching the kernel interface to aiter_tensor_t and using AITER’s thread-local HIP stream plumbing.

Changes:

  • Replaced torch::Tensor/TORCH_CHECK interfaces with aiter_tensor_t/AITER_CHECK and added HipDeviceGuard + aiter::getCurrentHIPStream() usage in the kernel wrapper.
  • Updated the pybind module to expose _set_current_hip_stream (develop-mode stream handoff) and updated Python to pre-allocate outputs / default tensors before calling into the torch-free binding.
  • Adjusted the pybind arg list for fused_split_gdr_update to accept an explicit output tensor (no allocation on the C++ side).

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
csrc/pybind/fused_split_gdr_update_pybind.cu Adds AITER stream setter export to support develop-mode stream handoff.
csrc/kernels/fused_split_gdr_update.cu Converts op wrapper to aiter_tensor_t and uses AITER stream + device guard; removes torch includes.
csrc/include/rocm_ops.hpp Updates pybind binding signature/args for the refactored op.
csrc/include/fused_split_gdr_update.h Updates public header signature to aiter_tensor_t and void return.
aiter/ops/fused_split_gdr_update.py Adds Python-side allocation/defaulting and forwards to the new torch-free binding.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +72 to +75
B = mixed_qkv.shape[0]
T = mixed_qkv.shape[2]
HV = num_heads_v
V = head_dim
Comment on lines +82 to +90
if initial_state_indices is None or initial_state_indices.numel() == 0:
initial_state_indices = torch.zeros(
(B,), dtype=torch.int32, device=mixed_qkv.device
)

if initial_state_source is None:
initial_state_source = torch.empty(
0, dtype=torch.float32, device=mixed_qkv.device
)
Comment thread csrc/include/rocm_ops.hpp
Comment on lines +2236 to +2240
py::arg("softplus_beta"), \
py::arg("softplus_threshold"), \
py::arg("scale"), \
py::arg("use_qk_l2norm_in_kernel"), \
py::arg("output"));
0, \
stream, \
reinterpret_cast<const __hip_bfloat16*>(mixed_qkv.data_ptr()), \
reinterpret_cast<float*>(A_log.data_ptr()), \
reinterpret_cast<const __hip_bfloat16*>(b_gate.data_ptr()), \
reinterpret_cast<__hip_bfloat16*>(o.data_ptr()), \
use_initial_state ? reinterpret_cast<float*>(initial_state_source.data_ptr()) : nullptr, \
reinterpret_cast<int32_t*>(initial_state_indices_ptr.data_ptr()), \
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants