HipKittens MXFP8 GEMM Support#566
Conversation
| [](const testing::TestParamInfo<DqGEMMTestSuite::ParamType>& info) { | ||
| return MKN(std::get<0>(info.param)) + "x" + TN(std::get<3>(info.param)); | ||
| return MKN(std::get<0>(info.param)) + "x" + | ||
| std::to_string(std::get<1>(info.param)) + "x" + |
There was a problem hiding this comment.
What is a point, they are set to false only
|
|
||
| return torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device=device) | ||
| key = (device, ub, grouped_gemm) | ||
| ws = _workspace_cache.get(key) |
There was a problem hiding this comment.
Why we don't rely on torch memory caching?
There was a problem hiding this comment.
I have made this change. I will need to run an E2E run to make sure that performance isn't affected, but should be ok given my understanding of torch.empty()
| if (use_hipkittens) { | ||
| auto param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k); | ||
|
|
||
| hipStream_t s = use_service_stream ? ss_ctl.stream : stream; |
There was a problem hiding this comment.
the same like with is_mxfp8, no point of having it defined for one branch only
| @@ -743,12 +786,15 @@ MAKE_DQ_GEMM_TEST(Testfp8xfp8xfp16, fp8, fp8, fp16) | |||
|
|
|||
| INSTANTIATE_TEST_SUITE_P(OperatorTest, DqGEMMTestSuite, | |||
There was a problem hiding this comment.
If you end up with having separate prefix for MXFP8, it has to be use for this suite for consistency
| @@ -30,7 +30,9 @@ std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes = { | |||
|
|
|||
| std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes_mxfp8 = { | |||
There was a problem hiding this comment.
test_case_sizes_mxfp8 is only used for DqGEMMTest, is it intention to add sizes there?
There was a problem hiding this comment.
Yes, I wanted to add the minimum possible size that hipKittens supports, which is 256x256x256
|
|
||
| is_mxfp8 = isinstance(A, MXFP8TensorStorage) or isinstance(B, MXFP8TensorStorage) | ||
| if is_mxfp8 and _use_hipkittens(): | ||
| a_size = A.size() if hasattr(A, "size") and callable(A.size) else A.shape |
There was a problem hiding this comment.
MXFP8TensorSttorage has callable size(). What other object could be here that require this condition
There was a problem hiding this comment.
I was considering a scenario where A or B was not MXFP8, but we always have them both as MXFP8 so I think it is ok to simplify the logic
|
|
||
| static int fp8_code(int dt) { | ||
| switch (dt) { | ||
| case KITTENS_FP8E4M3: return 0; |
There was a problem hiding this comment.
Are those and below codes just arbitrary indexes or some special values?
There was a problem hiding this comment.
fp8_code are the values used within v_mfma_scale_f32_16x16x128_f8f6f4 to designate whether we are using e5m2 or e4m3.
outcode is arbitrary, and is used for the switch to cast to 16-bit dtypes when needed.
There was a problem hiding this comment.
I found. they go down to dispatch_fp8_types. Please add similar comment here about 0/1 meaning or better replace 0/1 with enum/defines so they are easier tracked
There was a problem hiding this comment.
I have added a comment there to help keep things clear.
|
|
||
| static int fp8_code(int dt) { | ||
| switch (dt) { | ||
| case KITTENS_FP8E4M3: return 0; |
There was a problem hiding this comment.
I found. they go down to dispatch_fp8_types. Please add similar comment here about 0/1 meaning or better replace 0/1 with enum/defines so they are easier tracked
| return ws | ||
|
|
||
|
|
||
| def _use_hipkittens() -> bool: |
There was a problem hiding this comment.
maybe add cache so the env is read once
There was a problem hiding this comment.
Done. Cached both jax and pytorch.
| raise | ||
|
|
||
|
|
||
| def _use_hipkittens() -> bool: |
There was a problem hiding this comment.
maybe decorate it with cache not to re-read env everytime
Creates an MXFP8 GEMM with HipKittens that outperforms hipBLASlt, and offers additional epilogues such as BIAS and GELU AUX
Requires a workspace sized relative to the model. Often larger than hipBLASlt, but with significant performance improvements. Only builds for gfx950, and requires M / 256 and N / 256.
Adds hipKittens header library as a submodule.