Skip to content

HipKittens MXFP8 GEMM Support#566

Open
alextmagro wants to merge 33 commits into
devfrom
hipkittens_mxfp8
Open

HipKittens MXFP8 GEMM Support#566
alextmagro wants to merge 33 commits into
devfrom
hipkittens_mxfp8

Conversation

@alextmagro

@alextmagro alextmagro commented Apr 28, 2026

Copy link
Copy Markdown
Contributor

Creates an MXFP8 GEMM with HipKittens that outperforms hipBLASlt, and offers additional epilogues such as BIAS and GELU AUX

Requires a workspace sized relative to the model. Often larger than hipBLASlt, but with significant performance improvements. Only builds for gfx950, and requires M / 256 and N / 256.

Adds hipKittens header library as a submodule.

Comment thread tests/cpp/operator/test_cublaslt_gemm.cu Outdated
Comment thread tests/cpp/operator/test_cublaslt_gemm.cu Outdated
Comment thread tests/cpp/operator/test_cublaslt_gemm.cu Outdated
Comment thread tests/cpp/operator/test_cublaslt_gemm.cu
Comment thread tests/jax/utils.py
Comment thread tests/cpp/operator/test_cublaslt_gemm.cu Outdated
Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.hip Outdated
Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.cpp
Comment thread transformer_engine/pytorch/cpp_extensions/gemm.py Outdated
Comment thread transformer_engine/pytorch/cpp_extensions/gemm.py
@alextmagro alextmagro requested a review from wangye805 May 5, 2026 20:26
Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.cpp
Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.cpp
Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.cpp
Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.cpp
Comment thread tests/cpp/operator/test_cublaslt_gemm.cu
Comment thread tests/cpp/operator/test_cublaslt_gemm.cu
Comment thread tests/cpp/operator/test_cublaslt_gemm.cu Outdated
[](const testing::TestParamInfo<DqGEMMTestSuite::ParamType>& info) {
return MKN(std::get<0>(info.param)) + "x" + TN(std::get<3>(info.param));
return MKN(std::get<0>(info.param)) + "x" +
std::to_string(std::get<1>(info.param)) + "x" +

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is a point, they are set to false only

Comment thread tests/cpp/operator/test_cublaslt_gemm.cu Outdated
Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.h
Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.cpp
Comment thread transformer_engine/common/CMakeLists.txt Outdated

return torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device=device)
key = (device, ub, grouped_gemm)
ws = _workspace_cache.get(key)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we don't rely on torch memory caching?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have made this change. I will need to run an E2E run to make sure that performance isn't affected, but should be ok given my understanding of torch.empty()

@alextmagro alextmagro requested review from aris134 and ipanfilo May 12, 2026 13:24
@alextmagro alextmagro requested a review from ipanfilo May 14, 2026 17:18
@alextmagro alextmagro added ci-level 3 CI test level 3 and removed ci-level 1 CI test level 1 labels May 14, 2026
Comment thread transformer_engine/jax/cpp_extensions/gemm.py
Comment thread transformer_engine/common/gemm/rocm_gemm.cu Outdated
Comment thread transformer_engine/common/gemm/rocm_gemm.cu
Comment thread transformer_engine/common/gemm/rocm_gemm.cu
if (use_hipkittens) {
auto param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k);

hipStream_t s = use_service_stream ? ss_ctl.stream : stream;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the same like with is_mxfp8, no point of having it defined for one branch only

Comment thread tests/cpp/operator/test_cublaslt_gemm.cu Outdated
@@ -743,12 +786,15 @@ MAKE_DQ_GEMM_TEST(Testfp8xfp8xfp16, fp8, fp8, fp16)

INSTANTIATE_TEST_SUITE_P(OperatorTest, DqGEMMTestSuite,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you end up with having separate prefix for MXFP8, it has to be use for this suite for consistency

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -30,7 +30,9 @@ std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes = {

std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes_mxfp8 = {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_case_sizes_mxfp8 is only used for DqGEMMTest, is it intention to add sizes there?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I wanted to add the minimum possible size that hipKittens supports, which is 256x256x256

Comment thread tests/cpp/operator/test_cublaslt_gemm.cu
Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.h
Comment thread transformer_engine/jax/cpp_extensions/gemm.py

is_mxfp8 = isinstance(A, MXFP8TensorStorage) or isinstance(B, MXFP8TensorStorage)
if is_mxfp8 and _use_hipkittens():
a_size = A.size() if hasattr(A, "size") and callable(A.size) else A.shape

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MXFP8TensorSttorage has callable size(). What other object could be here that require this condition

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was considering a scenario where A or B was not MXFP8, but we always have them both as MXFP8 so I think it is ok to simplify the logic

@alextmagro alextmagro requested a review from ipanfilo May 18, 2026 20:43
Comment thread tests/cpp/operator/test_cublaslt_gemm.cu Outdated
@alextmagro alextmagro requested a review from ipanfilo May 29, 2026 19:34

static int fp8_code(int dt) {
switch (dt) {
case KITTENS_FP8E4M3: return 0;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are those and below codes just arbitrary indexes or some special values?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fp8_code are the values used within v_mfma_scale_f32_16x16x128_f8f6f4 to designate whether we are using e5m2 or e4m3.
outcode is arbitrary, and is used for the switch to cast to 16-bit dtypes when needed.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found. they go down to dispatch_fp8_types. Please add similar comment here about 0/1 meaning or better replace 0/1 with enum/defines so they are easier tracked

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added a comment there to help keep things clear.

@alextmagro alextmagro requested a review from ipanfilo June 17, 2026 21:09

static int fp8_code(int dt) {
switch (dt) {
case KITTENS_FP8E4M3: return 0;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found. they go down to dispatch_fp8_types. Please add similar comment here about 0/1 meaning or better replace 0/1 with enum/defines so they are easier tracked

return ws


def _use_hipkittens() -> bool:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add cache so the env is read once

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Cached both jax and pytorch.

raise


def _use_hipkittens() -> bool:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe decorate it with cache not to re-read env everytime

@alextmagro alextmagro requested a review from ipanfilo June 18, 2026 22:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 3 CI test level 3

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants