diff --git a/csrc/config.hpp b/csrc/config.hpp index 0e4f5b06..c67ba936 100644 --- a/csrc/config.hpp +++ b/csrc/config.hpp @@ -133,8 +133,9 @@ struct LowLatencyLayout { return reinterpret_cast(reinterpret_cast(ptr) + count); } - LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) { - const int num_scales = hidden / 128; + LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts, + int quant_group_size = 128) { + const int num_scales = hidden / quant_group_size; // Dispatch and combine layout: // - 2 symmetric odd/even send buffer @@ -143,7 +144,7 @@ struct LowLatencyLayout { // Message sizes // NOTES: you should add a control `int4` for combine messages if you want to do data transformation - // NOTES: `num_scales * sizeof(nv_bfloat162)` means the per-128-channel min/max + // NOTES: `num_scales * sizeof(nv_bfloat162)` means the per-channel min/max EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden); size_t num_bytes_per_dispatch_msg = sizeof(int4) + std::max(hidden * sizeof(nv_bfloat16), hidden + num_scales * sizeof(float)); size_t num_bytes_per_combine_msg = num_scales * sizeof(nv_bfloat162) + hidden * sizeof(nv_bfloat16); @@ -187,8 +188,9 @@ struct LowLatencyLayout { } }; -size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) { - auto num_bytes = LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts).total_bytes; +size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts, + int quant_group_size = 128) { + auto num_bytes = LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, quant_group_size).total_bytes; return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * NUM_BUFFER_ALIGNMENT_BYTES; } diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index 0c6108c9..a1c8f473 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -151,7 +151,7 @@ Buffer::Buffer(int rank, place, phi::distributed::CommType::ALLTOALL); calc_ctx = reinterpret_cast( reinterpret_cast(pg)->GetDeviceContext(place, true)); - return at::cuda::getStreamFromExternal(comm_ctx->GetStream(), device_id); + return at::cuda::CUDAStream(comm_ctx->GetStream()); }()), shared_memory_allocator(use_fabric) { // Metadata memory @@ -409,7 +409,7 @@ Buffer::get_dispatch_layout( // Allocate all tensors on comm stream if set // NOTES: do not allocate tensors upfront! - auto compute_stream = at::cuda::getStreamFromExternal(calc_ctx->stream(), device_id); + auto compute_stream = at::cuda::CUDAStream(calc_ctx->stream()); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() and async); deep_ep::SetAllocatorStreamForGPUContext(comm_stream, calc_ctx); @@ -476,6 +476,8 @@ std::tuple, + std::optional, std::optional> Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional& x_scales, @@ -493,7 +495,10 @@ Buffer::intranode_dispatch(const torch::Tensor& x, std::optional& previous_event, bool async, bool allocate_on_comm_stream, - bool skip_x_record_stream) { + bool skip_x_record_stream, + int quant_group_size, + bool use_mask_prmt, + int max_tokens_per_expert) { bool cached_mode = cached_rank_prefix_matrix.has_value(); // One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving. @@ -538,6 +543,15 @@ Buffer::intranode_dispatch(const torch::Tensor& x, auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)); auto num_experts = cached_mode ? 0 : static_cast(num_tokens_per_expert->size(0)), num_local_experts = num_experts / num_ranks; + // use_mask_prmt checks + if (use_mask_prmt) { + EP_HOST_ASSERT(quant_group_size == 32); + EP_HOST_ASSERT(x_scales.has_value()); + EP_HOST_ASSERT(topk_idx.has_value()); + EP_HOST_ASSERT(max_tokens_per_expert > 0); + EP_HOST_ASSERT(num_local_experts > 0); + } + // Top-k checks int num_topk = 0; topk_idx_t* topk_idx_ptr = nullptr; @@ -556,22 +570,23 @@ Buffer::intranode_dispatch(const torch::Tensor& x, } // FP8 scales checks - float* x_scales_ptr = nullptr; + void* x_scales_ptr = nullptr; int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0; if (x_scales.has_value()) { EP_HOST_ASSERT(x.element_size() == 1); - EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32 or x_scales->scalar_type() == torch::kInt); + EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32 or x_scales->scalar_type() == torch::kInt or + x_scales->scalar_type() == torch::kByte); EP_HOST_ASSERT(x_scales->dim() == 2); EP_HOST_ASSERT(x_scales->size(0) == num_tokens); num_scales = x_scales->dim() == 1 ? 1 : static_cast(x_scales->size(1)); - x_scales_ptr = static_cast(x_scales->data_ptr()); + x_scales_ptr = x_scales->data_ptr(); scale_token_stride = static_cast(x_scales->stride(0)); scale_hidden_stride = static_cast(x_scales->stride(1)); } // Allocate all tensors on comm stream if set // NOTES: do not allocate tensors upfront! - auto compute_stream = at::cuda::getStreamFromExternal(calc_ctx->stream(), device_id); + auto compute_stream = at::cuda::CUDAStream(calc_ctx->stream()); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() && async); deep_ep::SetAllocatorStreamForGPUContext(comm_stream, calc_ctx); @@ -664,17 +679,24 @@ Buffer::intranode_dispatch(const torch::Tensor& x, } // Allocate new tensors - auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); + // When use_mask_prmt: recv_x is [E*M, hidden], recv_x_scales is [E, M, kb_dim] uint8 (SfAtom layout) + auto recv_x = use_mask_prmt + ? torch::empty({num_local_experts * max_tokens_per_expert, hidden}, x.options()) + : torch::empty({num_recv_tokens, hidden}, x.options()); auto recv_src_idx = torch::empty({num_recv_tokens}, dtype(torch::kInt32).device(torch::kCUDA)); auto recv_topk_idx = std::optional(), recv_topk_weights = std::optional(), recv_x_scales = std::optional(); auto recv_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); auto send_head = torch::empty({num_tokens, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); + // Mask PMRT additional tensors + auto permuted_indice_map_tensor = std::optional(); + auto token_nums_per_expert_tensor = std::optional(); + // Assign pointers topk_idx_t* recv_topk_idx_ptr = nullptr; float* recv_topk_weights_ptr = nullptr; - float* recv_x_scales_ptr = nullptr; + void* recv_x_scales_ptr = nullptr; if (topk_idx.has_value()) { recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options()); recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); @@ -682,12 +704,25 @@ Buffer::intranode_dispatch(const torch::Tensor& x, recv_topk_weights_ptr = recv_topk_weights->data_ptr(); } if (x_scales.has_value()) { - recv_x_scales = x_scales->dim() == 1 ? torch::empty({num_recv_tokens}, x_scales->options()) - : torch::empty({num_recv_tokens, num_scales}, x_scales->options()); - recv_x_scales_ptr = static_cast(recv_x_scales->data_ptr()); + if (use_mask_prmt) { + // SfAtom layout: output_scale [E, M, kb_dim] as uint8 + int hidden_scale = static_cast(x_scales->size(1)) / 4; + int kb_dim_val = hidden_scale * 4; + recv_x_scales = torch::empty({num_local_experts, max_tokens_per_expert, kb_dim_val}, + dtype(torch::kByte).device(torch::kCUDA)); + } else { + recv_x_scales = x_scales->dim() == 1 ? torch::empty({num_recv_tokens}, x_scales->options()) + : torch::empty({num_recv_tokens, num_scales}, x_scales->options()); + } + recv_x_scales_ptr = recv_x_scales->data_ptr(); + } + if (use_mask_prmt) { + permuted_indice_map_tensor = torch::full({num_recv_tokens, num_topk}, -1, dtype(torch::kInt32).device(torch::kCUDA)); + token_nums_per_expert_tensor = torch::zeros({num_local_experts}, dtype(torch::kInt32).device(torch::kCUDA)); } // Dispatch + int scale_elem_size = (quant_group_size == 32) ? sizeof(uint8_t) : sizeof(float); EP_HOST_ASSERT( num_ranks * num_ranks * sizeof(int) + // Size prefix matrix num_channels * num_ranks * sizeof(int) + // Channel start offset @@ -697,8 +732,15 @@ Buffer::intranode_dispatch(const torch::Tensor& x, num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(topk_idx_t) + // Top-k index buffer num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) + // Top-k weight buffer - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(float) * num_scales // FP8 scale buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * scale_elem_size * num_scales // FP8 scale buffer <= num_nvl_bytes); + // Compute hidden_scale and kb_dim for mask_pmrt + int hidden_scale = 0, kb_dim = 0; + if (use_mask_prmt) { + hidden_scale = static_cast(x_scales->size(1)) / 4; + kb_dim = hidden_scale * 4; + } + intranode::dispatch(recv_x.data_ptr(), recv_x_scales_ptr, recv_src_idx.data_ptr(), @@ -726,7 +768,15 @@ Buffer::intranode_dispatch(const torch::Tensor& x, comm_stream, config.num_sms, config.num_max_nvl_chunked_send_tokens, - config.num_max_nvl_chunked_recv_tokens); + config.num_max_nvl_chunked_recv_tokens, + quant_group_size, + use_mask_prmt, + use_mask_prmt ? permuted_indice_map_tensor->data_ptr() : nullptr, + use_mask_prmt ? token_nums_per_expert_tensor->data_ptr() : nullptr, + max_tokens_per_expert, + num_local_experts, + hidden_scale, + kb_dim); // Wait streams std::optional event; @@ -758,7 +808,9 @@ Buffer::intranode_dispatch(const torch::Tensor& x, cached_rank_prefix_matrix, recv_topk_idx, recv_topk_weights, - recv_x_scales}) { + recv_x_scales, + permuted_indice_map_tensor, + token_nums_per_expert_tensor}) { to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); @@ -783,6 +835,8 @@ Buffer::intranode_dispatch(const torch::Tensor& x, recv_channel_prefix_matrix, recv_src_idx, send_head, + permuted_indice_map_tensor, + token_nums_per_expert_tensor, event}; } @@ -822,7 +876,7 @@ std::tuple, std::optionalstream(), device_id); + auto compute_stream = at::cuda::CUDAStream(calc_ctx->stream()); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() && async); deep_ep::SetAllocatorStreamForGPUContext(comm_stream, calc_ctx); @@ -1064,7 +1118,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, // Allocate all tensors on comm stream if set // NOTES: do not allocate tensors upfront! - auto compute_stream = at::cuda::getStreamFromExternal(calc_ctx->stream(), device_id); + auto compute_stream = at::cuda::CUDAStream(calc_ctx->stream()); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() && async); deep_ep::SetAllocatorStreamForGPUContext(comm_stream, calc_ctx); @@ -1382,7 +1436,7 @@ std::tuple, std::optionalstream(), device_id); + auto compute_stream = at::cuda::CUDAStream(calc_ctx->stream()); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() && async); deep_ep::SetAllocatorStreamForGPUContext(comm_stream, calc_ctx); @@ -1567,14 +1621,15 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, bool round_scale, bool use_ue8m0, bool async, - bool return_recv_hook) { + bool return_recv_hook, + int quant_group_size) { #ifndef DISABLE_NVSHMEM EP_HOST_ASSERT(low_latency_mode); // Tensor checks // By default using `ptp128c` FP8 cast EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16); - EP_HOST_ASSERT(x.size(1) % sizeof(int4) == 0 and x.size(1) % 128 == 0); + EP_HOST_ASSERT(x.size(1) % sizeof(int4) == 0 and x.size(1) % quant_group_size == 0); EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous()); EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) and x.size(0) <= num_max_dispatch_tokens_per_rank); EP_HOST_ASSERT(topk_idx.scalar_type() == c10::CppTypeToScalarType::value); @@ -1597,7 +1652,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, auto num_local_experts = num_experts / num_ranks; // Buffer control - LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); + LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, quant_group_size); EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); auto buffer = layout.buffers[low_latency_buffer_idx]; auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; @@ -1625,16 +1680,28 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, if (use_fp8) { // TODO: support unaligned cases - EP_HOST_ASSERT(hidden % 512 == 0); - if (not use_ue8m0) { - packed_recv_x_scales = torch::empty({num_local_experts, hidden / 128, num_ranks * num_max_dispatch_tokens_per_rank}, + EP_HOST_ASSERT(hidden % quant_group_size == 0); + const auto num_scales = hidden / quant_group_size; + const auto mn_dim = num_ranks * num_max_dispatch_tokens_per_rank; + + if (quant_group_size != 128 and use_ue8m0) { + // CUTLASS SfAtom layout: pad token dim to 128-tile boundary, store as flat uint8 + EP_HOST_ASSERT(round_scale); + EP_HOST_ASSERT(num_scales % 4 == 0 and "CUTLASS SfAtom requires num_scales to be multiple of 4"); + const auto padded_mn = ((mn_dim + 127) / 128) * 128; + packed_recv_x_scales = torch::empty({num_local_experts, padded_mn, num_scales}, + torch::dtype(torch::kByte).device(torch::kCUDA)); + // No transpose - kernel writes directly in CUTLASS SfAtom order + } else if (not use_ue8m0) { + packed_recv_x_scales = torch::empty({num_local_experts, num_scales, mn_dim}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); + packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2); } else { EP_HOST_ASSERT(round_scale); - packed_recv_x_scales = torch::empty({num_local_experts, hidden / 512, num_ranks * num_max_dispatch_tokens_per_rank}, + packed_recv_x_scales = torch::empty({num_local_experts, num_scales / 4, mn_dim}, torch::dtype(torch::kInt).device(torch::kCUDA)); + packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2); } - packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2); packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr(); } @@ -1667,6 +1734,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, use_fp8, round_scale, use_ue8m0, + quant_group_size, workspace, num_device_sms, launch_stream.stream(), @@ -1900,7 +1968,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("get_comm_stream", [](deep_ep::Buffer &self) { int device_id = self.get_local_device_id(); - cudaStream_t comm_stream = at::cuda::CUDAStream(self.get_comm_stream()).stream(); + cudaStream_t comm_stream = self.get_comm_stream().stream(); auto s = phi::Stream(reinterpret_cast(comm_stream)); #if defined(PADDLE_WITH_CUDA) return phi::CUDAStream(phi::GPUPlace(device_id), s); diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index 426ce042..5fd88c52 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -150,7 +150,7 @@ struct Buffer { torch::Tensor get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const; - torch::Stream get_comm_stream() const { + at::cuda::CUDAStream get_comm_stream() const { return comm_stream; } @@ -177,6 +177,8 @@ struct Buffer { torch::Tensor, torch::Tensor, torch::Tensor, + std::optional, + std::optional, std::optional> intranode_dispatch(const torch::Tensor& x, const std::optional& x_scales, @@ -194,7 +196,10 @@ struct Buffer { std::optional& previous_event, bool async, bool allocate_on_comm_stream, - bool skip_x_record_stream = false); + bool skip_x_record_stream = false, + int quant_group_size = 128, + bool use_mask_prmt = false, + int max_tokens_per_expert = 0); std::tuple, std::optional> intranode_combine( const torch::Tensor& x, @@ -283,7 +288,8 @@ struct Buffer { bool round_scale, bool use_ue8m0, bool async, - bool return_recv_hook); + bool return_recv_hook, + int quant_group_size = 128); std::tuple, std::optional>> low_latency_combine( const torch::Tensor& x, diff --git a/csrc/event.hpp b/csrc/event.hpp index c4138e38..4cd7994b 100644 --- a/csrc/event.hpp +++ b/csrc/event.hpp @@ -12,32 +12,35 @@ struct EventHandle { EventHandle() { event = std::make_shared(torch::kCUDA); - event->record(at::cuda::getCurrentCUDAStream()); + event->record(at::cuda::getCurrentCUDAStream().stream()); } explicit EventHandle(const at::cuda::CUDAStream& stream) { event = std::make_shared(torch::kCUDA); - event->record(stream); + event->record(stream.stream()); } EventHandle(const EventHandle& other) = default; - void current_stream_wait() const { at::cuda::getCurrentCUDAStream().unwrap().wait(*event); } + void current_stream_wait() const { + C10_CUDA_CHECK(cudaStreamWaitEvent(at::cuda::getCurrentCUDAStream().stream(), event->cuda_event())); + } }; torch::Event create_event(const at::cuda::CUDAStream& s) { auto event = torch::Event(torch::kCUDA); - event.record(s); + event.record(s.stream()); return event; } void stream_wait(const at::cuda::CUDAStream& s_0, const at::cuda::CUDAStream& s_1) { EP_HOST_ASSERT(s_0.id() != s_1.id()); - s_0.unwrap().wait(create_event(s_1)); + auto ev = create_event(s_1); + C10_CUDA_CHECK(cudaStreamWaitEvent(s_0.stream(), ev.cuda_event())); } void stream_wait(const at::cuda::CUDAStream& s, const EventHandle& event) { - s.unwrap().wait(*event.event); + C10_CUDA_CHECK(cudaStreamWaitEvent(s.stream(), event.event->cuda_event())); } } // namespace deep_ep diff --git a/csrc/kernels/api.cuh b/csrc/kernels/api.cuh index 9bbe096a..439dfe9f 100644 --- a/csrc/kernels/api.cuh +++ b/csrc/kernels/api.cuh @@ -76,14 +76,14 @@ void cached_notify_dispatch(const int* rank_prefix_matrix, cudaStream_t stream); void dispatch(void* recv_x, - float* recv_x_scales, + void* recv_x_scales, int* recv_src_idx, topk_idx_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset, int* send_head, const void* x, - const float* x_scales, + const void* x_scales, const topk_idx_t* topk_idx, const float* topk_weights, const bool* is_token_in_rank, @@ -102,7 +102,15 @@ void dispatch(void* recv_x, cudaStream_t stream, int num_sms, int num_max_send_tokens, - int num_recv_buffer_tokens); + int num_recv_buffer_tokens, + int quant_group_size, + bool use_mask_prmt = false, + int32_t* permuted_indice_map = nullptr, + int32_t* token_nums_per_expert = nullptr, + int max_tokens_per_expert = 0, + int num_local_experts = 0, + int hidden_scale = 0, + int kb_dim = 0); void cached_notify_combine(void** buffer_ptrs, int* send_head, @@ -307,6 +315,7 @@ void dispatch(void* packed_recv_x, bool use_fp8, bool round_scale, bool use_ue8m0, + int quant_group_size, void* workspace, int num_device_sms, cudaStream_t stream, diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index e9fd473b..2f366a86 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -126,7 +126,7 @@ void clean_low_latency_buffer(int* clean_0, sync_buffer_ptr); } -template +template __global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x, void* packed_recv_x_scales, int* packed_recv_src_info, @@ -169,16 +169,20 @@ __global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x, using packed_t = std::conditional_t; EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length"); + // CUTLASS SfAtom layout for small group sizes with UE8M0 + constexpr bool kUseCutlassSfLayout = kUseFP8 && kUseUE8M0 && (kQuantGroupSize != 128); + // FP8 staffs - constexpr int kNumPerChannels = 128; - const int num_scales = kHidden / kNumPerChannels; + const int num_scales = kHidden / kQuantGroupSize; const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16)); const size_t hidden_int4 = hidden_bytes / sizeof(int4); // Message package: index at source (int), 3 reserved int fields, hidden data, FP8 scales // NOTES: currently we have 3 reserved int fields for future use + // For UE8M0 with non-128 group size, pack scales as uint8 to reduce RDMA transfer using vec_t = std::conditional_t; - const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(nv_bfloat16))); + constexpr size_t kScaleElemBytes = kUseCutlassSfLayout ? sizeof(uint8_t) : sizeof(float); + const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * kScaleElemBytes) : (kHidden * sizeof(nv_bfloat16))); const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4); EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); @@ -196,7 +200,8 @@ __global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x, if (warp_id < num_warps - 1) { constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(nv_bfloat16); EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerRead) == 0, "Invalid hidden"); - EP_STATIC_ASSERT(kNumElemsPerRead * 32 % kNumPerChannels == 0, "Invalid vectorization"); + EP_STATIC_ASSERT(kNumElemsPerRead * 32 % kQuantGroupSize == 0, "Invalid vectorization"); + constexpr int kNumLanesPerGroup = kQuantGroupSize / kNumElemsPerRead; const auto num_threads = (num_warps - 1) * 32; const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead; @@ -204,7 +209,7 @@ __global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x, const auto x_int4 = static_cast(x) + token_idx * hidden_bf16_int4; const auto rdma_x_src_idx = reinterpret_cast(static_cast(rdma_x) + token_idx * num_bytes_per_msg); const auto rdma_x_vec = reinterpret_cast(reinterpret_cast(rdma_x_src_idx) + sizeof(int4)); - const auto rdma_x_scales = reinterpret_cast(reinterpret_cast(rdma_x_vec) + hidden_bytes); + const auto rdma_x_scales_area = reinterpret_cast(rdma_x_vec) + hidden_bytes; // Overlap top-k index read and source token index writes auto dst_expert_idx = warp_id < num_topk ? static_cast(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1; @@ -229,11 +234,17 @@ __global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x, } // Reduce amax and scale - EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization"); - amax = warp_reduce_max<16>(amax); + amax = warp_reduce_max(amax); calculate_fp8_scales(amax, scale, scale_inv, round_scale); - if (lane_id == 0 or lane_id == 16) - rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv; + if (lane_id % kNumLanesPerGroup == 0) { + const auto scale_idx = i * kNumElemsPerRead / kQuantGroupSize; + if constexpr (kUseCutlassSfLayout) { + // Pack as uint8 in RDMA message to reduce transfer size + rdma_x_scales_area[scale_idx] = extract_required_scale_format(scale_inv); + } else { + reinterpret_cast(rdma_x_scales_area)[scale_idx] = scale_inv; + } + } // Cast into send buffer vec_t int2_value; @@ -371,8 +382,13 @@ LOW_LATENCY_DISPATCH_RECV: const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank; const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks; const auto num_aligned_scales = align_up(num_scales, sizeof(float) / sizeof(scale_t)); + // For CUTLASS SfAtom layout (kQuantGroupSize != 128 with UE8M0), pad token dim to 128 + constexpr bool kUseCutlassSfLayout = (kQuantGroupSize != 128) && kUseUE8M0; + const auto scale_mn_dim = kUseCutlassSfLayout ? + ((num_ranks * num_max_dispatch_tokens_per_rank + 127) & ~127) : + num_ranks * num_max_dispatch_tokens_per_rank; const auto recv_x_scales = static_cast(packed_recv_x_scales) + - local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales; + local_expert_idx * scale_mn_dim * num_aligned_scales; // Shared between sub-warps in warp groups __shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups]; @@ -422,7 +438,6 @@ LOW_LATENCY_DISPATCH_RECV: recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id]; // Copy tokens - EP_DEVICE_ASSERT(num_scales <= 64); for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) { // Copy source info const auto src_src_idx = reinterpret_cast(rdma_recv_x_uint8 + i * num_bytes_per_msg); @@ -438,24 +453,41 @@ LOW_LATENCY_DISPATCH_RECV: // Copy scales if constexpr (kUseFP8) { - // Equivalent CuTe layout: - // (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1)) - const auto src_scales = reinterpret_cast(reinterpret_cast(src_data) + hidden_bytes); - const auto num_elems_per_pack = static_cast(sizeof(packed_t) / sizeof(scale_t)); const auto token_idx = recv_token_begin_idx + i; - const auto token_stride = num_elems_per_pack; - const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack; - if (lane_id < num_scales) { - const auto pack_idx = lane_id / num_elems_per_pack; - const auto elem_idx = lane_id % num_elems_per_pack; - auto scale = extract_required_scale_format(ld_nc_global(src_scales + lane_id)); - recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale; - } - if (lane_id + 32 < num_scales) { - const auto pack_idx = (lane_id + 32) / num_elems_per_pack; - const auto elem_idx = (lane_id + 32) % num_elems_per_pack; - auto scale = extract_required_scale_format(ld_nc_global(src_scales + lane_id + 32)); - recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale; + + if constexpr (kUseCutlassSfLayout) { + // CUTLASS SfAtom interleaved layout with packed uint8 scales in message + // Atom shape: ((32, 4), (SFVecSize, 4)), stride: ((16, 4), (0, 1)) + // Physical atom size: 128 MN x 4 K = 512 bytes + const auto src_scales_u8 = reinterpret_cast(src_data) + hidden_bytes; + const auto kb_dim = num_aligned_scales; + const auto num_k_tiles = kb_dim / 4; + const int n_tile = token_idx / 128; + const int n_local = token_idx % 128; + const int mn_i = n_local % 32; + const int mn_j = n_local / 32; + const int base_offset = n_tile * num_k_tiles * 512 + mn_i * 16 + mn_j * 4; + // Vectorized: each lane handles one k_tile (4 uint8 scales) + // Read 4 consecutive uint8 as uint32 from packed message, write 4 bytes at once + #pragma unroll 1 + for (int k = lane_id; k < num_k_tiles; k += 32) { + uint32_t packed = ld_nc_global(reinterpret_cast(src_scales_u8 + k * 4)); + *reinterpret_cast(recv_x_scales + base_offset + k * 512) = packed; + } + } else { + // Original CuTe interleaved layout with float scales in message + // (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1)) + const auto src_scales = reinterpret_cast(reinterpret_cast(src_data) + hidden_bytes); + const auto num_elems_per_pack = static_cast(sizeof(packed_t) / sizeof(scale_t)); + const auto token_stride = num_elems_per_pack; + const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack; + #pragma unroll 1 + for (int s = lane_id; s < num_scales; s += 32) { + const auto pack_idx = s / num_elems_per_pack; + const auto elem_idx = s % num_elems_per_pack; + auto scale = extract_required_scale_format(ld_nc_global(src_scales + s)); + recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale; + } } } } @@ -487,6 +519,7 @@ void dispatch(void* packed_recv_x, bool use_fp8, bool round_scale, bool use_ue8m0, + int quant_group_size, void* workspace, int num_device_sms, cudaStream_t stream, @@ -509,44 +542,53 @@ void dispatch(void* packed_recv_x, // FP8 checks if (use_ue8m0) EP_HOST_ASSERT(round_scale and "UE8M0 SF requires `round_scale=True`"); - -#define DISPATCH_LAUNCH_CASE(hidden) \ - { \ - auto dispatch_func = dispatch