Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions csrc/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,9 @@ struct LowLatencyLayout {
return reinterpret_cast<out_ptr_t>(reinterpret_cast<count_ptr_t>(ptr) + count);
}

LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) {
const int num_scales = hidden / 128;
LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts,
int quant_group_size = 128) {
const int num_scales = hidden / quant_group_size;

// Dispatch and combine layout:
// - 2 symmetric odd/even send buffer
Expand All @@ -143,7 +144,7 @@ struct LowLatencyLayout {

// Message sizes
// NOTES: you should add a control `int4` for combine messages if you want to do data transformation
// NOTES: `num_scales * sizeof(nv_bfloat162)` means the per-128-channel min/max
// NOTES: `num_scales * sizeof(nv_bfloat162)` means the per-channel min/max
EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden);
size_t num_bytes_per_dispatch_msg = sizeof(int4) + std::max(hidden * sizeof(nv_bfloat16), hidden + num_scales * sizeof(float));
size_t num_bytes_per_combine_msg = num_scales * sizeof(nv_bfloat162) + hidden * sizeof(nv_bfloat16);
Expand Down Expand Up @@ -187,8 +188,9 @@ struct LowLatencyLayout {
}
};

size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) {
auto num_bytes = LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts).total_bytes;
size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts,
int quant_group_size = 128) {
auto num_bytes = LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, quant_group_size).total_bytes;
return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * NUM_BUFFER_ALIGNMENT_BYTES;
}

Expand Down
122 changes: 95 additions & 27 deletions csrc/deep_ep.cpp

Large diffs are not rendered by default.

12 changes: 9 additions & 3 deletions csrc/deep_ep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ struct Buffer {

torch::Tensor get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const;

torch::Stream get_comm_stream() const {
at::cuda::CUDAStream get_comm_stream() const {
return comm_stream;
}

Expand All @@ -177,6 +177,8 @@ struct Buffer {
torch::Tensor,
torch::Tensor,
torch::Tensor,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<EventHandle>>
intranode_dispatch(const torch::Tensor& x,
const std::optional<torch::Tensor>& x_scales,
Expand All @@ -194,7 +196,10 @@ struct Buffer {
std::optional<EventHandle>& previous_event,
bool async,
bool allocate_on_comm_stream,
bool skip_x_record_stream = false);
bool skip_x_record_stream = false,
int quant_group_size = 128,
bool use_mask_prmt = false,
int max_tokens_per_expert = 0);

std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>> intranode_combine(
const torch::Tensor& x,
Expand Down Expand Up @@ -283,7 +288,8 @@ struct Buffer {
bool round_scale,
bool use_ue8m0,
bool async,
bool return_recv_hook);
bool return_recv_hook,
int quant_group_size = 128);

std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> low_latency_combine(
const torch::Tensor& x,
Expand Down
15 changes: 9 additions & 6 deletions csrc/event.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,35 @@ struct EventHandle {

EventHandle() {
event = std::make_shared<torch::Event>(torch::kCUDA);
event->record(at::cuda::getCurrentCUDAStream());
event->record(at::cuda::getCurrentCUDAStream().stream());
}

explicit EventHandle(const at::cuda::CUDAStream& stream) {
event = std::make_shared<torch::Event>(torch::kCUDA);
event->record(stream);
event->record(stream.stream());
}

EventHandle(const EventHandle& other) = default;

void current_stream_wait() const { at::cuda::getCurrentCUDAStream().unwrap().wait(*event); }
void current_stream_wait() const {
C10_CUDA_CHECK(cudaStreamWaitEvent(at::cuda::getCurrentCUDAStream().stream(), event->cuda_event()));
}
};

torch::Event create_event(const at::cuda::CUDAStream& s) {
auto event = torch::Event(torch::kCUDA);
event.record(s);
event.record(s.stream());
return event;
}

void stream_wait(const at::cuda::CUDAStream& s_0, const at::cuda::CUDAStream& s_1) {
EP_HOST_ASSERT(s_0.id() != s_1.id());
s_0.unwrap().wait(create_event(s_1));
auto ev = create_event(s_1);
C10_CUDA_CHECK(cudaStreamWaitEvent(s_0.stream(), ev.cuda_event()));
}

void stream_wait(const at::cuda::CUDAStream& s, const EventHandle& event) {
s.unwrap().wait(*event.event);
C10_CUDA_CHECK(cudaStreamWaitEvent(s.stream(), event.event->cuda_event()));
}

} // namespace deep_ep
15 changes: 12 additions & 3 deletions csrc/kernels/api.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,14 @@ void cached_notify_dispatch(const int* rank_prefix_matrix,
cudaStream_t stream);

void dispatch(void* recv_x,
float* recv_x_scales,
void* recv_x_scales,
int* recv_src_idx,
topk_idx_t* recv_topk_idx,
float* recv_topk_weights,
int* recv_channel_offset,
int* send_head,
const void* x,
const float* x_scales,
const void* x_scales,
const topk_idx_t* topk_idx,
const float* topk_weights,
const bool* is_token_in_rank,
Expand All @@ -102,7 +102,15 @@ void dispatch(void* recv_x,
cudaStream_t stream,
int num_sms,
int num_max_send_tokens,
int num_recv_buffer_tokens);
int num_recv_buffer_tokens,
int quant_group_size,
bool use_mask_prmt = false,
int32_t* permuted_indice_map = nullptr,
int32_t* token_nums_per_expert = nullptr,
int max_tokens_per_expert = 0,
int num_local_experts = 0,
int hidden_scale = 0,
int kb_dim = 0);

void cached_notify_combine(void** buffer_ptrs,
int* send_head,
Expand Down Expand Up @@ -307,6 +315,7 @@ void dispatch(void* packed_recv_x,
bool use_fp8,
bool round_scale,
bool use_ue8m0,
int quant_group_size,
void* workspace,
int num_device_sms,
cudaStream_t stream,
Expand Down
Loading