diff --git a/src/ocean.cu b/src/ocean.cu index baaa9b7be6..8ce6eb57da 100644 --- a/src/ocean.cu +++ b/src/ocean.cu @@ -2,6 +2,7 @@ // Included by pufferlib.cu — requires precision_t, PrecisionTensor, Allocator, puf_mm, etc. #include "cudnn_conv2d.cu" +#include "kernels.cu" // ---- NMMO3 constants ---- @@ -24,8 +25,101 @@ static cudnnDataType_t n3_cudnn_dtype() { return (PRECISION_SIZE == 2) ? CUDNN_DATA_BFLOAT16 : CUDNN_DATA_FLOAT; } +struct FastDivMod { + uint32_t d_; + uint32_t M_; + uint32_t l_; + + __host__ FastDivMod(int d) { + d_ = d <= 0 ? 1u : (uint32_t)d; + uint32_t l = 0; + for (; l < 32; ++l) + if ((1u << l) >= d_) break; + l_ = l; + const uint64_t one = 1; + uint64_t m = ((one << 32) * ((one << l_) - d_)) / d_ + 1; + M_ = (uint32_t)m; + } + + __device__ __forceinline__ int div(int n) const { + uint32_t u = (uint32_t)n; // n must be >= 0 + uint32_t t = __umulhi(M_, u); + return (int)((t + u) >> l_); + } + + __device__ __forceinline__ int mod(int n) const { + return n - div(n) * (int)d_; + } + + __device__ __forceinline__ void divmod(int n, int& q, int& r) const { + q = div(n); + r = n - q * (int)d_; + } +}; + +struct Im2ColFastMods { + FastDivMod dm_col_w; + FastDivMod dm_oh_ow; + FastDivMod dm_ow; + FastDivMod dm_kk; + FastDivMod dm_k; + FastDivMod dm_oc; + FastDivMod dm_iw; + FastDivMod dm_ih; + FastDivMod dm_ic; + FastDivMod dm_s; + FastDivMod dm_n3_hw; + FastDivMod dm_n3_hwf; + FastDivMod dm_n3_w; + int total_no_batch; + int oh_ow; + int oc_spatial; + int col_cols; + int n3_hw; + int n3_hwf; + int n3_multihot_plane; + int IC, IH, IW, OC, K, S, OH, OW; + + __host__ Im2ColFastMods(int ic, int ih, int iw, int oc, int k, int s, int oh, int ow) + : dm_col_w(ic * k * k), dm_oh_ow(oh * ow), dm_ow(ow), dm_kk(k * k), dm_k(k), dm_oc(oc), + dm_iw(iw), dm_ih(ih), dm_ic(ic), dm_s(s), + dm_n3_hw(N3_MAP_H * N3_MAP_W), + dm_n3_hwf(N3_MAP_H * N3_MAP_W * N3_NFEAT), + dm_n3_w(N3_MAP_W), + total_no_batch((oh * ow) * (ic * k * k)), oh_ow(oh * ow), col_cols(ic * k * k), + oc_spatial(oc * oh * ow), n3_hw(N3_MAP_H * N3_MAP_W), n3_hwf(N3_MAP_H * N3_MAP_W * N3_NFEAT), + n3_multihot_plane(N3_MULTIHOT * N3_MAP_H * N3_MAP_W), + IC(ic), IH(ih), IW(iw), OC(oc), K(k), S(s), OH(oh), OW(ow) {} +}; + +static const Im2ColFastMods kIm2ColModsC1( + N3_C1_IC, N3_MAP_H, N3_MAP_W, N3_C1_OC, N3_C1_K, N3_C1_S, N3_C1_OH, N3_C1_OW); +static const Im2ColFastMods kIm2ColModsC2( + N3_C2_IC, N3_C1_OH, N3_C1_OW, N3_C2_OC, N3_C2_K, N3_C2_S, N3_C2_OH, N3_C2_OW); +static const FastDivMod kDmN3Player(N3_PLAYER); +static const FastDivMod kDmConvBiasSpatialC1(N3_C1_OH * N3_C1_OW); +static const FastDivMod kDmConvBiasSpatialC2(N3_C2_OH * N3_C2_OW); + // ---- NMMO3 kernels ---- +// One thread per (b, f, h, w): dm_n3_hwf splits idx -> (b, rem_hwf), dm_n3_hw -> (f, rem_sp), dm_n3_w -> (h, w). +__global__ void n3_multihot_kernel_fast( + precision_t* __restrict__ out, const precision_t* __restrict__ obs, int B, int obs_size, + const FastDivMod dm_n3_hwf, const FastDivMod dm_n3_hw, const FastDivMod dm_n3_w, int n3_hw, int n3_hwf, + int n3_multihot_plane) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= B * n3_hwf) return; + int b, rem_hwf; + dm_n3_hwf.divmod(idx, b, rem_hwf); + int f, rem_sp; + dm_n3_hw.divmod(rem_hwf, f, rem_sp); + int h, w; + dm_n3_w.divmod(rem_sp, h, w); + const precision_t* src = obs + (int64_t)b * obs_size + (int64_t)(h * N3_MAP_W + w) * N3_NFEAT; + precision_t* dst = out + (int64_t)b * n3_multihot_plane; + dst[(N3_OFFSETS[f] + (int)to_float(src[f])) * n3_hw + h * N3_MAP_W + w] = from_float(1.0f); +} + __global__ void n3_multihot_kernel( precision_t* __restrict__ out, const precision_t* __restrict__ obs, int B, int obs_size) { int idx = blockIdx.x * blockDim.x + threadIdx.x; @@ -38,6 +132,29 @@ __global__ void n3_multihot_kernel( dst[(N3_OFFSETS[f] + (int)to_float(src[f])) * N3_MAP_H * N3_MAP_W + h * N3_MAP_W + w] = from_float(1.0f); } +__global__ void n3_embedding_kernel_fast( + precision_t* __restrict__ out, const precision_t* __restrict__ obs, + const precision_t* __restrict__ embed_w, int B, int obs_size, const FastDivMod dm_n3_player) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= B * N3_PLAYER) return; + int b, f; + dm_n3_player.divmod(idx, b, f); + int val = (int)to_float(obs[b * obs_size + N3_MAP_SIZE + f]); + const precision_t* src = embed_w + val * N3_EMBED_DIM; + precision_t* dst = out + b * N3_PLAYER_EMBED + f * N3_EMBED_DIM; +#ifdef PRECISION_FLOAT + const float4* src4 = reinterpret_cast(src); + float4* dst4 = reinterpret_cast(dst); +#pragma unroll + for (int i = 0; i < N3_EMBED_DIM / 4; i++) dst4[i] = src4[i]; +#else + const uint4* src4 = reinterpret_cast(src); + uint4* dst4 = reinterpret_cast(dst); +#pragma unroll + for (int i = 0; i < N3_EMBED_DIM / 8; i++) dst4[i] = src4[i]; +#endif +} + __global__ void n3_embedding_kernel( precision_t* __restrict__ out, const precision_t* __restrict__ obs, const precision_t* __restrict__ embed_w, int B, int obs_size) { @@ -107,7 +224,82 @@ __global__ void bias_grad_kernel( } } -// NCHW bias grad: sum over (B, OH, OW) for each OC channel +// NCHW bias grad: sum over (B, OH, OW) for each OC channel (dm_spatial.d_ must equal spatial). +// NMMO3 spatial 12 / 2: stripe over b, vectorize along contiguous s; else flat-i + FastDivMod. +__global__ void n3_conv_bias_grad_nchw_fast( + precision_t* __restrict__ bgrad, const precision_t* __restrict__ grad, + int B, int OC, const FastDivMod dm_spatial) { + int oc = blockIdx.x; + if (oc >= OC) return; + const int spatial = (int)dm_spatial.d_; + const int sp_c1 = N3_C1_OH * N3_C1_OW; + const int sp_c2 = N3_C2_OH * N3_C2_OW; + float sum = 0.0f; + + if (spatial == sp_c1) { +#ifdef PRECISION_FLOAT + for (int b = threadIdx.x; b < B; b += blockDim.x) { + const float* row = grad + ((int64_t)b * OC + oc) * spatial; + float4 a0 = *reinterpret_cast(row); + float4 a1 = *reinterpret_cast(row + 4); + float4 a2 = *reinterpret_cast(row + 8); + sum += a0.x + a0.y + a0.z + a0.w + a1.x + a1.y + a1.z + a1.w + a2.x + a2.y + a2.z + a2.w; + } +#else + for (int b = threadIdx.x; b < B; b += blockDim.x) { + const __nv_bfloat16* row = grad + ((int64_t)b * OC + oc) * spatial; + const uint64_t* p = reinterpret_cast(row); + #pragma unroll + for (int j = 0; j < 3; ++j) { + union { + uint64_t u; + __nv_bfloat16 h[4]; + } w; + w.u = p[j]; + sum += to_float(w.h[0]) + to_float(w.h[1]) + to_float(w.h[2]) + to_float(w.h[3]); + } + } +#endif + } else if (spatial == sp_c2) { +#ifdef PRECISION_FLOAT + for (int b = threadIdx.x; b < B; b += blockDim.x) { + const float* row = grad + ((int64_t)b * OC + oc) * spatial; + float2 v = *reinterpret_cast(row); + sum += v.x + v.y; + } +#else + for (int b = threadIdx.x; b < B; b += blockDim.x) { + const __nv_bfloat16* row = grad + ((int64_t)b * OC + oc) * spatial; + union { + uint32_t u; + __nv_bfloat16 h[2]; + } w; + w.u = *reinterpret_cast(row); + sum += to_float(w.h[0]) + to_float(w.h[1]); + } +#endif + } else { + int total = B * spatial; + for (int i = threadIdx.x; i < total; i += blockDim.x) { + int bb, s; + dm_spatial.divmod(i, bb, s); + sum += to_float(grad[(int64_t)bb * OC * spatial + oc * spatial + s]); + } + } + for (int offset = 16; offset > 0; offset >>= 1) + sum += __shfl_down_sync(0xffffffff, sum, offset); + __shared__ float sdata[32]; + int lane = threadIdx.x % 32, warp = threadIdx.x / 32; + if (lane == 0) sdata[warp] = sum; + __syncthreads(); + if (warp == 0) { + sum = (lane < (blockDim.x + 31) / 32) ? sdata[lane] : 0.0f; + for (int offset = 16; offset > 0; offset >>= 1) + sum += __shfl_down_sync(0xffffffff, sum, offset); + if (lane == 0) bgrad[oc] = from_float(sum); + } +} + __global__ void n3_conv_bias_grad_nchw( precision_t* __restrict__ bgrad, const precision_t* __restrict__ grad, int B, int OC, int spatial) { @@ -213,7 +405,6 @@ __global__ void conv_bias_relu_kernel(precision_t* __restrict__ data, // NCHW layout throughout. Weight stored as (OC, IC*K*K). // im2col produces (B*OH*OW, IC*K*K), matmul with W^T gives (B*OH*OW, OC), // then reshape to NCHW (B, OC, OH, OW). - __global__ void im2col_kernel( const precision_t* __restrict__ input, precision_t* __restrict__ col, int B, int IC, int IH, int IW, int K, int S, int OH, int OW @@ -233,9 +424,73 @@ __global__ void im2col_kernel( col[idx] = input[b * IC * IH * IW + ic * IH * IW + ih * IW + iw]; } + + +__global__ void im2col_kernel_fast( + const precision_t* __restrict__ input, precision_t* __restrict__ col, + int B, int IC, int IH, int IW, int K, int S, int OH, int OW, + const FastDivMod dm_col_w, const FastDivMod dm_oh_ow, + const FastDivMod dm_ow, const FastDivMod dm_kk, const FastDivMod dm_k, + const int total_no_batch +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = B * total_no_batch; + if (idx >= total) return; + int row, c; + dm_col_w.divmod(idx, row, c); + int b, rem; + dm_oh_ow.divmod(row, b, rem); + int oh, ow; + dm_ow.divmod(rem, oh, ow); + int ic, kk; + dm_kk.divmod(c, ic, kk); + int kh, kw; + dm_k.divmod(kk, kh, kw); + int ih = oh * S + kh, iw = ow * S + kw; + int _IH_IW = IH * IW; + col[idx] = input[b * IC * _IH_IW + ic * _IH_IW + ih * IW + iw]; +} + // Backward: col2im — input-centric gather to avoid atomics. // Each thread owns one (b, ic, ih, iw) element and sums contributions from all // (oh, ow, kh, kw) patches that map to it. +// col2im fast path: dm_iw/dm_ih/dm_ic/dm_s and col_cols/oh_ow from Im2ColFastMods. +__global__ void col2im_kernel_fast( + const precision_t* __restrict__ col, precision_t* __restrict__ grad_input, + int B, int IC, int IH, int IW, int K, int OH, int OW, + const FastDivMod dm_iw, const FastDivMod dm_ih, const FastDivMod dm_ic, + const FastDivMod dm_s, int col_cols, int oh_ow +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = B * IC * IH * IW; + if (idx >= total) return; + int q0, iw, q1, ih, b, ic; + dm_iw.divmod(idx, q0, iw); + dm_ih.divmod(q0, q1, ih); + dm_ic.divmod(q1, b, ic); + int bohow_ickk = b * oh_ow * col_cols + ic * (K * K); + float sum = 0.0f; + for (int kh = 0; kh < K; kh++) { + int ih_off = ih - kh; + if (ih_off < 0) continue; + int oh, ih_rem; + dm_s.divmod(ih_off, oh, ih_rem); + if (ih_rem != 0 || oh >= OH) continue; + int ohowcc_khk = oh * OW * col_cols + kh * K; + int inner_value = bohow_ickk + ohowcc_khk; + for (int kw = 0; kw < K; kw++) { + int iw_off = iw - kw; + if (iw_off < 0) continue; + int ow, iw_rem; + dm_s.divmod(iw_off, ow, iw_rem); + if (iw_rem != 0 || ow >= OW) continue; + int col_idx = inner_value + ow * col_cols + kw; + sum += to_float(col[col_idx]); + } + } + grad_input[idx] = from_float(sum); +} + __global__ void col2im_kernel( const precision_t* __restrict__ col, precision_t* __restrict__ grad_input, int B, int IC, int IH, int IW, int K, int S, int OH, int OW @@ -266,6 +521,20 @@ __global__ void col2im_kernel( } // Transpose (B, OC, OH, OW) -> (B*OH*OW, OC) [NCHW to row-major spatial-first] +// Same idx layout as rows_to_nchw_kernel_fused; dm_oh_ow = spatial, dm_oc = OC. +__global__ void nchw_to_rows_kernel_fast( + const precision_t* __restrict__ src, precision_t* __restrict__ dst, + int B, int OC, int spatial, + const FastDivMod dm_oh_ow, const FastDivMod dm_oc +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = B * OC * spatial; + if (idx >= total) return; + int q, s, b, oc; + dm_oh_ow.divmod(idx, q, s); + dm_oc.divmod(q, b, oc); + dst[(b * spatial + s) * OC + oc] = src[idx]; +} __global__ void nchw_to_rows_kernel( const precision_t* __restrict__ src, precision_t* __restrict__ dst, int B, int OC, int spatial @@ -280,6 +549,33 @@ __global__ void nchw_to_rows_kernel( } // Transpose (B*OH*OW, OC) -> (B, OC, OH, OW) [row-major spatial-first to NCHW] +__global__ void rows_to_nchw_kernel_fused( + const precision_t* __restrict__ src, + const precision_t* __restrict__ bias, + precision_t* __restrict__ data, + int B, + int spatial, int oc_spatial, int OC, + const FastDivMod dm_oh_ow, + const FastDivMod dm_oc, + bool relu +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = B * oc_spatial; + if (idx >= total) return; + + int b, q, s, oc; + dm_oh_ow.divmod(idx, q, s); + dm_oc.divmod(q, b, oc); + + float value = to_float(src[(b * spatial + s) * OC + oc]); + float oc_bias = to_float(bias[oc]); + float value_bias = value + oc_bias; + if (relu) { + data[idx] = from_float(fmaxf(0.0f, value_bias)); + } else { + data[idx] = from_float(value_bias); + } +} __global__ void rows_to_nchw_kernel( const precision_t* __restrict__ src, precision_t* __restrict__ dst, int B, int OC, int spatial @@ -330,6 +626,30 @@ static void gemm_conv_forward( } } +// NMMO3 conv1/conv2 only: pass kIm2ColModsC1 or kIm2ColModsC2 (built once from N3_*). +static void gemm_conv_forward_fast( + PrecisionTensor* weight, PrecisionTensor* bias, + precision_t* input, precision_t* output, + precision_t* col_buf, precision_t* mm_buf, + int B, const Im2ColFastMods& m, bool relu, cudaStream_t stream +) { + int col_rows = B * m.oh_ow; + int total_col = col_rows * m.col_cols; + int total_out = B * m.OC * m.oh_ow; + + im2col_kernel_fast<<>>( + input, col_buf, B, m.IC, m.IH, m.IW, m.K, m.S, m.OH, m.OW, + m.dm_col_w, m.dm_oh_ow, m.dm_ow, m.dm_kk, m.dm_k, m.total_no_batch); + + PrecisionTensor col_t = {.data = col_buf, .shape = {col_rows, m.col_cols}}; + PrecisionTensor mm_t = {.data = mm_buf, .shape = {col_rows, m.OC}}; + puf_mm(&col_t, weight, &mm_t, stream); + + rows_to_nchw_kernel_fused<<>>( + mm_buf, bias->data, output, B, m.oh_ow, m.oc_spatial, m.OC, + m.dm_oh_ow, m.dm_oc, relu); +} + // Backward: weight grad + optional input grad via im2col/col2im + cuBLAS. // grad_output is NCHW (B, OC, OH, OW). saved_input is NCHW. // Caller handles relu backward and bias grad (same as cuDNN path). @@ -369,6 +689,38 @@ static void gemm_conv_backward( } } +// NMMO3 conv1/conv2 only: pass kIm2ColModsC1 or kIm2ColModsC2 (same as gemm_conv_forward_fast). +static void gemm_conv_backward_fast( + PrecisionTensor* weight, + precision_t* saved_input, precision_t* grad_output, + precision_t* wgrad, precision_t* input_grad, + precision_t* col_buf, precision_t* mm_buf, + int B, const Im2ColFastMods& m, cudaStream_t stream +) { + int col_rows = B * m.oh_ow; + int total_col = col_rows * m.col_cols; + int total_out = B * m.OC * m.oh_ow; + + nchw_to_rows_kernel_fast<<>>( + grad_output, mm_buf, B, m.OC, m.oh_ow, m.dm_oh_ow, m.dm_oc); + + im2col_kernel_fast<<>>( + saved_input, col_buf, B, m.IC, m.IH, m.IW, m.K, m.S, m.OH, m.OW, + m.dm_col_w, m.dm_oh_ow, m.dm_ow, m.dm_kk, m.dm_k, m.total_no_batch); + + PrecisionTensor mm_t = {.data = mm_buf, .shape = {col_rows, m.OC}}; + PrecisionTensor col_t = {.data = col_buf, .shape = {col_rows, m.col_cols}}; + PrecisionTensor wg_t = {.data = wgrad, .shape = {m.OC, m.col_cols}}; + puf_mm_tn(&mm_t, &col_t, &wg_t, stream); + + if (input_grad) { + puf_mm_nn(&mm_t, weight, &col_t, stream); + col2im_kernel_fast<<>>( + col_buf, input_grad, B, m.IC, m.IH, m.IW, m.K, m.OH, m.OW, + m.dm_iw, m.dm_ih, m.dm_ic, m.dm_s, m.col_cols, m.oh_ow); + } +} + // ---- NMMO3 encoder structs ---- struct NMMO3EncoderWeights { @@ -403,24 +755,23 @@ static PrecisionTensor nmmo3_encoder_forward(void* w, void* activations, Precisi if (a->saved_obs.data) puf_copy(&a->saved_obs, &input, stream); cudaMemsetAsync(a->multihot.data, 0, (int64_t)B * N3_MULTIHOT * N3_MAP_H * N3_MAP_W * sizeof(precision_t), stream); - n3_multihot_kernel<<>>( - a->multihot.data, input.data, B, ew->obs_size); + n3_multihot_kernel_fast<<>>(a->multihot.data, input.data, B, + ew->obs_size, kIm2ColModsC1.dm_n3_hwf, kIm2ColModsC1.dm_n3_hw, kIm2ColModsC1.dm_n3_w, kIm2ColModsC1.n3_hw, + kIm2ColModsC1.n3_hwf, kIm2ColModsC1.n3_multihot_plane); - gemm_conv_forward(&ew->conv1.w, &ew->conv1.b, a->multihot.data, a->conv1.out.data, - a->col1.data, a->mm1.data, B, N3_C1_IC, N3_MAP_H, N3_MAP_W, - N3_C1_OC, N3_C1_K, N3_C1_S, N3_C1_OH, N3_C1_OW, true, stream); + gemm_conv_forward_fast(&ew->conv1.w, &ew->conv1.b, a->multihot.data, a->conv1.out.data, + a->col1.data, a->mm1.data, B, kIm2ColModsC1, true, stream); if (a->conv1.saved_input.data) cudaMemcpyAsync(a->conv1.saved_input.data, a->multihot.data, (int64_t)B * N3_C1_IC * N3_MAP_H * N3_MAP_W * sizeof(precision_t), cudaMemcpyDeviceToDevice, stream); - gemm_conv_forward(&ew->conv2.w, &ew->conv2.b, a->conv1.out.data, a->conv2.out.data, - a->col2.data, a->mm2.data, B, N3_C2_IC, N3_C1_OH, N3_C1_OW, - N3_C2_OC, N3_C2_K, N3_C2_S, N3_C2_OH, N3_C2_OW, false, stream); + gemm_conv_forward_fast(&ew->conv2.w, &ew->conv2.b, a->conv1.out.data, a->conv2.out.data, + a->col2.data, a->mm2.data, B, kIm2ColModsC2, false, stream); if (a->conv2.saved_input.data) cudaMemcpyAsync(a->conv2.saved_input.data, a->conv1.out.data, (int64_t)B * N3_C2_IC * N3_C1_OH * N3_C1_OW * sizeof(precision_t), cudaMemcpyDeviceToDevice, stream); - n3_embedding_kernel<<>>( - a->embed_out.data, input.data, ew->embed_w.data, B, ew->obs_size); + n3_embedding_kernel_fast<<>>( + a->embed_out.data, input.data, ew->embed_w.data, B, ew->obs_size, kDmN3Player); n3_concat_kernel<<>>( a->concat.data, a->conv2.out.data, a->embed_out.data, input.data, B, ew->obs_size); @@ -447,24 +798,20 @@ static void nmmo3_encoder_backward(void* w, void* activations, PrecisionTensor g n3_concat_backward_conv_kernel<<>>( a->conv2.grad.data, grad_concat.data, B); - n3_conv_bias_grad_nchw<<conv2.OC, 256, 0, stream>>>( - a->conv2.bgrad.data, a->conv2.grad.data, - B, ew->conv2.OC, ew->conv2.OH * ew->conv2.OW); - gemm_conv_backward(&ew->conv2.w, a->conv2.saved_input.data, a->conv2.grad.data, + n3_conv_bias_grad_nchw_fast<<conv2.OC, 256, 0, stream>>>( + a->conv2.bgrad.data, a->conv2.grad.data, B, ew->conv2.OC, kDmConvBiasSpatialC2); + gemm_conv_backward_fast(&ew->conv2.w, a->conv2.saved_input.data, a->conv2.grad.data, a->conv2.wgrad.data, a->conv1.grad.data, - a->col2.data, a->mm2.data, B, N3_C2_IC, N3_C1_OH, N3_C1_OW, - N3_C2_OC, N3_C2_K, N3_C2_S, N3_C2_OH, N3_C2_OW, stream); + a->col2.data, a->mm2.data, B, kIm2ColModsC2, stream); n3_relu_backward_kernel<<conv1.OC * ew->conv1.OH * ew->conv1.OW), BLOCK_SIZE, 0, stream>>>( a->conv1.grad.data, a->conv1.out.data, B * ew->conv1.OC * ew->conv1.OH * ew->conv1.OW); - n3_conv_bias_grad_nchw<<conv1.OC, 256, 0, stream>>>( - a->conv1.bgrad.data, a->conv1.grad.data, - B, ew->conv1.OC, ew->conv1.OH * ew->conv1.OW); - gemm_conv_backward(&ew->conv1.w, a->conv1.saved_input.data, a->conv1.grad.data, + n3_conv_bias_grad_nchw_fast<<conv1.OC, 256, 0, stream>>>( + a->conv1.bgrad.data, a->conv1.grad.data, B, ew->conv1.OC, kDmConvBiasSpatialC1); + gemm_conv_backward_fast(&ew->conv1.w, a->conv1.saved_input.data, a->conv1.grad.data, a->conv1.wgrad.data, NULL, - a->col1.data, a->mm1.data, B, N3_C1_IC, N3_MAP_H, N3_MAP_W, - N3_C1_OC, N3_C1_K, N3_C1_S, N3_C1_OH, N3_C1_OW, stream); + a->col1.data, a->mm1.data, B, kIm2ColModsC1, stream); // Embedding backward: scatter-add from concat gradient into float buffer, then cast int embed_n = N3_EMBED_VOCAB * N3_EMBED_DIM; diff --git a/tests/bench_gemm_conv_end2end.cu b/tests/bench_gemm_conv_end2end.cu new file mode 100644 index 0000000000..564759cd5d --- /dev/null +++ b/tests/bench_gemm_conv_end2end.cu @@ -0,0 +1,613 @@ +// End-to-end: gemm conv (slow) vs gemm_fast vs cudnn — forward & backward timed separately, layers 1 & 2 (NMMO3). +// Multihot microbench: n3_multihot_kernel (reference, not fast) vs n3_multihot_kernel_fast — correctness + timing, +// B in {1024..32768} powers of two (run_n3_multihot_bench_B). +// Embedding microbench: n3_embedding_kernel vs n3_embedding_kernel_fast — own CLI flag --embedding-only (same B grid). +// Conv bias grad: n3_conv_bias_grad_nchw vs n3_conv_bias_grad_nchw_fast (FastDivMod) — --conv-bias-grad-only. +// Build/run: tests/bench_gemm_conv_end2end.sh [--float|--bf16] [--multihot-only|--embedding-only|--conv-bias-grad-only] + +#include + +#include "models.cu" +#include "ocean.cu" + +#ifndef PRECISION_FLOAT +#include +#endif +#include +#include +#include +#include +#include +#include + +static void fill_rand_host(float* h, int n, unsigned s) { + for (int i = 0; i < n; ++i) { + s = s * 1103515245u + 12345u; + h[i] = ((s >> 16) & 0x7fff) / 16384.0f - 1.0f; + } +} + +static void stats_diff(const float* a, const float* b, int n, float* max_abs, float* mean_abs) { + float mx = 0.0f, sum = 0.0f; + for (int i = 0; i < n; ++i) { + float d = fabsf(a[i] - b[i]); + if (d > mx) mx = d; + sum += d; + } + *max_abs = mx; + *mean_abs = sum / (float)std::max(1, n); +} + +// max_i |a-b| / max(|a|, eps) — scale-free compare vs reference `a`. +static float stats_rel_max(const float* a, const float* b, int n, float eps) { + float mx = 0.0f; + for (int i = 0; i < n; ++i) { + float den = fmaxf(fabsf(a[i]), eps); + float r = fabsf(a[i] - b[i]) / den; + if (r > mx) mx = r; + } + return mx; +} + +static void copy_precision_d2h(const precision_t* d, int n, std::vector* hf) { + hf->resize((size_t)n); +#ifdef PRECISION_FLOAT + cudaMemcpy(hf->data(), d, (size_t)n * sizeof(float), cudaMemcpyDeviceToHost); +#else + std::vector h((size_t)n); + cudaMemcpy(h.data(), d, (size_t)n * sizeof(precision_t), cudaMemcpyDeviceToHost); + for (int i = 0; i < n; ++i) (*hf)[i] = __bfloat162float(h[i]); +#endif +} + +static void copy_fp32_h2d(const float* h, precision_t* d, int n) { +#ifdef PRECISION_FLOAT + cudaMemcpy(d, h, (size_t)n * sizeof(float), cudaMemcpyHostToDevice); +#else + std::vector hb((size_t)n); + for (int i = 0; i < n; ++i) hb[i] = __float2bfloat16(h[i]); + cudaMemcpy(d, hb.data(), (size_t)n * sizeof(precision_t), cudaMemcpyHostToDevice); +#endif +} + +template +static float time_kernel_ms(cudaStream_t stream, F fn, int warmup, int iters) { + for (int i = 0; i < warmup; ++i) { + fn(stream); + cudaStreamSynchronize(stream); + } + cudaEvent_t ev0, ev1; + cudaEventCreate(&ev0); + cudaEventCreate(&ev1); + cudaEventRecord(ev0, stream); + for (int i = 0; i < iters; ++i) fn(stream); + cudaEventRecord(ev1, stream); + cudaEventSynchronize(ev1); + float ms = 0.0f; + cudaEventElapsedTime(&ms, ev0, ev1); + cudaEventDestroy(ev0); + cudaEventDestroy(ev1); + return ms / (float)iters; +} + +struct BenchDims { + int B; + int IC, OC, K, S, IH, IW; + bool relu; +}; + +static void dims_conv1(BenchDims* d, int B) { + d->B = B; + d->IC = N3_C1_IC; + d->OC = N3_C1_OC; + d->K = N3_C1_K; + d->S = N3_C1_S; + d->IH = N3_MAP_H; + d->IW = N3_MAP_W; + d->relu = true; +} + +static void dims_conv2(BenchDims* d, int B) { + d->B = B; + d->IC = N3_C2_IC; + d->OC = N3_C2_OC; + d->K = N3_C2_K; + d->S = N3_C2_S; + d->IH = N3_C1_OH; + d->IW = N3_C1_OW; + d->relu = false; +} + +static int run_layer(int layer, int warmup, int iters) { + BenchDims dim{}; + if (layer == 1) dims_conv1(&dim, 1024); + else if (layer == 2) dims_conv2(&dim, 1024); + else return 1; + + const Im2ColFastMods& m = (layer == 1) ? kIm2ColModsC1 : kIm2ColModsC2; + int OH = (dim.IH - dim.K) / dim.S + 1; + int OW = (dim.IW - dim.K) / dim.S + 1; + if (OH != m.OH || OW != m.OW || dim.IC != m.IC || dim.IH != m.IH || dim.IW != m.IW || dim.OC != m.OC + || dim.K != m.K || dim.S != m.S) { + fprintf(stderr, "layer %d: dim mismatch vs Im2ColFastMods\n", layer); + return 1; + } + + ConvWeights cw{}; + conv_init(&cw, dim.IC, dim.OC, dim.K, dim.S, dim.IH, dim.IW, dim.relu); + + Allocator param_alloc{}; + conv_reg_params(&cw, ¶m_alloc); + if (alloc_create(¶m_alloc) != cudaSuccess) return 1; + uint64_t seed = 1000u + (unsigned)layer; + conv_init_weights(&cw, &seed, 0); + cudaDeviceSynchronize(); + + int B = dim.B; + int out_elems = B * dim.OC * OH * OW; + int in_elems = B * dim.IC * dim.IH * dim.IW; + int w_elems = (int)numel(cw.w.shape); + int col_rows = B * OH * OW; + int col_cols = dim.IC * dim.K * dim.K; + + Allocator act_g{}; + PrecisionTensor out_gemm{}, out_fast{}, col{}, mm{}; + PrecisionTensor saved_in{}, grad_out{}, wgrad_g{}, dinput_g{}; + out_gemm = {.shape = {(int64_t)out_elems}}; + out_fast = {.shape = {(int64_t)out_elems}}; + col = {.shape = {col_rows, col_cols}}; + mm = {.shape = {col_rows, dim.OC}}; + saved_in = {.shape = {B, dim.IC, dim.IH, dim.IW}}; + grad_out = {.shape = {(int64_t)out_elems}}; + wgrad_g = {.shape = {cw.w.shape[0], cw.w.shape[1]}}; + dinput_g = {.shape = {B, dim.IC, dim.IH, dim.IW}}; + alloc_register(&act_g, &out_gemm); + alloc_register(&act_g, &out_fast); + alloc_register(&act_g, &col); + alloc_register(&act_g, &mm); + alloc_register(&act_g, &saved_in); + alloc_register(&act_g, &grad_out); + alloc_register(&act_g, &wgrad_g); + alloc_register(&act_g, &dinput_g); + if (alloc_create(&act_g) != cudaSuccess) return 1; + + Allocator acts{}, grads{}; + ConvActivations ca{}; + conv_reg_train(&cw, &ca, &acts, &grads, B, n3_cudnn_dtype()); + if (alloc_create(&acts) != cudaSuccess || alloc_create(&grads) != cudaSuccess) return 1; + + std::vector hin(in_elems), hg_up(out_elems); + fill_rand_host(hin.data(), in_elems, 201u + (unsigned)layer); + fill_rand_host(hg_up.data(), out_elems, 303u + (unsigned)layer); + copy_fp32_h2d(hin.data(), saved_in.data, in_elems); + cudaMemcpy(ca.saved_input.data, saved_in.data, (size_t)in_elems * sizeof(precision_t), cudaMemcpyDeviceToDevice); + + cudaStream_t stream = 0; + const float rel_eps = 1e-5f; + + printf("layer %d B=%d IC=%d OC=%d %dx%d K=%d S=%d -> %dx%d relu=%d\n", layer, B, dim.IC, dim.OC, dim.IH, + dim.IW, dim.K, dim.S, OH, OW, (int)dim.relu); + printf(" --- correctness (reference = gemm_conv forward/backward) ---\n"); + + gemm_conv_forward(&cw.w, &cw.b, saved_in.data, out_gemm.data, col.data, mm.data, B, dim.IC, dim.IH, dim.IW, + dim.OC, dim.K, dim.S, OH, OW, dim.relu, stream); + cudaDeviceSynchronize(); + gemm_conv_forward_fast(&cw.w, &cw.b, saved_in.data, out_fast.data, col.data, mm.data, B, m, dim.relu, stream); + cudaDeviceSynchronize(); + conv_forward(&cw, &ca, saved_in.data, B, stream); + cudaDeviceSynchronize(); + + std::vector h_gemm_o, h_fast_o, h_cdnn_o; + copy_precision_d2h(out_gemm.data, out_elems, &h_gemm_o); + copy_precision_d2h(out_fast.data, out_elems, &h_fast_o); + copy_precision_d2h(ca.out.data, out_elems, &h_cdnn_o); + float mx, mn, rel; + stats_diff(h_gemm_o.data(), h_fast_o.data(), out_elems, &mx, &mn); + rel = stats_rel_max(h_gemm_o.data(), h_fast_o.data(), out_elems, rel_eps); + printf(" forward gemm_fast vs gemm: max|diff| %.6g mean|diff| %.6g max rel err %.6g\n", mx, mn, rel); + stats_diff(h_gemm_o.data(), h_cdnn_o.data(), out_elems, &mx, &mn); + rel = stats_rel_max(h_gemm_o.data(), h_cdnn_o.data(), out_elems, rel_eps); + printf(" forward cudnn vs gemm: max|diff| %.6g mean|diff| %.6g max rel err %.6g\n", mx, mn, rel); + + copy_fp32_h2d(hg_up.data(), grad_out.data, out_elems); + cudaMemcpy(ca.grad.data, grad_out.data, (size_t)out_elems * sizeof(precision_t), cudaMemcpyDeviceToDevice); + + cudaMemset(wgrad_g.data, 0, (size_t)w_elems * sizeof(precision_t)); + cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); + gemm_conv_backward(&cw.w, saved_in.data, grad_out.data, wgrad_g.data, dinput_g.data, col.data, mm.data, B, + dim.IC, dim.IH, dim.IW, dim.OC, dim.K, dim.S, OH, OW, stream); + cudaDeviceSynchronize(); + std::vector hwg_ref, hdi_ref; + copy_precision_d2h(wgrad_g.data, w_elems, &hwg_ref); + copy_precision_d2h(dinput_g.data, in_elems, &hdi_ref); + + cudaMemset(wgrad_g.data, 0, (size_t)w_elems * sizeof(precision_t)); + cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); + gemm_conv_backward_fast(&cw.w, saved_in.data, grad_out.data, wgrad_g.data, dinput_g.data, col.data, mm.data, B, + m, stream); + cudaDeviceSynchronize(); + std::vector hwg_f, hdi_f; + copy_precision_d2h(wgrad_g.data, w_elems, &hwg_f); + copy_precision_d2h(dinput_g.data, in_elems, &hdi_f); + stats_diff(hwg_ref.data(), hwg_f.data(), w_elems, &mx, &mn); + rel = stats_rel_max(hwg_ref.data(), hwg_f.data(), w_elems, rel_eps); + printf(" backward wgrad gemm_fast vs gemm: max|diff| %.6g mean|diff| %.6g max rel err %.6g\n", mx, mn, rel); + stats_diff(hdi_ref.data(), hdi_f.data(), in_elems, &mx, &mn); + rel = stats_rel_max(hdi_ref.data(), hdi_f.data(), in_elems, rel_eps); + printf(" backward d_input gemm_fast vs gemm: max|diff| %.6g mean|diff| %.6g max rel err %.6g\n", mx, mn, rel); + + cudaMemset(ca.wgrad.data, 0, (size_t)w_elems * sizeof(precision_t)); + cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); + conv_backward(&cw, &ca, dinput_g.data, B, stream); + cudaDeviceSynchronize(); + std::vector hwg_c, hdi_c; + copy_precision_d2h(ca.wgrad.data, w_elems, &hwg_c); + copy_precision_d2h(dinput_g.data, in_elems, &hdi_c); + stats_diff(hwg_ref.data(), hwg_c.data(), w_elems, &mx, &mn); + rel = stats_rel_max(hwg_ref.data(), hwg_c.data(), w_elems, rel_eps); + printf(" backward wgrad cudnn vs gemm: max|diff| %.6g mean|diff| %.6g max rel err %.6g\n", mx, mn, rel); + stats_diff(hdi_ref.data(), hdi_c.data(), in_elems, &mx, &mn); + rel = stats_rel_max(hdi_ref.data(), hdi_c.data(), in_elems, rel_eps); + printf(" backward d_input cudnn vs gemm: max|diff| %.6g mean|diff| %.6g max rel err %.6g\n", mx, mn, rel); + + printf(" --- timing (%d warmup / %d iters): forward and backward measured separately ---\n", warmup, iters); + + auto run_gemm_fwd = [&](cudaStream_t s) { + gemm_conv_forward(&cw.w, &cw.b, saved_in.data, out_gemm.data, col.data, mm.data, B, dim.IC, dim.IH, dim.IW, + dim.OC, dim.K, dim.S, OH, OW, dim.relu, s); + }; + auto run_gemm_fast_fwd = [&](cudaStream_t s) { + gemm_conv_forward_fast(&cw.w, &cw.b, saved_in.data, out_fast.data, col.data, mm.data, B, m, dim.relu, s); + }; + auto run_cudnn_fwd = [&](cudaStream_t s) { conv_forward(&cw, &ca, saved_in.data, B, s); }; + + auto run_gemm_bwd = [&](cudaStream_t s) { + cudaMemsetAsync(wgrad_g.data, 0, (size_t)w_elems * sizeof(precision_t), s); + cudaMemsetAsync(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t), s); + gemm_conv_backward(&cw.w, saved_in.data, grad_out.data, wgrad_g.data, dinput_g.data, col.data, mm.data, B, + dim.IC, dim.IH, dim.IW, dim.OC, dim.K, dim.S, OH, OW, s); + }; + auto run_gemm_fast_bwd = [&](cudaStream_t s) { + cudaMemsetAsync(wgrad_g.data, 0, (size_t)w_elems * sizeof(precision_t), s); + cudaMemsetAsync(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t), s); + gemm_conv_backward_fast(&cw.w, saved_in.data, grad_out.data, wgrad_g.data, dinput_g.data, col.data, mm.data, + B, m, s); + }; + auto run_cudnn_bwd = [&](cudaStream_t s) { + cudaMemcpyAsync(ca.saved_input.data, saved_in.data, (size_t)in_elems * sizeof(precision_t), + cudaMemcpyDeviceToDevice, s); + cudaMemcpyAsync(ca.grad.data, grad_out.data, (size_t)out_elems * sizeof(precision_t), cudaMemcpyDeviceToDevice, + s); + cudaMemsetAsync(ca.wgrad.data, 0, (size_t)w_elems * sizeof(precision_t), s); + cudaMemsetAsync(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t), s); + conv_backward(&cw, &ca, dinput_g.data, B, s); + }; + + float ms_gf = time_kernel_ms(stream, run_gemm_fwd, warmup, iters); + float ms_ff = time_kernel_ms(stream, run_gemm_fast_fwd, warmup, iters); + float ms_cf = time_kernel_ms(stream, run_cudnn_fwd, warmup, iters); + float ms_gb = time_kernel_ms(stream, run_gemm_bwd, warmup, iters); + float ms_fb = time_kernel_ms(stream, run_gemm_fast_bwd, warmup, iters); + float ms_cb = time_kernel_ms(stream, run_cudnn_bwd, warmup, iters); + + printf(" forward:\n"); + printf(" gemm (slow): %8.4f us/iter\n", ms_gf * 1000.0f); + printf(" gemm_fast: %8.4f us/iter (%.2fx vs gemm)\n", ms_ff * 1000.0f, ms_gf / ms_ff); + printf(" cudnn: %8.4f us/iter (%.2fx vs gemm, %.2fx vs gemm_fast)\n", ms_cf * 1000.0f, ms_gf / ms_cf, + ms_ff / ms_cf); + printf(" backward:\n"); + printf(" gemm (slow): %8.4f us/iter\n", ms_gb * 1000.0f); + printf(" gemm_fast: %8.4f us/iter (%.2fx vs gemm)\n", ms_fb * 1000.0f, ms_gb / ms_fb); + printf(" cudnn: %8.4f us/iter (%.2fx vs gemm, %.2fx vs gemm_fast)\n", ms_cb * 1000.0f, ms_gb / ms_cb, + ms_fb / ms_cb); + printf("\n"); + + alloc_free(¶m_alloc); + alloc_free(&act_g); + alloc_free(&acts); + alloc_free(&grads); + return 0; +} + +// Host copy of N3_OFFSETS (device __constant__ in ocean.cu) for valid multihot index ranges. +static const int kN3OffsetsHost[10] = {0, 4, 8, 25, 30, 33, 38, 43, 48, 55}; + +static int n3_feat_max_v(int f) { + int next = (f + 1 < 10) ? kN3OffsetsHost[f + 1] : N3_MULTIHOT; + return next - kN3OffsetsHost[f] - 1; +} + +static int run_n3_multihot_bench_B(int B, int warmup, int iters) { + const int obs_size = N3_MAP_SIZE + N3_PLAYER + N3_REWARD; + const int n3_hw = N3_MAP_H * N3_MAP_W; + const int multihot_elems = B * N3_MULTIHOT * n3_hw; + const Im2ColFastMods& m = kIm2ColModsC1; + + Allocator alloc{}; + PrecisionTensor d_obs{}, d_out_ref{}, d_out_fast{}; + d_obs = {.shape = {B, obs_size}}; + d_out_ref = {.shape = {(int64_t)multihot_elems}}; + d_out_fast = {.shape = {(int64_t)multihot_elems}}; + alloc_register(&alloc, &d_obs); + alloc_register(&alloc, &d_out_ref); + alloc_register(&alloc, &d_out_fast); + if (alloc_create(&alloc) != cudaSuccess) return 1; + + std::vector h_obs((size_t)B * (size_t)obs_size, 0.0f); + unsigned seed = 4242u; + for (int b = 0; b < B; ++b) { + for (int rem = 0; rem < n3_hw; ++rem) { + for (int f = 0; f < N3_NFEAT; ++f) { + int mv = n3_feat_max_v(f); + seed = seed * 1103515245u + 12345u; + int v = (int)((seed >> 16) % (unsigned)(mv + 1)); + h_obs[(size_t)b * (size_t)obs_size + (size_t)rem * (size_t)N3_NFEAT + (size_t)f] = (float)v; + } + } + } + copy_fp32_h2d(h_obs.data(), d_obs.data, B * obs_size); + + cudaStream_t stream = 0; + const float rel_eps = 1e-5f; + + auto launch_ref = [&](cudaStream_t s) { + n3_multihot_kernel<<>>(d_out_ref.data, d_obs.data, B, obs_size); + }; + auto launch_fast = [&](cudaStream_t s) { + n3_multihot_kernel_fast<<>>(d_out_fast.data, d_obs.data, B, obs_size, + m.dm_n3_hwf, m.dm_n3_hw, m.dm_n3_w, m.n3_hw, m.n3_hwf, m.n3_multihot_plane); + }; + + const size_t multihot_bytes = (size_t)multihot_elems * sizeof(precision_t); + cudaMemsetAsync(d_out_ref.data, 0, multihot_bytes, stream); + cudaMemsetAsync(d_out_fast.data, 0, multihot_bytes, stream); + cudaStreamSynchronize(stream); + launch_ref(stream); + cudaDeviceSynchronize(); + cudaMemsetAsync(d_out_fast.data, 0, multihot_bytes, stream); + cudaStreamSynchronize(stream); + launch_fast(stream); + cudaDeviceSynchronize(); + + std::vector h_ref, h_fast; + copy_precision_d2h(d_out_ref.data, multihot_elems, &h_ref); + copy_precision_d2h(d_out_fast.data, multihot_elems, &h_fast); + float mx, mn; + stats_diff(h_ref.data(), h_fast.data(), multihot_elems, &mx, &mn); + float rel = stats_rel_max(h_ref.data(), h_fast.data(), multihot_elems, rel_eps); + + printf("n3_multihot B=%d obs_size=%d multihot_elems=%d\n", B, obs_size, multihot_elems); + printf(" reference=n3_multihot_kernel fast=n3_multihot_kernel_fast (dm_n3_hwf+dm_n3_hw+dm_n3_w, n3_hwf=%d)\n", + m.n3_hwf); + printf(" correctness fast vs reference: max|diff| %.6g mean|diff| %.6g max rel err %.6g\n", mx, mn, rel); + printf(" timing (%d warmup / %d iters), kernel only (correctness above uses cudaMemsetAsync like encoder):\n", warmup, iters); + + float ms_ref = time_kernel_ms(stream, launch_ref, warmup, iters); + float ms_f = time_kernel_ms(stream, launch_fast, warmup, iters); + printf(" n3_multihot_kernel (ref): %8.4f us/iter\n", ms_ref * 1000.0f); + printf(" n3_multihot_kernel_fast: %8.4f us/iter (%.2fx vs ref)\n", ms_f * 1000.0f, ms_ref / ms_f); + printf("\n"); + + alloc_free(&alloc); + return 0; +} + +static int run_n3_embedding_bench_B(int B, int warmup, int iters) { + const int obs_size = N3_MAP_SIZE + N3_PLAYER + N3_REWARD; + const int out_elems = B * N3_PLAYER_EMBED; + const int embed_elems = N3_EMBED_VOCAB * N3_EMBED_DIM; + + Allocator alloc{}; + PrecisionTensor d_obs{}, d_embed{}, d_out_ref{}, d_out_fast{}; + d_obs = {.shape = {B, obs_size}}; + d_embed = {.shape = {N3_EMBED_VOCAB, N3_EMBED_DIM}}; + d_out_ref = {.shape = {B, N3_PLAYER_EMBED}}; + d_out_fast = {.shape = {B, N3_PLAYER_EMBED}}; + alloc_register(&alloc, &d_obs); + alloc_register(&alloc, &d_embed); + alloc_register(&alloc, &d_out_ref); + alloc_register(&alloc, &d_out_fast); + if (alloc_create(&alloc) != cudaSuccess) return 1; + + std::vector h_obs((size_t)B * (size_t)obs_size, 0.0f); + std::vector h_embed((size_t)embed_elems); + fill_rand_host(h_embed.data(), embed_elems, 9191u); + unsigned seed = 5150u; + for (int b = 0; b < B; ++b) { + for (int f = 0; f < N3_PLAYER; ++f) { + seed = seed * 1103515245u + 12345u; + int v = (int)((seed >> 16) % (unsigned)N3_EMBED_VOCAB); + h_obs[(size_t)b * (size_t)obs_size + (size_t)N3_MAP_SIZE + (size_t)f] = (float)v; + } + } + copy_fp32_h2d(h_obs.data(), d_obs.data, B * obs_size); + copy_fp32_h2d(h_embed.data(), d_embed.data, embed_elems); + + cudaStream_t stream = 0; + const float rel_eps = 1e-5f; + + auto launch_ref = [&](cudaStream_t s) { + n3_embedding_kernel<<>>( + d_out_ref.data, d_obs.data, d_embed.data, B, obs_size); + }; + auto launch_fast = [&](cudaStream_t s) { + n3_embedding_kernel_fast<<>>( + d_out_fast.data, d_obs.data, d_embed.data, B, obs_size, kDmN3Player); + }; + + const size_t out_bytes = (size_t)out_elems * sizeof(precision_t); + cudaMemsetAsync(d_out_ref.data, 0, out_bytes, stream); + cudaMemsetAsync(d_out_fast.data, 0, out_bytes, stream); + cudaStreamSynchronize(stream); + launch_ref(stream); + cudaDeviceSynchronize(); + cudaMemsetAsync(d_out_fast.data, 0, out_bytes, stream); + cudaStreamSynchronize(stream); + launch_fast(stream); + cudaDeviceSynchronize(); + + std::vector h_ref, h_fast; + copy_precision_d2h(d_out_ref.data, out_elems, &h_ref); + copy_precision_d2h(d_out_fast.data, out_elems, &h_fast); + float mx, mn; + stats_diff(h_ref.data(), h_fast.data(), out_elems, &mx, &mn); + float rel = stats_rel_max(h_ref.data(), h_fast.data(), out_elems, rel_eps); + + printf("n3_embedding B=%d obs_size=%d out_elems=%d embed_elems=%d\n", B, obs_size, out_elems, embed_elems); + printf(" reference=n3_embedding_kernel fast=n3_embedding_kernel_fast (kDmN3Player + vec copy)\n"); + printf(" correctness fast vs reference: max|diff| %.6g mean|diff| %.6g max rel err %.6g\n", mx, mn, rel); + printf(" timing (%d warmup / %d iters), kernel only:\n", warmup, iters); + + float ms_ref = time_kernel_ms(stream, launch_ref, warmup, iters); + float ms_f = time_kernel_ms(stream, launch_fast, warmup, iters); + printf(" n3_embedding_kernel (ref): %8.4f us/iter\n", ms_ref * 1000.0f); + printf(" n3_embedding_kernel_fast: %8.4f us/iter (%.2fx vs ref)\n", ms_f * 1000.0f, ms_ref / ms_f); + printf("\n"); + + alloc_free(&alloc); + return 0; +} + +static int run_n3_multihot_bench(int warmup, int iters) { + for (int B = 1024; B <= 32768; B *= 2) { + if (run_n3_multihot_bench_B(B, warmup, iters)) return 1; + } + return 0; +} + +static int run_n3_embedding_bench(int warmup, int iters) { + for (int B = 1024; B <= 32768; B *= 2) { + if (run_n3_embedding_bench_B(B, warmup, iters)) return 1; + } + return 0; +} + +static int run_n3_conv_bias_grad_bench_B( + int B, int OC, int spatial, const FastDivMod dm_spatial, const char* tag, int warmup, int iters) { + const int grad_elems = B * OC * spatial; + cudaStream_t stream = 0; + const float rel_eps = 1e-5f; + + Allocator alloc{}; + PrecisionTensor d_grad{}, d_bref{}, d_bfast{}; + d_grad = {.shape = {(int64_t)grad_elems}}; + d_bref = {.shape = {OC}}; + d_bfast = {.shape = {OC}}; + alloc_register(&alloc, &d_grad); + alloc_register(&alloc, &d_bref); + alloc_register(&alloc, &d_bfast); + if (alloc_create(&alloc) != cudaSuccess) return 1; + + std::vector h_grad((size_t)grad_elems); + fill_rand_host(h_grad.data(), grad_elems, 77077u + (unsigned)B + (unsigned)spatial); + copy_fp32_h2d(h_grad.data(), d_grad.data, grad_elems); + + cudaMemsetAsync(d_bref.data, 0, (size_t)OC * sizeof(precision_t), stream); + n3_conv_bias_grad_nchw<<>>(d_bref.data, d_grad.data, B, OC, spatial); + cudaDeviceSynchronize(); + + cudaMemsetAsync(d_bfast.data, 0, (size_t)OC * sizeof(precision_t), stream); + n3_conv_bias_grad_nchw_fast<<>>(d_bfast.data, d_grad.data, B, OC, dm_spatial); + cudaDeviceSynchronize(); + + std::vector href, hfast; + copy_precision_d2h(d_bref.data, OC, &href); + copy_precision_d2h(d_bfast.data, OC, &hfast); + float mx, mn; + stats_diff(href.data(), hfast.data(), OC, &mx, &mn); + float rel = stats_rel_max(href.data(), hfast.data(), OC, rel_eps); + + printf("n3_conv_bias_grad %s B=%d OC=%d spatial=%d grad_elems=%d\n", tag, B, OC, spatial, grad_elems); + printf(" reference=n3_conv_bias_grad_nchw fast=n3_conv_bias_grad_nchw_fast (FastDivMod)\n"); + printf(" correctness fast vs reference: max|diff| %.6g mean|diff| %.6g max rel err %.6g\n", mx, mn, rel); + printf(" timing (%d warmup / %d iters), kernel only:\n", warmup, iters); + + auto launch_ref = [&](cudaStream_t s) { + n3_conv_bias_grad_nchw<<>>(d_bref.data, d_grad.data, B, OC, spatial); + }; + auto launch_fast = [&](cudaStream_t s) { + n3_conv_bias_grad_nchw_fast<<>>(d_bfast.data, d_grad.data, B, OC, dm_spatial); + }; + + float ms_ref = time_kernel_ms(stream, launch_ref, warmup, iters); + float ms_f = time_kernel_ms(stream, launch_fast, warmup, iters); + printf(" n3_conv_bias_grad_nchw (ref): %8.4f us/iter\n", ms_ref * 1000.0f); + printf(" n3_conv_bias_grad_nchw_fast: %8.4f us/iter (%.2fx vs ref)\n", ms_f * 1000.0f, ms_ref / ms_f); + printf("\n"); + + alloc_free(&alloc); + return 0; +} + +static int run_n3_conv_bias_grad_bench(int warmup, int iters) { + for (int B = 1024; B <= 32768; B *= 2) { + if (run_n3_conv_bias_grad_bench_B(B, N3_C2_OC, N3_C2_OH * N3_C2_OW, kDmConvBiasSpatialC2, "conv2", warmup, iters)) return 1; + if (run_n3_conv_bias_grad_bench_B(B, N3_C1_OC, N3_C1_OH * N3_C1_OW, kDmConvBiasSpatialC1, "conv1", warmup, iters)) return 1; + } + return 0; +} + +int main(int argc, char** argv) { + const int warmup = 50; + const int iters = 200; + bool run_conv = true, run_multihot = true, run_embedding = true, run_conv_bias_grad = false; + for (int i = 1; i < argc; ++i) { + if (strcmp(argv[i], "--float") == 0 || strcmp(argv[i], "--fp32") == 0 || strcmp(argv[i], "--bf16") == 0 + || strcmp(argv[i], "--half") == 0) { + continue; + } + if (strcmp(argv[i], "--multihot-only") == 0) { + run_conv = false; + run_multihot = true; + run_embedding = false; + run_conv_bias_grad = false; + continue; + } + if (strcmp(argv[i], "--embedding-only") == 0) { + run_conv = false; + run_multihot = false; + run_embedding = true; + run_conv_bias_grad = false; + continue; + } + if (strcmp(argv[i], "--conv-bias-grad-only") == 0) { + run_conv = false; + run_multihot = false; + run_embedding = false; + run_conv_bias_grad = true; + continue; + } + if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) { + printf("Usage: %s [--multihot-only] [--embedding-only] [--conv-bias-grad-only]\n", argv[0]); + printf(" Precision: compile via tests/bench_gemm_conv_end2end.sh --float|--bf16\n"); + printf(" Default: conv layers + n3_multihot ref vs fast + n3_embedding ref vs fast (B=1024..32768).\n"); + printf(" --multihot-only multihot ref vs fast only\n"); + printf(" --embedding-only embedding ref vs fast only\n"); + printf(" --conv-bias-grad-only n3_conv_bias_grad_nchw vs _fast (conv1+conv2 shapes)\n"); + return 0; + } + fprintf(stderr, "Unknown arg: %s\n", argv[i]); + return 1; + } + +#ifdef PRECISION_FLOAT + printf("bench_gemm_conv_end2end precision=fp32 warmup=%d iters=%d\n\n", warmup, iters); +#else + printf("bench_gemm_conv_end2end precision=bf16 warmup=%d iters=%d\n\n", warmup, iters); +#endif + + if (run_conv) { + if (run_layer(1, warmup, iters)) return 1; + if (run_layer(2, warmup, iters)) return 1; + } + if (run_multihot) { + if (run_n3_multihot_bench(warmup, iters)) return 1; + } + if (run_embedding) { + if (run_n3_embedding_bench(warmup, iters)) return 1; + } + if (run_conv_bias_grad) { + if (run_n3_conv_bias_grad_bench(warmup, iters)) return 1; + } + return 0; +} diff --git a/tests/bench_gemm_conv_end2end.sh b/tests/bench_gemm_conv_end2end.sh new file mode 100755 index 0000000000..1eec994315 --- /dev/null +++ b/tests/bench_gemm_conv_end2end.sh @@ -0,0 +1,75 @@ +#!/usr/bin/env bash +# Build and run end-to-end conv benchmark: gemm vs gemm_fast vs cudnn (fwd+bwd), layers 1 & 2, +# plus NMMO3 microbenches (each optional via flags; default runs conv + multihot + embedding). +# Args: +# --float | --fp32 (default) or --bf16 | --half +# --multihot-only n3_multihot ref vs fast only (B=1024..32768) +# --embedding-only n3_embedding ref vs fast only (same B grid) +# --conv-bias-grad-only n3_conv_bias_grad_nchw vs _fast, conv1+conv2 NMMO3 shapes +# +# ./tests/bench_gemm_conv_end2end.sh +# ./tests/bench_gemm_conv_end2end.sh --bf16 +# ./tests/bench_gemm_conv_end2end.sh --multihot-only +# ./tests/bench_gemm_conv_end2end.sh --embedding-only +# ./tests/bench_gemm_conv_end2end.sh --conv-bias-grad-only +set -euo pipefail +ROOT="$(cd "$(dirname "$0")/.." && pwd)" +cd "$ROOT" + +CUDA_HOME="${CUDA_HOME:-${CUDA_PATH:-$(dirname "$(dirname "$(command -v nvcc)")")}}" +NVCC="${NVCC:-$CUDA_HOME/bin/nvcc}" +ARCH="${NVCC_ARCH:-native}" + +PRECISION_FLAG="-DPRECISION_FLOAT" +EXTRA_ARGS=() +for arg in "$@"; do + case "$arg" in + --bf16|--half) PRECISION_FLAG="" ;; + --float|--fp32) PRECISION_FLAG="-DPRECISION_FLOAT" ;; + --multihot-only|--embedding-only|--conv-bias-grad-only) EXTRA_ARGS+=("$arg") ;; + *) + echo "Unknown argument: $arg (use --float, --bf16, --multihot-only, --embedding-only, or --conv-bias-grad-only)" >&2 + exit 1 + ;; + esac +done + +CUDNN_IFLAG="" +CUDNN_LFLAG="" +for dir in "$CUDA_HOME/include" /usr/local/cuda/include /usr/include; do + if [[ -f "$dir/cudnn.h" ]]; then + CUDNN_IFLAG="-I$dir" + break + fi +done +for dir in "$CUDA_HOME/lib64" "$CUDA_HOME/lib" /usr/lib/x86_64-linux-gnu; do + if [[ -f "$dir/libcudnn.so" ]] || [[ -f "$dir/libcudnn.dylib" ]]; then + CUDNN_LFLAG="-L$dir" + break + fi +done +if [[ -z "$CUDNN_IFLAG" ]]; then + CUDNN_IFLAG=$(python3 -c "import nvidia.cudnn, os; print('-I' + os.path.join(nvidia.cudnn.__path__[0], 'include'))" 2>/dev/null || true) +fi +if [[ -z "$CUDNN_LFLAG" ]]; then + CUDNN_LFLAG=$(python3 -c "import nvidia.cudnn, os; print('-L' + os.path.join(nvidia.cudnn.__path__[0], 'lib'))" 2>/dev/null || true) +fi + +OUT="${ROOT}/tests/bench_gemm_conv_end2end" +if [[ -n "$PRECISION_FLAG" ]]; then + echo "nvcc $ARCH $OUT (fp32)" +else + echo "nvcc $ARCH $OUT (bf16)" +fi +"$NVCC" -O2 -std=c++17 "-arch=$ARCH" \ + "-I${ROOT}/src" \ + "-I${CUDA_HOME}/include" \ + $CUDNN_IFLAG \ + $PRECISION_FLAG \ + "${ROOT}/tests/bench_gemm_conv_end2end.cu" \ + -o "$OUT" \ + $CUDNN_LFLAG \ + -L"${CUDA_HOME}/lib64" -L"${CUDA_HOME}/lib" \ + -lcublas -lcudnn -lcurand + +exec "$OUT" "${EXTRA_ARGS[@]}"