Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 54 additions & 31 deletions csrc/kernels/internode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "utils.cuh"
#include "shmem_wrapper.cuh"

using namespace rocshmem;
namespace deep_ep {

namespace internode {
Expand Down Expand Up @@ -554,7 +555,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
kNVLReceivers
};

#if !defined(ROCM_DISABLE_CTX)
#if !defined(ROCM_DISABLE_CTX) && !defined(ROCM_EXPLICIT_CTX)
__shared__ shmem_ctx_t ctx;
shmem_wg_ctx_create(&ctx);
#endif
Expand Down Expand Up @@ -674,22 +675,28 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
} else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) {
rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1;
}
#if defined(ROCM_DISABLE_CTX)
shmemx_int_put_nbi_warp(
#else
#if defined(ROCM_EXPLICIT_CTX)
shmem_ctx_int_put_nbi_warp(rocshmem_ctx_array[sm_id],
#elif !defined(ROCM_DISABLE_CTX)
shmem_ctx_int_put_nbi_warp(ctx,
#else
shmemx_int_put_nbi_warp(
#endif
rdma_channel_meta.recv_buffer(rdma_rank),
rdma_channel_meta.send_buffer(dst_rdma_rank),
NUM_MAX_NVL_PEERS * 2 + 2,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank,
nvl_rank));
}
#if defined(ROCM_DISABLE_CTX)
shmem_fence();
#else
if (warp_id < kNumRDMARanks) {
#if defined(ROCM_EXPLICIT_CTX)
shmem_ctx_quiet(rocshmem_ctx_array[sm_id]);
#elif !defined(ROCM_DISABLE_CTX)
shmem_ctx_quiet(ctx);
#else
shmem_fence();
#endif
}
sync_rdma_sender_smem();

// Iterate over tokens and copy into buffer
Expand Down Expand Up @@ -832,10 +839,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
if (dst_rdma_rank != rdma_rank) {
auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens;
EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens);
#if defined(ROCM_DISABLE_CTX)
shmemx_int8_put_nbi_warp(
#else
#if defined(ROCM_EXPLICIT_CTX)
shmem_ctx_schar_put_nbi_warp(rocshmem_ctx_array[sm_id],
#elif !defined(ROCM_DISABLE_CTX)
shmem_ctx_schar_put_nbi_warp(ctx,
#else
shmemx_int8_put_nbi_warp(
#endif
rdma_channel_data.recv_buffer(rdma_rank) +
dst_slot_idx * num_bytes_per_rdma_token,
Expand All @@ -854,10 +863,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
if (lane_id == dst_rdma_rank) {
last_issued_tail += num_tokens_to_issue;
num_tokens_to_send -= num_tokens_to_issue;
#if defined(ROCM_DISABLE_CTX)
shmem_signal_op_add(
#else
#if defined(ROCM_EXPLICIT_CTX)
shmem_ctx_ulong_atomic_add(rocshmem_ctx_array[sm_id],
#elif !defined(ROCM_DISABLE_CTX)
shmem_ctx_ulong_atomic_add(ctx,
#else
shmem_signal_op_add(
#endif
rdma_channel_tail.buffer(rdma_rank),
num_tokens_to_issue,
Expand Down Expand Up @@ -1083,10 +1094,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
if (min_head != std::numeric_limits<int>::max() and
min_head >= last_head + num_max_rdma_chunked_send_tokens and
lane_id < kNumRDMARanks) {
#if defined(ROCM_DISABLE_CTX)
shmem_signal_op_add(
#else
#if defined(ROCM_EXPLICIT_CTX)
shmem_ctx_ulong_atomic_add(rocshmem_ctx_array[sm_id],
#elif !defined(ROCM_DISABLE_CTX)
shmem_ctx_ulong_atomic_add(ctx,
#else
shmem_signal_op_add(
#endif
rdma_channel_head.buffer(rdma_rank), min_head - last_head,
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
Expand Down Expand Up @@ -1218,7 +1231,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
recv_topk_idx[i] = -1;
}

#if !defined(ROCM_DISABLE_CTX)
#if !defined(ROCM_DISABLE_CTX) && !defined(ROCM_EXPLICIT_CTX)
shmem_wg_ctx_destroy(&ctx);
#endif
}
Expand Down Expand Up @@ -1636,7 +1649,7 @@ combine(int4* combined_x, float* combined_topk_weights,
init_workgroup_warp_barrier(&combine_large_warp_barriers[bi]);
}
__syncthreads();
#if !defined(ROCM_DISABLE_CTX)
#if !defined(ROCM_DISABLE_CTX) && !defined(ROCM_EXPLICIT_CTX)
__shared__ shmem_ctx_t ctx;
shmem_wg_ctx_create(&ctx);
#endif
Expand Down Expand Up @@ -1896,18 +1909,22 @@ combine(int4* combined_x, float* combined_topk_weights,
if (dst_rdma_rank != rdma_rank) {
auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;
const auto num_bytes_per_msg = num_chunked_tokens * num_bytes_per_token;
#if defined(ROCM_DISABLE_CTX)
shmemx_int8_put_nbi_warp(
#else
#if defined(ROCM_EXPLICIT_CTX)
shmem_ctx_schar_put_nbi_warp(rocshmem_ctx_array[sm_id],
#elif !defined(ROCM_DISABLE_CTX)
shmem_ctx_schar_put_nbi_warp(ctx,
#else
shmemx_int8_put_nbi_warp(
#endif
rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_token,
rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_token,
num_bytes_per_msg, translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
#if defined(ROCM_DISABLE_CTX)
shmem_fence();
#else
#if defined(ROCM_EXPLICIT_CTX)
shmem_ctx_quiet(rocshmem_ctx_array[sm_id]);
#elif !defined(ROCM_DISABLE_CTX)
shmem_ctx_quiet(ctx);
#else
shmem_fence();
#endif
} else {
memory_fence();
Expand All @@ -1916,10 +1933,12 @@ combine(int4* combined_x, float* combined_topk_weights,
// Write new RDMA tail
syncwarp();
if (lane_id == 0) {
#if defined(ROCM_DISABLE_CTX)
shmem_signal_op_add(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
#else
#if defined(ROCM_EXPLICIT_CTX)
shmem_ctx_ulong_atomic_add(rocshmem_ctx_array[sm_id], rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
#elif !defined(ROCM_DISABLE_CTX)
shmem_ctx_ulong_atomic_add(ctx, rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
#else
shmem_signal_op_add(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
#endif
}
}
Expand All @@ -1944,10 +1963,14 @@ combine(int4* combined_x, float* combined_topk_weights,
if (not rdma_receiver_retired[i])
min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
if (min_head != std::numeric_limits<int>::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
#if defined(ROCM_DISABLE_CTX)
shmem_signal_op_add(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
#if defined(ROCM_EXPLICIT_CTX)
shmem_ctx_ulong_atomic_add(rocshmem_ctx_array[sm_id],
rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
#elif !defined(ROCM_DISABLE_CTX)
shmem_ctx_ulong_atomic_add(ctx,
rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
#else
shmem_ctx_ulong_atomic_add(ctx, rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
shmem_signal_op_add(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
#endif
last_rdma_head = min_head;
}
Expand Down Expand Up @@ -2003,7 +2026,7 @@ combine(int4* combined_x, float* combined_topk_weights,
if (lane_id == 0) rdma_receiver_retired[warp_id] = true;
}
}
#if !defined(ROCM_DISABLE_CTX)
#if defined(USE_ROCM) && !defined(ROCM_DISABLE_CTX) && !defined(ROCM_EXPLICIT_CTX)
shmem_wg_ctx_destroy(&ctx);
#endif
}
Expand Down