diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index b26405e..05b3edf 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -5,6 +5,7 @@ #include "utils.cuh" #include "shmem_wrapper.cuh" +using namespace rocshmem; namespace deep_ep { namespace internode { @@ -554,7 +555,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv kNVLReceivers }; -#if !defined(ROCM_DISABLE_CTX) +#if !defined(ROCM_DISABLE_CTX) && !defined(ROCM_EXPLICIT_CTX) __shared__ shmem_ctx_t ctx; shmem_wg_ctx_create(&ctx); #endif @@ -674,10 +675,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv } else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) { rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1; } -#if defined(ROCM_DISABLE_CTX) - shmemx_int_put_nbi_warp( -#else +#if defined(ROCM_EXPLICIT_CTX) + shmem_ctx_int_put_nbi_warp(rocshmem_ctx_array[sm_id], +#elif !defined(ROCM_DISABLE_CTX) shmem_ctx_int_put_nbi_warp(ctx, +#else + shmemx_int_put_nbi_warp( #endif rdma_channel_meta.recv_buffer(rdma_rank), rdma_channel_meta.send_buffer(dst_rdma_rank), @@ -685,11 +688,15 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); } -#if defined(ROCM_DISABLE_CTX) - shmem_fence(); -#else + if (warp_id < kNumRDMARanks) { +#if defined(ROCM_EXPLICIT_CTX) + shmem_ctx_quiet(rocshmem_ctx_array[sm_id]); +#elif !defined(ROCM_DISABLE_CTX) shmem_ctx_quiet(ctx); +#else + shmem_fence(); #endif + } sync_rdma_sender_smem(); // Iterate over tokens and copy into buffer @@ -832,10 +839,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv if (dst_rdma_rank != rdma_rank) { auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens; EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens); -#if defined(ROCM_DISABLE_CTX) - shmemx_int8_put_nbi_warp( -#else +#if defined(ROCM_EXPLICIT_CTX) + shmem_ctx_schar_put_nbi_warp(rocshmem_ctx_array[sm_id], +#elif !defined(ROCM_DISABLE_CTX) shmem_ctx_schar_put_nbi_warp(ctx, +#else + shmemx_int8_put_nbi_warp( #endif rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token, @@ -854,10 +863,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv if (lane_id == dst_rdma_rank) { last_issued_tail += num_tokens_to_issue; num_tokens_to_send -= num_tokens_to_issue; -#if defined(ROCM_DISABLE_CTX) - shmem_signal_op_add( -#else +#if defined(ROCM_EXPLICIT_CTX) + shmem_ctx_ulong_atomic_add(rocshmem_ctx_array[sm_id], +#elif !defined(ROCM_DISABLE_CTX) shmem_ctx_ulong_atomic_add(ctx, +#else + shmem_signal_op_add( #endif rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, @@ -1083,10 +1094,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv if (min_head != std::numeric_limits::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { -#if defined(ROCM_DISABLE_CTX) - shmem_signal_op_add( -#else +#if defined(ROCM_EXPLICIT_CTX) + shmem_ctx_ulong_atomic_add(rocshmem_ctx_array[sm_id], +#elif !defined(ROCM_DISABLE_CTX) shmem_ctx_ulong_atomic_add(ctx, +#else + shmem_signal_op_add( #endif rdma_channel_head.buffer(rdma_rank), min_head - last_head, translate_dst_rdma_rank(lane_id, nvl_rank)); @@ -1218,7 +1231,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv recv_topk_idx[i] = -1; } -#if !defined(ROCM_DISABLE_CTX) +#if !defined(ROCM_DISABLE_CTX) && !defined(ROCM_EXPLICIT_CTX) shmem_wg_ctx_destroy(&ctx); #endif } @@ -1636,7 +1649,7 @@ combine(int4* combined_x, float* combined_topk_weights, init_workgroup_warp_barrier(&combine_large_warp_barriers[bi]); } __syncthreads(); -#if !defined(ROCM_DISABLE_CTX) +#if !defined(ROCM_DISABLE_CTX) && !defined(ROCM_EXPLICIT_CTX) __shared__ shmem_ctx_t ctx; shmem_wg_ctx_create(&ctx); #endif @@ -1896,18 +1909,22 @@ combine(int4* combined_x, float* combined_topk_weights, if (dst_rdma_rank != rdma_rank) { auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens; const auto num_bytes_per_msg = num_chunked_tokens * num_bytes_per_token; -#if defined(ROCM_DISABLE_CTX) - shmemx_int8_put_nbi_warp( -#else +#if defined(ROCM_EXPLICIT_CTX) + shmem_ctx_schar_put_nbi_warp(rocshmem_ctx_array[sm_id], +#elif !defined(ROCM_DISABLE_CTX) shmem_ctx_schar_put_nbi_warp(ctx, +#else + shmemx_int8_put_nbi_warp( #endif rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_token, rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_token, num_bytes_per_msg, translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); -#if defined(ROCM_DISABLE_CTX) - shmem_fence(); -#else +#if defined(ROCM_EXPLICIT_CTX) + shmem_ctx_quiet(rocshmem_ctx_array[sm_id]); +#elif !defined(ROCM_DISABLE_CTX) shmem_ctx_quiet(ctx); +#else + shmem_fence(); #endif } else { memory_fence(); @@ -1916,10 +1933,12 @@ combine(int4* combined_x, float* combined_topk_weights, // Write new RDMA tail syncwarp(); if (lane_id == 0) { -#if defined(ROCM_DISABLE_CTX) - shmem_signal_op_add(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); -#else +#if defined(ROCM_EXPLICIT_CTX) + shmem_ctx_ulong_atomic_add(rocshmem_ctx_array[sm_id], rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); +#elif !defined(ROCM_DISABLE_CTX) shmem_ctx_ulong_atomic_add(ctx, rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); +#else + shmem_signal_op_add(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); #endif } } @@ -1944,10 +1963,14 @@ combine(int4* combined_x, float* combined_topk_weights, if (not rdma_receiver_retired[i]) min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]); if (min_head != std::numeric_limits::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { -#if defined(ROCM_DISABLE_CTX) - shmem_signal_op_add(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); +#if defined(ROCM_EXPLICIT_CTX) + shmem_ctx_ulong_atomic_add(rocshmem_ctx_array[sm_id], + rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); +#elif !defined(ROCM_DISABLE_CTX) + shmem_ctx_ulong_atomic_add(ctx, + rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); #else - shmem_ctx_ulong_atomic_add(ctx, rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + shmem_signal_op_add(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); #endif last_rdma_head = min_head; } @@ -2003,7 +2026,7 @@ combine(int4* combined_x, float* combined_topk_weights, if (lane_id == 0) rdma_receiver_retired[warp_id] = true; } } -#if !defined(ROCM_DISABLE_CTX) +#if defined(USE_ROCM) && !defined(ROCM_DISABLE_CTX) && !defined(ROCM_EXPLICIT_CTX) shmem_wg_ctx_destroy(&ctx); #endif }