From 71c29773dc9e3d6bf00c96f3cf691cc3002966ee Mon Sep 17 00:00:00 2001 From: jonah Date: Wed, 8 Apr 2026 08:41:23 -0700 Subject: [PATCH 1/9] initial --- pyproject.toml | 3 + src/ocean.cu | 134 +++++- tests/bench_conv_gemm_vs_cudnn.cu | 693 ++++++++++++++++++++++++++++++ tests/bench_conv_gemm_vs_cudnn.sh | 68 +++ tests/tune_cublas_gemm.cu | 296 +++++++++++++ tests/tune_cublas_gemm.sh | 36 ++ 6 files changed, 1229 insertions(+), 1 deletion(-) create mode 100755 tests/bench_conv_gemm_vs_cudnn.cu create mode 100755 tests/bench_conv_gemm_vs_cudnn.sh create mode 100644 tests/tune_cublas_gemm.cu create mode 100755 tests/tune_cublas_gemm.sh diff --git a/pyproject.toml b/pyproject.toml index d8c18a0158..f9754941ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,9 @@ dependencies = [ "pybind11", ] +[[tool.uv.index]] +url = "https://download.pytorch.org/whl/cu128" + [project.scripts] puffer = "pufferlib.pufferl:main" diff --git a/src/ocean.cu b/src/ocean.cu index baaa9b7be6..abe0a19208 100644 --- a/src/ocean.cu +++ b/src/ocean.cu @@ -213,7 +213,6 @@ __global__ void conv_bias_relu_kernel(precision_t* __restrict__ data, // NCHW layout throughout. Weight stored as (OC, IC*K*K). // im2col produces (B*OH*OW, IC*K*K), matmul with W^T gives (B*OH*OW, OC), // then reshape to NCHW (B, OC, OH, OW). - __global__ void im2col_kernel( const precision_t* __restrict__ input, precision_t* __restrict__ col, int B, int IC, int IH, int IW, int K, int S, int OH, int OW @@ -233,6 +232,88 @@ __global__ void im2col_kernel( col[idx] = input[b * IC * IH * IW + ic * IH * IW + ih * IW + iw]; } +struct FastDivMod { + uint32_t d_; + uint32_t M_; + uint32_t l_; + + __host__ FastDivMod(int d) { + d_ = d <= 0 ? 1u : (uint32_t)d; + uint32_t l = 0; + for (; l < 32; ++l) + if ((1u << l) >= d_) break; + l_ = l; + const uint64_t one = 1; + uint64_t m = ((one << 32) * ((one << l_) - d_)) / d_ + 1; + M_ = (uint32_t)m; + } + + __device__ __forceinline__ int div(int n) const { + uint32_t u = (uint32_t)n; // n must be >= 0 + uint32_t t = __umulhi(M_, u); + return (int)((t + u) >> l_); + } + + __device__ __forceinline__ int mod(int n) const { + return n - div(n) * (int)d_; + } + + __device__ __forceinline__ void divmod(int n, int& q, int& r) const { + q = div(n); + r = n - q * (int)d_; + } +}; + +struct Im2ColFastMods { + FastDivMod dm_col_w; + FastDivMod dm_oh_ow; + FastDivMod dm_ow; + FastDivMod dm_kk; + FastDivMod dm_k; + FastDivMod dm_oc; + int total_no_batch; + int oh_ow; + int oc_spatial; + int col_cols; + int IC, IH, IW, OC, K, S, OH, OW; + + __host__ Im2ColFastMods(int ic, int ih, int iw, int oc, int k, int s, int oh, int ow) + : dm_col_w(ic * k * k), dm_oh_ow(oh * ow), dm_ow(ow), dm_kk(k * k), dm_k(k), dm_oc(oc), + total_no_batch((oh * ow) * (ic * k * k)), oh_ow(oh * ow), col_cols(ic * k * k), + oc_spatial(oc * oh * ow), + IC(ic), IH(ih), IW(iw), OC(oc), K(k), S(s), OH(oh), OW(ow) {} +}; + +static const Im2ColFastMods kIm2ColModsC1( + N3_C1_IC, N3_MAP_H, N3_MAP_W, N3_C1_OC, N3_C1_K, N3_C1_S, N3_C1_OH, N3_C1_OW); +static const Im2ColFastMods kIm2ColModsC2( + N3_C2_IC, N3_C1_OH, N3_C1_OW, N3_C2_OC, N3_C2_K, N3_C2_S, N3_C2_OH, N3_C2_OW); + +__global__ void im2col_kernel_fast( + const precision_t* __restrict__ input, precision_t* __restrict__ col, + int B, int IC, int IH, int IW, int K, int S, int OH, int OW, + const FastDivMod dm_col_w, const FastDivMod dm_oh_ow, + const FastDivMod dm_ow, const FastDivMod dm_kk, const FastDivMod dm_k, + const int total_no_batch +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = B * total_no_batch; + if (idx >= total) return; + int row, c; + dm_col_w.divmod(idx, row, c); + int b, rem; + dm_oh_ow.divmod(row, b, rem); + int oh, ow; + dm_ow.divmod(rem, oh, ow); + int ic, kk; + dm_kk.divmod(c, ic, kk); + int kh, kw; + dm_k.divmod(kk, kh, kw); + int ih = oh * S + kh, iw = ow * S + kw; + int _IH_IW = IH * IW; + col[idx] = input[b * IC * _IH_IW + ic * _IH_IW + ih * IW + iw]; +} + // Backward: col2im — input-centric gather to avoid atomics. // Each thread owns one (b, ic, ih, iw) element and sums contributions from all // (oh, ow, kh, kw) patches that map to it. @@ -280,6 +361,33 @@ __global__ void nchw_to_rows_kernel( } // Transpose (B*OH*OW, OC) -> (B, OC, OH, OW) [row-major spatial-first to NCHW] +__global__ void rows_to_nchw_kernel_fused( + const precision_t* __restrict__ src, + const precision_t* __restrict__ bias, + precision_t* __restrict__ data, + int B, + int spatial, int oc_spatial, int OC, + const FastDivMod dm_oh_ow, + const FastDivMod dm_oc, + bool relu +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = B * oc_spatial; + if (idx >= total) return; + + int b, q, s, oc; + dm_oh_ow.divmod(idx, q, s); + dm_oc.divmod(q, b, oc); + + float value = to_float(src[(b * spatial + s) * OC + oc]); + float oc_bias = to_float(bias[oc]); + float value_bias = value + oc_bias; + if (relu) { + data[idx] = from_float(fmaxf(0.0f, value_bias)); + } else { + data[idx] = from_float(value_bias); + } +} __global__ void rows_to_nchw_kernel( const precision_t* __restrict__ src, precision_t* __restrict__ dst, int B, int OC, int spatial @@ -330,6 +438,30 @@ static void gemm_conv_forward( } } +// NMMO3 conv1/conv2 only: pass kIm2ColModsC1 or kIm2ColModsC2 (built once from N3_*). +static void gemm_conv_forward_fast( + PrecisionTensor* weight, PrecisionTensor* bias, + precision_t* input, precision_t* output, + precision_t* col_buf, precision_t* mm_buf, + int B, const Im2ColFastMods& m, bool relu, cudaStream_t stream +) { + int col_rows = B * m.oh_ow; + int total_col = col_rows * m.col_cols; + int total_out = B * m.OC * m.oh_ow; + + im2col_kernel_fast<<>>( + input, col_buf, B, m.IC, m.IH, m.IW, m.K, m.S, m.OH, m.OW, + m.dm_col_w, m.dm_oh_ow, m.dm_ow, m.dm_kk, m.dm_k, m.total_no_batch); + + PrecisionTensor col_t = {.data = col_buf, .shape = {col_rows, m.col_cols}}; + PrecisionTensor mm_t = {.data = mm_buf, .shape = {col_rows, m.OC}}; + puf_mm(&col_t, weight, &mm_t, stream); + + rows_to_nchw_kernel_fused<<>>( + mm_buf, bias->data, output, B, m.oh_ow, m.oc_spatial, m.OC, + m.dm_oh_ow, m.dm_oc, relu); +} + // Backward: weight grad + optional input grad via im2col/col2im + cuBLAS. // grad_output is NCHW (B, OC, OH, OW). saved_input is NCHW. // Caller handles relu backward and bias grad (same as cuDNN path). diff --git a/tests/bench_conv_gemm_vs_cudnn.cu b/tests/bench_conv_gemm_vs_cudnn.cu new file mode 100755 index 0000000000..babf763468 --- /dev/null +++ b/tests/bench_conv_gemm_vs_cudnn.cu @@ -0,0 +1,693 @@ +// Benchmark: gemm_conv_* (im2col + cuBLAS) vs conv_* (cuDNN), same weights/inputs. +// Baseline for correctness: gemm path (per project convention). +// +// Build: see tests/bench_conv_gemm_vs_cudnn.sh (-DPRECISION_FLOAT => fp32; omit => bf16) + +#include +#include "models.cu" +#include "ocean.cu" + +#ifndef PRECISION_FLOAT +#include +#endif +#include +#include +#include +#include +#include +#include + +static void fill_rand_host(float* h, int n, unsigned s) { + for (int i = 0; i < n; ++i) { + s = s * 1103515245u + 12345u; + h[i] = ((s >> 16) & 0x7fff) / 16384.0f - 1.0f; + } +} + +static void stats_diff(const float* a, const float* b, int n, float* max_abs, float* mean_abs) { + float mx = 0.0f, sum = 0.0f; + for (int i = 0; i < n; ++i) { + float d = fabsf(a[i] - b[i]); + if (d > mx) mx = d; + sum += d; + } + *max_abs = mx; + *mean_abs = sum / (float)std::max(1, n); +} + +static void copy_precision_d2h(const precision_t* d, int n, std::vector* hf) { + hf->resize((size_t)n); +#ifdef PRECISION_FLOAT + cudaMemcpy(hf->data(), d, (size_t)n * sizeof(float), cudaMemcpyDeviceToHost); +#else + std::vector h((size_t)n); + cudaMemcpy(h.data(), d, (size_t)n * sizeof(precision_t), cudaMemcpyDeviceToHost); + for (int i = 0; i < n; ++i) (*hf)[i] = __bfloat162float(h[i]); +#endif +} + +static void copy_fp32_h2d(const float* h, precision_t* d, int n) { +#ifdef PRECISION_FLOAT + cudaMemcpy(d, h, (size_t)n * sizeof(float), cudaMemcpyHostToDevice); +#else + std::vector hb((size_t)n); + for (int i = 0; i < n; ++i) hb[i] = __float2bfloat16(h[i]); + cudaMemcpy(d, hb.data(), (size_t)n * sizeof(precision_t), cudaMemcpyHostToDevice); +#endif +} + +template +static float time_kernel_ms(cudaStream_t stream, F fn, int warmup, int iters) { + for (int i = 0; i < warmup; ++i) { + fn(stream); + cudaStreamSynchronize(stream); + } + cudaEvent_t ev0, ev1; + cudaEventCreate(&ev0); + cudaEventCreate(&ev1); + cudaEventRecord(ev0, stream); + for (int i = 0; i < iters; ++i) fn(stream); + cudaEventRecord(ev1, stream); + cudaEventSynchronize(ev1); + float ms = 0.0f; + cudaEventElapsedTime(&ms, ev0, ev1); + cudaEventDestroy(ev0); + cudaEventDestroy(ev1); + return ms / (float)iters; +} + +struct BenchDims { + int B; + int IC, OC, K, S, IH, IW; + bool relu; +}; + +static void dims_conv1(BenchDims* d, int B) { + d->B = B; + d->IC = N3_C1_IC; + d->OC = N3_C1_OC; + d->K = N3_C1_K; + d->S = N3_C1_S; + d->IH = N3_MAP_H; + d->IW = N3_MAP_W; + d->relu = true; +} + +static void dims_conv2(BenchDims* d, int B) { + d->B = B; + d->IC = N3_C2_IC; + d->OC = N3_C2_OC; + d->K = N3_C2_K; + d->S = N3_C2_S; + d->IH = N3_C1_OH; + d->IW = N3_C1_OW; + d->relu = false; +} + +static int run_forward(const BenchDims& dim, int warmup, int iters, bool cudnn_save_input) { + ConvWeights cw{}; + conv_init(&cw, dim.IC, dim.OC, dim.K, dim.S, dim.IH, dim.IW, dim.relu); + + Allocator param_alloc{}; + conv_reg_params(&cw, ¶m_alloc); + if (alloc_create(¶m_alloc) != cudaSuccess) { + fprintf(stderr, "alloc_create params failed\n"); + return 1; + } + uint64_t seed = 42; + conv_init_weights(&cw, &seed, 0); + cudaDeviceSynchronize(); + + int OH = cw.OH; + int OW = cw.OW; + int out_elems = dim.B * dim.OC * OH * OW; + int in_elems = dim.B * dim.IC * dim.IH * dim.IW; + + Allocator act_g{}; + PrecisionTensor out_g{}, col{}, mm{}, input{}; + int col_rows = dim.B * OH * OW; + int col_cols = dim.IC * dim.K * dim.K; + out_g = {.shape = {(int64_t)out_elems}}; + col = {.shape = {col_rows, col_cols}}; + mm = {.shape = {col_rows, dim.OC}}; + input = {.shape = {dim.B, dim.IC, dim.IH, dim.IW}}; + alloc_register(&act_g, &out_g); + alloc_register(&act_g, &col); + alloc_register(&act_g, &mm); + alloc_register(&act_g, &input); + if (alloc_create(&act_g) != cudaSuccess) { + fprintf(stderr, "alloc_create gemm acts failed\n"); + return 1; + } + + std::vector hin(in_elems); + fill_rand_host(hin.data(), in_elems, 99u); + copy_fp32_h2d(hin.data(), input.data, in_elems); + + Allocator acts{}, grads{}; + ConvActivations ca{}; + conv_reg_train(&cw, &ca, &acts, &grads, dim.B, n3_cudnn_dtype()); + if (alloc_create(&acts) != cudaSuccess || alloc_create(&grads) != cudaSuccess) { + fprintf(stderr, "alloc_create cudnn failed\n"); + return 1; + } + if (!cudnn_save_input) ca.saved_input.data = nullptr; + + cudaStream_t stream = 0; + + gemm_conv_forward(&cw.w, &cw.b, input.data, out_g.data, col.data, mm.data, dim.B, dim.IC, dim.IH, + dim.IW, dim.OC, dim.K, dim.S, OH, OW, dim.relu, stream); + cudaDeviceSynchronize(); + + conv_forward(&cw, &ca, input.data, dim.B, stream); + cudaDeviceSynchronize(); + + std::vector hg, hc; + copy_precision_d2h(out_g.data, out_elems, &hg); + copy_precision_d2h(ca.out.data, out_elems, &hc); + + float max_abs, mean_abs; + stats_diff(hg.data(), hc.data(), out_elems, &max_abs, &mean_abs); + printf(" forward max |diff|: %.6g mean |diff|: %.6g\n", max_abs, mean_abs); + + auto run_gemm = [&](cudaStream_t s) { + gemm_conv_forward(&cw.w, &cw.b, input.data, out_g.data, col.data, mm.data, dim.B, dim.IC, + dim.IH, dim.IW, dim.OC, dim.K, dim.S, OH, OW, dim.relu, s); + }; + auto run_cudnn = [&](cudaStream_t s) { + conv_forward(&cw, &ca, input.data, dim.B, s); + }; + + float ms_g = time_kernel_ms(stream, run_gemm, warmup, iters); + float ms_c = time_kernel_ms(stream, run_cudnn, warmup, iters); + printf(" gemm_conv_forward: %8.4f us/iter\n", ms_g * 1000.0f); + printf(" conv_forward: %8.4f us/iter (%.2fx vs gemm)\n", ms_c * 1000.0f, ms_g / ms_c); + + alloc_free(¶m_alloc); + alloc_free(&act_g); + alloc_free(&acts); + alloc_free(&grads); + return 0; +} + +// ∂L/∂W only: gemm path vs cudnnConvolutionBackwardFilter. +// Correctness: gemm_conv_backward(..., input_grad=null) vs conv_backward(..., nullptr). +// Timings: inlined nchw+im2col+mm_tn, gemm_conv_backward, and cudnn BackwardFilter. +static int run_filter_backward_bench(const BenchDims& dim, int warmup, int iters) { + int B = dim.B, IC = dim.IC, OC = dim.OC, K = dim.K, S = dim.S, IH = dim.IH, IW = dim.IW; + int OH = (IH - K) / S + 1; + int OW = (IW - K) / S + 1; + int col_rows = B * OH * OW; + int col_cols = IC * K * K; + int total_col = col_rows * col_cols; + int total_out = B * OC * OH * OW; + int in_elems = B * IC * IH * IW; + int spatial = OH * OW; + int w_elems = OC * col_cols; + + ConvWeights cw{}; + conv_init(&cw, IC, OC, K, S, IH, IW, false); + Allocator param_alloc{}; + conv_reg_params(&cw, ¶m_alloc); + if (alloc_create(¶m_alloc) != cudaSuccess) return 1; + uint64_t seed = 301; + conv_init_weights(&cw, &seed, 0); + cudaDeviceSynchronize(); + + Allocator act{}; + PrecisionTensor col{}, mm{}, saved_in{}, grad_out{}, wgrad{}; + col = {.shape = {col_rows, col_cols}}; + mm = {.shape = {col_rows, OC}}; + saved_in = {.shape = {B, IC, IH, IW}}; + grad_out = {.shape = {(int64_t)total_out}}; + wgrad = {.shape = {OC, col_cols}}; + alloc_register(&act, &col); + alloc_register(&act, &mm); + alloc_register(&act, &saved_in); + alloc_register(&act, &grad_out); + alloc_register(&act, &wgrad); + if (alloc_create(&act) != cudaSuccess) return 1; + + Allocator acts{}, grads{}; + ConvActivations ca{}; + conv_reg_train(&cw, &ca, &acts, &grads, B, n3_cudnn_dtype()); + if (alloc_create(&acts) != cudaSuccess || alloc_create(&grads) != cudaSuccess) return 1; + + std::vector hs(in_elems), hg(total_out); + fill_rand_host(hs.data(), in_elems, 301u); + fill_rand_host(hg.data(), total_out, 302u); + copy_fp32_h2d(hs.data(), saved_in.data, in_elems); + copy_fp32_h2d(hg.data(), grad_out.data, total_out); + cudaMemcpy(ca.saved_input.data, saved_in.data, (size_t)in_elems * sizeof(precision_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(ca.grad.data, grad_out.data, (size_t)total_out * sizeof(precision_t), cudaMemcpyDeviceToDevice); + + PrecisionTensor mm_t = {.data = mm.data, .shape = {col_rows, OC}}; + PrecisionTensor col_t = {.data = col.data, .shape = {col_rows, col_cols}}; + PrecisionTensor wg_t = {.data = wgrad.data, .shape = {OC, col_cols}}; + + cudaStream_t stream = 0; + + cudaMemset(wgrad.data, 0, (size_t)w_elems * sizeof(precision_t)); + gemm_conv_backward(&cw.w, saved_in.data, grad_out.data, wgrad.data, nullptr, col.data, mm.data, B, IC, IH, IW, + OC, K, S, OH, OW, stream); + cudaDeviceSynchronize(); + std::vector h_wg_g, h_wg_c; + copy_precision_d2h(wgrad.data, w_elems, &h_wg_g); + + cudaMemset(ca.wgrad.data, 0, (size_t)w_elems * sizeof(precision_t)); + conv_backward(&cw, &ca, nullptr, B, stream); + cudaDeviceSynchronize(); + copy_precision_d2h(ca.wgrad.data, w_elems, &h_wg_c); + + float f_wg_max, f_wg_mean; + stats_diff(h_wg_g.data(), h_wg_c.data(), w_elems, &f_wg_max, &f_wg_mean); + printf(" filter ∂W max |diff| (gemm_conv_backward vs cudnn): %.6g mean |diff|: %.6g\n", f_wg_max, + f_wg_mean); + + float ms_nchw = time_kernel_ms( + stream, + [&](cudaStream_t s) { + nchw_to_rows_kernel<<>>( + grad_out.data, mm.data, B, OC, spatial); + }, + warmup, iters); + + float ms_im2col = time_kernel_ms( + stream, + [&](cudaStream_t s) { + im2col_kernel<<>>( + saved_in.data, col.data, B, IC, IH, IW, K, S, OH, OW); + }, + warmup, iters); + + nchw_to_rows_kernel<<>>( + grad_out.data, mm.data, B, OC, spatial); + im2col_kernel<<>>( + saved_in.data, col.data, B, IC, IH, IW, K, S, OH, OW); + cudaStreamSynchronize(stream); + + float ms_tn = time_kernel_ms( + stream, + [&](cudaStream_t s) { puf_mm_tn(&mm_t, &col_t, &wg_t, s); }, warmup, iters); + + float ms_gemm_chain = time_kernel_ms( + stream, + [&](cudaStream_t s) { + nchw_to_rows_kernel<<>>( + grad_out.data, mm.data, B, OC, spatial); + im2col_kernel<<>>( + saved_in.data, col.data, B, IC, IH, IW, K, S, OH, OW); + puf_mm_tn(&mm_t, &col_t, &wg_t, s); + }, + warmup, iters); + + float ms_gemm_conv_bwd_wonly = time_kernel_ms( + stream, + [&](cudaStream_t s) { + cudaMemset(wgrad.data, 0, (size_t)w_elems * sizeof(precision_t)); + gemm_conv_backward(&cw.w, saved_in.data, grad_out.data, wgrad.data, nullptr, col.data, mm.data, B, IC, + IH, IW, OC, K, S, OH, OW, s); + }, + warmup, iters); + + float ms_cudnn_filt = time_kernel_ms( + stream, + [&](cudaStream_t s) { + cudaMemset(ca.wgrad.data, 0, (size_t)w_elems * sizeof(precision_t)); + conv_backward(&cw, &ca, nullptr, B, s); + }, + warmup, iters); + + float sum_iso = (ms_nchw + ms_im2col + ms_tn) * 1000.0f; + printf(" gemm: nchw_to_rows only: %8.4f us/iter\n", ms_nchw * 1000.0f); + printf(" gemm: im2col only: %8.4f us/iter\n", ms_im2col * 1000.0f); + printf(" gemm: puf_mm_tn only: %8.4f us/iter (mm/col prefilled)\n", ms_tn * 1000.0f); + printf(" gemm: nchw+im2col+mm_tn: %8.4f us/iter (chained, matches ∂W slice)\n", ms_gemm_chain * 1000.0f); + printf(" gemm: gemm_conv_backward: %8.4f us/iter (input_grad=null, same as ocean.cu ∂W-only)\n", + ms_gemm_conv_bwd_wonly * 1000.0f); + printf(" sum(3 isolated): %8.4f us/iter (vs chained)\n", sum_iso); + printf(" cudnn: BackwardFilter only: %8.4f us/iter (conv_backward, input_grad=null)\n", + ms_cudnn_filt * 1000.0f); + printf(" ratio gemm_chain/cudnn: %.2fx (>1 => cudnn faster)\n", ms_gemm_chain / ms_cudnn_filt); + printf(" ratio gemm_conv_bwd/cudnn: %.2fx\n", ms_gemm_conv_bwd_wonly / ms_cudnn_filt); + + alloc_free(¶m_alloc); + alloc_free(&act); + alloc_free(&acts); + alloc_free(&grads); + return 0; +} + +static int run_backward(const BenchDims& dim, int warmup, int iters) { + ConvWeights cw{}; + conv_init(&cw, dim.IC, dim.OC, dim.K, dim.S, dim.IH, dim.IW, false); + + Allocator param_alloc{}; + conv_reg_params(&cw, ¶m_alloc); + if (alloc_create(¶m_alloc) != cudaSuccess) return 1; + uint64_t seed = 7; + conv_init_weights(&cw, &seed, 0); + cudaDeviceSynchronize(); + + int OH = cw.OH; + int OW = cw.OW; + int out_elems = dim.B * dim.OC * OH * OW; + int in_elems = dim.B * dim.IC * dim.IH * dim.IW; + int w_elems = (int)numel(cw.w.shape); + + Allocator act_g{}; + PrecisionTensor col{}, mm{}; + int col_rows = dim.B * OH * OW; + int col_cols = dim.IC * dim.K * dim.K; + col = {.shape = {col_rows, col_cols}}; + mm = {.shape = {col_rows, dim.OC}}; + PrecisionTensor saved_in{}, grad_out{}, wgrad_g{}; + saved_in = {.shape = {dim.B, dim.IC, dim.IH, dim.IW}}; + grad_out = {.shape = {(int64_t)out_elems}}; + wgrad_g = {.shape = {cw.w.shape[0], cw.w.shape[1]}}; + PrecisionTensor dinput_g{}; + dinput_g = {.shape = {dim.B, dim.IC, dim.IH, dim.IW}}; + alloc_register(&act_g, &col); + alloc_register(&act_g, &mm); + alloc_register(&act_g, &saved_in); + alloc_register(&act_g, &grad_out); + alloc_register(&act_g, &wgrad_g); + alloc_register(&act_g, &dinput_g); + if (alloc_create(&act_g) != cudaSuccess) return 1; + + std::vector hs(in_elems), hg(out_elems); + fill_rand_host(hs.data(), in_elems, 101u); + fill_rand_host(hg.data(), out_elems, 202u); + copy_fp32_h2d(hs.data(), saved_in.data, in_elems); + copy_fp32_h2d(hg.data(), grad_out.data, out_elems); + + Allocator acts{}, grads{}; + ConvActivations ca{}; + conv_reg_train(&cw, &ca, &acts, &grads, dim.B, n3_cudnn_dtype()); + if (alloc_create(&acts) != cudaSuccess || alloc_create(&grads) != cudaSuccess) return 1; + cudaMemcpy(ca.saved_input.data, saved_in.data, (size_t)in_elems * sizeof(precision_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(ca.grad.data, grad_out.data, (size_t)out_elems * sizeof(precision_t), cudaMemcpyDeviceToDevice); + + cudaStream_t stream = 0; + + cudaMemset(wgrad_g.data, 0, (size_t)w_elems * sizeof(precision_t)); + cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); + gemm_conv_backward(&cw.w, saved_in.data, grad_out.data, wgrad_g.data, dinput_g.data, col.data, mm.data, + dim.B, dim.IC, dim.IH, dim.IW, dim.OC, dim.K, dim.S, OH, OW, stream); + cudaDeviceSynchronize(); + + std::vector hwg, hdi; + copy_precision_d2h(wgrad_g.data, w_elems, &hwg); + copy_precision_d2h(dinput_g.data, in_elems, &hdi); + + cudaMemset(ca.wgrad.data, 0, (size_t)w_elems * sizeof(precision_t)); + cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); + conv_backward(&cw, &ca, dinput_g.data, dim.B, stream); + cudaDeviceSynchronize(); + + std::vector hwg_c, hdi_c; + copy_precision_d2h(ca.wgrad.data, w_elems, &hwg_c); + copy_precision_d2h(dinput_g.data, in_elems, &hdi_c); + + float mw, mdw, mdi, mdi_mean; + stats_diff(hwg.data(), hwg_c.data(), w_elems, &mw, &mdw); + stats_diff(hdi.data(), hdi_c.data(), in_elems, &mdi, &mdi_mean); + printf(" backward wgrad max |diff|: %.6g\n", mw); + printf(" backward d_input max |diff|: %.6g\n", mdi); + + auto run_gemm_b = [&](cudaStream_t s) { + cudaMemset(wgrad_g.data, 0, (size_t)w_elems * sizeof(precision_t)); + cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); + gemm_conv_backward(&cw.w, saved_in.data, grad_out.data, wgrad_g.data, dinput_g.data, col.data, + mm.data, dim.B, dim.IC, dim.IH, dim.IW, dim.OC, dim.K, dim.S, OH, OW, s); + }; + auto run_cudnn_b = [&](cudaStream_t s) { + cudaMemset(ca.wgrad.data, 0, (size_t)w_elems * sizeof(precision_t)); + cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); + conv_backward(&cw, &ca, dinput_g.data, dim.B, s); + }; + + float ms_g = time_kernel_ms(stream, run_gemm_b, warmup, iters); + float ms_c = time_kernel_ms(stream, run_cudnn_b, warmup, iters); + printf(" gemm_conv_backward: %8.4f us/iter\n", ms_g * 1000.0f); + printf(" conv_backward: %8.4f us/iter (%.2fx vs gemm)\n", ms_c * 1000.0f, ms_g / ms_c); + + alloc_free(¶m_alloc); + alloc_free(&act_g); + alloc_free(&acts); + alloc_free(&grads); + return 0; +} + +static int run_im2col_bench(const BenchDims& dim, int warmup, int iters) { + int B = dim.B, IC = dim.IC, IH = dim.IH, IW = dim.IW, K = dim.K, S = dim.S; + int OH = (IH - K) / S + 1; + int OW = (IW - K) / S + 1; + int total_col = B * OH * OW * IC * K * K; + int in_elems = B * IC * IH * IW; + + precision_t *d_in = nullptr, *d_col_slow = nullptr, *d_col_fast = nullptr; + if (cudaMalloc(&d_in, (size_t)in_elems * sizeof(precision_t)) != cudaSuccess) return 1; + if (cudaMalloc(&d_col_slow, (size_t)total_col * sizeof(precision_t)) != cudaSuccess) return 1; + if (cudaMalloc(&d_col_fast, (size_t)total_col * sizeof(precision_t)) != cudaSuccess) return 1; + + std::vector h_in((size_t)in_elems); + fill_rand_host(h_in.data(), in_elems, 401u); + copy_fp32_h2d(h_in.data(), d_in, in_elems); + cudaDeviceSynchronize(); + + const int oh_ow = OH * OW; + const int col_cols = IC * K * K; + const int total_no_batch = oh_ow * col_cols; + const int kk = K * K; + FastDivMod dm_col_w(col_cols); + FastDivMod dm_oh_ow(oh_ow); + FastDivMod dm_ow(OW); + FastDivMod dm_kk(kk); + FastDivMod dm_k(K); + + cudaStream_t stream = 0; + im2col_kernel<<>>( + d_in, d_col_slow, B, IC, IH, IW, K, S, OH, OW); + im2col_kernel_fast<<>>( + d_in, d_col_fast, B, IC, IH, IW, K, S, OH, OW, + dm_col_w, dm_oh_ow, dm_ow, dm_kk, dm_k, total_no_batch); + cudaDeviceSynchronize(); + + std::vector hs, hf; + copy_precision_d2h(d_col_slow, total_col, &hs); + copy_precision_d2h(d_col_fast, total_col, &hf); + float max_d = 0.0f, mean_d = 0.0f; + stats_diff(hs.data(), hf.data(), total_col, &max_d, &mean_d); + printf(" im2col vs im2col_fast max |diff|: %.6g mean |diff|: %.6g\n", max_d, mean_d); + + float ms_slow = time_kernel_ms( + stream, + [&](cudaStream_t s) { + im2col_kernel<<>>( + d_in, d_col_slow, B, IC, IH, IW, K, S, OH, OW); + }, + warmup, iters); + float ms_fast = time_kernel_ms( + stream, + [&](cudaStream_t s) { + im2col_kernel_fast<<>>( + d_in, d_col_fast, B, IC, IH, IW, K, S, OH, OW, + dm_col_w, dm_oh_ow, dm_ow, dm_kk, dm_k, total_no_batch); + }, + warmup, iters); + printf(" im2col_kernel: %8.4f us/iter\n", ms_slow * 1000.0f); + printf(" im2col_kernel_fast: %8.4f us/iter (%.2fx vs slow)\n", ms_fast * 1000.0f, + ms_slow / ms_fast); + + cudaFree(d_in); + cudaFree(d_col_slow); + cudaFree(d_col_fast); + return 0; +} + +// gemm_conv_forward vs gemm_conv_forward_fast; relu on/off (NMMO3 layer geometry only). +static int run_gemm_fast_fwd_bench(const BenchDims& dim, int layer, int warmup, int iters) { + const Im2ColFastMods& m = (layer == 1) ? kIm2ColModsC1 : kIm2ColModsC2; + + ConvWeights cw{}; + conv_init(&cw, dim.IC, dim.OC, dim.K, dim.S, dim.IH, dim.IW, dim.relu); + Allocator param_alloc{}; + conv_reg_params(&cw, ¶m_alloc); + if (alloc_create(¶m_alloc) != cudaSuccess) return 1; + uint64_t seed = 55; + conv_init_weights(&cw, &seed, 0); + cudaDeviceSynchronize(); + + int OH = cw.OH; + int OW = cw.OW; + int out_elems = dim.B * dim.OC * OH * OW; + int in_elems = dim.B * dim.IC * dim.IH * dim.IW; + int col_rows = dim.B * OH * OW; + int col_cols = dim.IC * dim.K * dim.K; + + Allocator act{}; + PrecisionTensor out_s{}, out_f{}, col{}, mm{}, input{}; + out_s = {.shape = {(int64_t)out_elems}}; + out_f = {.shape = {(int64_t)out_elems}}; + col = {.shape = {col_rows, col_cols}}; + mm = {.shape = {col_rows, dim.OC}}; + input = {.shape = {dim.B, dim.IC, dim.IH, dim.IW}}; + alloc_register(&act, &out_s); + alloc_register(&act, &out_f); + alloc_register(&act, &col); + alloc_register(&act, &mm); + alloc_register(&act, &input); + if (alloc_create(&act) != cudaSuccess) return 1; + + std::vector hin(in_elems); + fill_rand_host(hin.data(), in_elems, 77u); + copy_fp32_h2d(hin.data(), input.data, in_elems); + cudaDeviceSynchronize(); + + cudaStream_t stream = 0; + + for (int ri = 0; ri < 2; ++ri) { + bool use_relu = (ri == 1); + gemm_conv_forward(&cw.w, &cw.b, input.data, out_s.data, col.data, mm.data, dim.B, dim.IC, dim.IH, + dim.IW, dim.OC, dim.K, dim.S, OH, OW, use_relu, stream); + cudaDeviceSynchronize(); + gemm_conv_forward_fast(&cw.w, &cw.b, input.data, out_f.data, col.data, mm.data, dim.B, m, use_relu, + stream); + cudaDeviceSynchronize(); + std::vector hs, hf; + copy_precision_d2h(out_s.data, out_elems, &hs); + copy_precision_d2h(out_f.data, out_elems, &hf); + float mx, mn; + stats_diff(hs.data(), hf.data(), out_elems, &mx, &mn); + printf(" relu=%d max |diff| slow vs fast: %.6g mean |diff|: %.6g\n", (int)use_relu, mx, mn); + + float ms_slow = time_kernel_ms( + stream, + [&](cudaStream_t s) { + gemm_conv_forward(&cw.w, &cw.b, input.data, out_s.data, col.data, mm.data, dim.B, dim.IC, + dim.IH, dim.IW, dim.OC, dim.K, dim.S, OH, OW, use_relu, s); + }, + warmup, iters); + float ms_fast = time_kernel_ms( + stream, + [&](cudaStream_t s) { + gemm_conv_forward_fast(&cw.w, &cw.b, input.data, out_f.data, col.data, mm.data, dim.B, m, + use_relu, s); + }, + warmup, iters); + printf(" relu=%d gemm_conv_forward: %8.4f us/iter\n", (int)use_relu, ms_slow * 1000.0f); + printf(" relu=%d gemm_conv_forward_fast: %8.4f us/iter (%.2fx vs slow)\n", (int)use_relu, + ms_fast * 1000.0f, ms_slow / ms_fast); + } + + alloc_free(¶m_alloc); + alloc_free(&act); + return 0; +} + +int main(int argc, char** argv) { + int B = 1024; + int layer = 1; + int warmup = 50; + int iters = 200; + bool do_fwd = true; + bool do_bwd = true; + bool cudnn_save = true; + bool do_wgrad_breakdown = false; + bool do_im2col_bench = false; + bool do_gemm_fast_bench = false; + for (int i = 1; i < argc; ++i) { + if (strcmp(argv[i], "-B") == 0 && i + 1 < argc) B = atoi(argv[++i]); + else if (strcmp(argv[i], "--layer") == 0 && i + 1 < argc) layer = atoi(argv[++i]); + else if (strcmp(argv[i], "--warmup") == 0 && i + 1 < argc) warmup = atoi(argv[++i]); + else if (strcmp(argv[i], "--iters") == 0 && i + 1 < argc) iters = atoi(argv[++i]); + else if (strcmp(argv[i], "--forward-only") == 0) do_bwd = false; + else if (strcmp(argv[i], "--backward-only") == 0) do_fwd = false; + else if (strcmp(argv[i], "--no-cudnn-save-input") == 0) cudnn_save = false; + else if (strcmp(argv[i], "--wgrad-breakdown-only") == 0 + || strcmp(argv[i], "--filter-bwd-only") == 0) { + do_wgrad_breakdown = true; + do_fwd = false; + do_bwd = false; + } else if (strcmp(argv[i], "--wgrad-breakdown") == 0 || strcmp(argv[i], "--filter-bwd") == 0) { + do_wgrad_breakdown = true; + } else if (strcmp(argv[i], "--im2col-bench-only") == 0) { + do_im2col_bench = true; + do_fwd = false; + do_bwd = false; + do_wgrad_breakdown = false; + } else if (strcmp(argv[i], "--im2col-bench") == 0) { + do_im2col_bench = true; + } else if (strcmp(argv[i], "--gemm-fast-bench-only") == 0) { + do_gemm_fast_bench = true; + do_fwd = false; + do_bwd = false; + do_wgrad_breakdown = false; + do_im2col_bench = false; + } else if (strcmp(argv[i], "--gemm-fast-bench") == 0) { + do_gemm_fast_bench = true; + } else if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) { + printf("Usage: %s [options]\n", argv[0]); + printf(" -B N batch size (default 1024)\n"); + printf(" --layer 1|2 NMMO3 conv1 or conv2 sizes (default 1)\n"); + printf(" --warmup N timing warmup runs (default 50)\n"); + printf(" --iters N timed iterations (default 200)\n"); + printf(" --forward-only only forward pass\n"); + printf(" --backward-only only backward (identity activation)\n"); + printf(" --no-cudnn-save-input omit cudnn memcpy to saved_input (forward timing)\n"); + printf(" --filter-bwd / --wgrad-breakdown also bench ∂W: gemm (nchw+im2col+mm_tn) vs cudnn BackwardFilter\n"); + printf(" --filter-bwd-only / --wgrad-breakdown-only only that ∂W bench\n"); + printf(" --im2col-bench also bench im2col_kernel vs im2col_kernel_fast\n"); + printf(" --im2col-bench-only only that im2col bench\n"); + printf(" --gemm-fast-bench also bench gemm_conv_forward vs gemm_conv_forward_fast (relu 0/1)\n"); + printf(" --gemm-fast-bench-only only that bench\n"); + printf(" (script) --float / --fp32 compile fp32 (default)\n"); + printf(" (script) --bf16 / --half compile bf16 (matches default native backend)\n"); + return 0; + } + } + + BenchDims dim{}; + if (layer == 1) dims_conv1(&dim, B); + else if (layer == 2) dims_conv2(&dim, B); + else { + fprintf(stderr, "layer must be 1 or 2\n"); + return 1; + } + + int OH = (dim.IH - dim.K) / dim.S + 1; + int OW = (dim.IW - dim.K) / dim.S + 1; + printf("bench_conv_gemm_vs_cudnn B=%d layer=%d IC=%d OC=%d %dx%d K=%d S=%d -> %dx%d relu=%d", + dim.B, layer, dim.IC, dim.OC, dim.IH, dim.IW, dim.K, dim.S, OH, OW, (int)dim.relu); +#ifdef PRECISION_FLOAT + printf(" precision=fp32\n"); +#else + printf(" precision=bf16\n"); +#endif + + if (do_fwd) { + printf("\n--- forward (gemm baseline vs cudnn) ---\n"); + if (run_forward(dim, warmup, iters, cudnn_save)) return 1; + } + if (do_bwd) { + BenchDims bd = dim; + bd.relu = false; + printf("\n--- backward (identity conv; gemm vs cudnn) ---\n"); + printf(" (relu ignored: identity conv so cudnn bwd matches gemm without ReLU mask)\n"); + if (run_backward(bd, warmup, iters)) return 1; + } + if (do_wgrad_breakdown) { + printf("\n--- filter backward (∂W): gemm path vs cudnnConvolutionBackwardFilter ---\n"); + if (run_filter_backward_bench(dim, warmup, iters)) return 1; + } + if (do_im2col_bench) { + printf("\n--- im2col_kernel vs im2col_kernel_fast ---\n"); + if (run_im2col_bench(dim, warmup, iters)) return 1; + } + if (do_gemm_fast_bench) { + printf("\n--- gemm_conv_forward vs gemm_conv_forward_fast (relu off/on) ---\n"); + if (run_gemm_fast_fwd_bench(dim, layer, warmup, iters)) return 1; + } + return 0; +} diff --git a/tests/bench_conv_gemm_vs_cudnn.sh b/tests/bench_conv_gemm_vs_cudnn.sh new file mode 100755 index 0000000000..4e35dffaa3 --- /dev/null +++ b/tests/bench_conv_gemm_vs_cudnn.sh @@ -0,0 +1,68 @@ +#!/usr/bin/env bash +# Build and run gemm (im2col+cuBLAS) vs cuDNN conv benchmark. +# Default: fp32. Use --bf16 / --half for bf16 (matches native backend without --float). +# +# im2col vs im2col_kernel_fast (correctness + timing), same layer sizes as conv bench: +# ./tests/bench_conv_gemm_vs_cudnn.sh --layer 1 --im2col-bench-only +# ./tests/bench_conv_gemm_vs_cudnn.sh --bf16 --layer 2 --im2col-bench +# gemm slow vs fast forward (relu 0/1), NMMO3 layer sizes: +# ./tests/bench_conv_gemm_vs_cudnn.sh --layer 1 --gemm-fast-bench-only +set -euo pipefail +ROOT="$(cd "$(dirname "$0")/.." && pwd)" +cd "$ROOT" + +CUDA_HOME="${CUDA_HOME:-${CUDA_PATH:-$(dirname "$(dirname "$(command -v nvcc)")")}}" +NVCC="${NVCC:-$CUDA_HOME/bin/nvcc}" +ARCH="${NVCC_ARCH:-native}" + +# Default fp32; --bf16 / --half drop -DPRECISION_FLOAT +PRECISION_FLAG="-DPRECISION_FLOAT" +USER_ARGS=() +for arg in "$@"; do + case "$arg" in + --bf16|--half) PRECISION_FLAG="" ;; + --float|--fp32) PRECISION_FLAG="-DPRECISION_FLOAT" ;; + *) USER_ARGS+=("$arg") ;; + esac +done + +CUDNN_IFLAG="" +CUDNN_LFLAG="" +for dir in "$CUDA_HOME/include" /usr/local/cuda/include /usr/include; do + if [[ -f "$dir/cudnn.h" ]]; then + CUDNN_IFLAG="-I$dir" + break + fi +done +for dir in "$CUDA_HOME/lib64" "$CUDA_HOME/lib" /usr/lib/x86_64-linux-gnu; do + if [[ -f "$dir/libcudnn.so" ]] || [[ -f "$dir/libcudnn.dylib" ]]; then + CUDNN_LFLAG="-L$dir" + break + fi +done +if [[ -z "$CUDNN_IFLAG" ]]; then + CUDNN_IFLAG=$(python3 -c "import nvidia.cudnn, os; print('-I' + os.path.join(nvidia.cudnn.__path__[0], 'include'))" 2>/dev/null || true) +fi +if [[ -z "$CUDNN_LFLAG" ]]; then + CUDNN_LFLAG=$(python3 -c "import nvidia.cudnn, os; print('-L' + os.path.join(nvidia.cudnn.__path__[0], 'lib'))" 2>/dev/null || true) +fi + +OUT="${ROOT}/tests/bench_conv_gemm_vs_cudnn" +if [[ -n "$PRECISION_FLAG" ]]; then + echo "nvcc $ARCH $OUT (fp32)" +else + echo "nvcc $ARCH $OUT (bf16)" +fi +"$NVCC" -O2 -std=c++17 "-arch=$ARCH" \ + "-I${ROOT}/src" \ + "-I${CUDA_HOME}/include" \ + $CUDNN_IFLAG \ + $PRECISION_FLAG \ + "${ROOT}/tests/bench_conv_gemm_vs_cudnn.cu" \ + -o "$OUT" \ + $CUDNN_LFLAG \ + -L"${CUDA_HOME}/lib64" -L"${CUDA_HOME}/lib" \ + -lcublas -lcudnn -lcurand + +echo "Running: $OUT ${USER_ARGS[*]}" +exec "$OUT" "${USER_ARGS[@]}" diff --git a/tests/tune_cublas_gemm.cu b/tests/tune_cublas_gemm.cu new file mode 100644 index 0000000000..35274e3469 --- /dev/null +++ b/tests/tune_cublas_gemm.cu @@ -0,0 +1,296 @@ +// Sweep cublasGemmEx algorithms for the same layout as puf_mm_tn (see kernels.cu). +// Default (M,N,K) matches NMMO3 conv1 ∂W GEMM at B=1024: M=OC, N=IC*K*K, K=B*OH*OW. +// +// Build: see tests/tune_cublas_gemm.sh (-DPRECISION_FLOAT => fp32; omit => bf16) + +#include +#include +#include + +#include +#include +#include +#include +#include + +#ifndef CUBLAS_GEMM_ALGO0 +#define CUBLAS_GEMM_ALGO0 ((cublasGemmAlgo_t)0) +#endif + +#ifdef PRECISION_FLOAT +typedef float precision_t; +static constexpr cudaDataType_t kCudaPrec = CUDA_R_32F; +static constexpr cublasComputeType_t kCompute = CUBLAS_COMPUTE_32F; +#else +typedef __nv_bfloat16 precision_t; +static constexpr cudaDataType_t kCudaPrec = CUDA_R_16BF; +static constexpr cublasComputeType_t kCompute = CUBLAS_COMPUTE_32F; +#endif + +static void check_cuda(cudaError_t e) { + if (e != cudaSuccess) { + fprintf(stderr, "cuda: %s\n", cudaGetErrorString(e)); + exit(1); + } +} + +static const char* cublas_str(cublasStatus_t s) { + switch (s) { + case CUBLAS_STATUS_SUCCESS: return "SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: return "NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: return "ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: return "INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: return "ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: return "MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: return "EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: return "INTERNAL_ERROR"; + case CUBLAS_STATUS_NOT_SUPPORTED: return "NOT_SUPPORTED"; + default: return "OTHER"; + } +} + +// Same lda/ldb rules as cublasGemmExDense in kernels.cu +static inline void gemm_ex_like_puf_mm_tn(cublasHandle_t h, int M, int N, int K, const precision_t* A, + const precision_t* B, precision_t* C, cublasGemmAlgo_t algo, cudaStream_t stream) { + const float alpha = 1.0f, beta = 0.0f; + cublasOperation_t op_a = CUBLAS_OP_T; + cublasOperation_t op_b = CUBLAS_OP_N; + int lda = (op_a == CUBLAS_OP_N) ? K : M; + int ldb = (op_b == CUBLAS_OP_N) ? N : K; + cublasSetStream(h, stream); + cublasStatus_t st = cublasGemmEx(h, op_b, op_a, N, M, K, &alpha, B, kCudaPrec, ldb, A, kCudaPrec, lda, + &beta, C, kCudaPrec, N, kCompute, algo); + if (st != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, "cublasGemmEx failed: %s\n", cublas_str(st)); + exit(1); + } +} + +template +static float time_ms(cudaStream_t stream, F fn, int warmup, int iters) { + for (int i = 0; i < warmup; ++i) { + fn(stream); + cudaStreamSynchronize(stream); + } + cudaEvent_t e0, e1; + cudaEventCreate(&e0); + cudaEventCreate(&e1); + cudaEventRecord(e0, stream); + for (int i = 0; i < iters; ++i) fn(stream); + cudaEventRecord(e1, stream); + cudaEventSynchronize(e1); + float ms = 0.f; + cudaEventElapsedTime(&ms, e0, e1); + cudaEventDestroy(e0); + cudaEventDestroy(e1); + return ms / (float)iters; +} + +static float max_abs_diff_fp32(const float* a, const float* b, int n) { + float m = 0.f; + for (int i = 0; i < n; ++i) m = fmaxf(m, fabsf(a[i] - b[i])); + return m; +} + +int main(int argc, char** argv) { + int Bbatch = 1024; + int M = 128, N = 1475, Kdim = 12288; + int warmup = 20, iters = 100; + for (int i = 1; i < argc; ++i) { + if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) { + printf( + "Usage: %s [options]\n" + " Tune cublasGemmEx for the same call pattern as puf_mm_tn:\n" + " C = B * A with op(B)=N, op(A)=T, sizes (N,M,K) -> cublasGemmEx(..., N,M,K,...)\n" + "Options:\n" + " --layer 1|2 NMMO3 conv sizes for (M,N,K) at given -B (default --layer 1)\n" + " -B N batch (default 1024), used with --layer\n" + " -M,-N,-K override matrix dims (after --layer, if set)\n" + " --warmup N (default 20)\n" + " --iters N (default 100)\n" + " (M,N,K) default without --layer: 128 1475 12288 (conv1 ∂W @ B=1024)\n", + argv[0]); + return 0; + } else if (strcmp(argv[i], "-B") == 0 && i + 1 < argc) Bbatch = atoi(argv[++i]); + else if (strcmp(argv[i], "--warmup") == 0 && i + 1 < argc) warmup = atoi(argv[++i]); + else if (strcmp(argv[i], "--iters") == 0 && i + 1 < argc) iters = atoi(argv[++i]); + else if (strcmp(argv[i], "--layer") == 0 && i + 1 < argc) { + int L = atoi(argv[++i]); + if (L == 1) { + const int OH = 3, OW = 4, IC = 59, OC = 128, Kk = 5; + Kdim = Bbatch * OH * OW; + N = IC * Kk * Kk; + M = OC; + } else if (L == 2) { + const int OH = 1, OW = 2, IC = 128, OC = 128, Kk = 3; + Kdim = Bbatch * OH * OW; + N = IC * Kk * Kk; + M = OC; + } else { + fprintf(stderr, "layer must be 1 or 2\n"); + return 1; + } + } else if (strcmp(argv[i], "-M") == 0 && i + 1 < argc) M = atoi(argv[++i]); + else if (strcmp(argv[i], "-N") == 0 && i + 1 < argc) N = atoi(argv[++i]); + else if (strcmp(argv[i], "-K") == 0 && i + 1 < argc) Kdim = atoi(argv[++i]); + } + + int ldc = N; + int lenA = Kdim * M; + int lenB = Kdim * N; + int lenC = M * N; + + printf("tune_cublas_gemm M=%d N=%d K=%d (puf_mm_tn logical sizes)", M, N, Kdim); +#ifdef PRECISION_FLOAT + printf(" dtype=fp32\n"); +#else + printf(" dtype=bf16 compute=CUBLAS_COMPUTE_32F\n"); +#endif + + precision_t *dA, *dB, *dC, *dRef; + check_cuda(cudaMalloc(&dA, (size_t)lenA * sizeof(precision_t))); + check_cuda(cudaMalloc(&dB, (size_t)lenB * sizeof(precision_t))); + check_cuda(cudaMalloc(&dC, (size_t)lenC * sizeof(precision_t))); + check_cuda(cudaMalloc(&dRef, (size_t)lenC * sizeof(precision_t))); + + std::vector hAf(lenA), hBf(lenB); + unsigned seed = 12345; + for (int i = 0; i < lenA; ++i) { + seed = seed * 1103515245u + 12345u; + hAf[i] = (((seed >> 16) & 0x7fff) / 16384.0f - 1.0f) * 0.25f; + } + for (int i = 0; i < lenB; ++i) { + seed = seed * 1103515245u + 12345u; + hBf[i] = (((seed >> 16) & 0x7fff) / 16384.0f - 1.0f) * 0.25f; + } +#ifdef PRECISION_FLOAT + check_cuda(cudaMemcpy(dA, hAf.data(), (size_t)lenA * sizeof(float), cudaMemcpyHostToDevice)); + check_cuda(cudaMemcpy(dB, hBf.data(), (size_t)lenB * sizeof(float), cudaMemcpyHostToDevice)); +#else + std::vector hA(lenA), hB(lenB); + for (int i = 0; i < lenA; ++i) hA[i] = __float2bfloat16(hAf[i]); + for (int i = 0; i < lenB; ++i) hB[i] = __float2bfloat16(hBf[i]); + check_cuda(cudaMemcpy(dA, hA.data(), (size_t)lenA * sizeof(precision_t), cudaMemcpyHostToDevice)); + check_cuda(cudaMemcpy(dB, hB.data(), (size_t)lenB * sizeof(precision_t), cudaMemcpyHostToDevice)); +#endif + + cublasHandle_t handle; + cublasCreate(&handle); + + std::vector ref_host(lenC); + cudaStream_t stream = 0; + + auto run_ref = [&]() { + cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH); + gemm_ex_like_puf_mm_tn(handle, M, N, Kdim, dA, dB, dRef, CUBLAS_GEMM_DEFAULT, stream); + }; + run_ref(); + cudaDeviceSynchronize(); +#ifdef PRECISION_FLOAT + check_cuda(cudaMemcpy(ref_host.data(), dRef, (size_t)lenC * sizeof(float), cudaMemcpyDeviceToHost)); +#else + std::vector ref_bf(lenC); + check_cuda(cudaMemcpy(ref_bf.data(), dRef, (size_t)lenC * sizeof(precision_t), cudaMemcpyDeviceToHost)); + for (int i = 0; i < lenC; ++i) ref_host[i] = __bfloat162float(ref_bf[i]); +#endif + + struct Row { + cublasGemmAlgo_t algo; + cublasMath_t math; + float ms; + float max_diff; + bool ok; + }; + std::vector rows; + + const cublasMath_t math_modes[] = {CUBLAS_DEFAULT_MATH, CUBLAS_TENSOR_OP_MATH}; + const char* math_names[] = {"DEFAULT_MATH", "TENSOR_OP_MATH"}; + + std::vector algos; + algos.push_back(CUBLAS_GEMM_DEFAULT); +#ifdef CUBLAS_GEMM_DEFAULT_TENSOR_OP + algos.push_back(CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif + for (int a = 0; a <= 23; ++a) + algos.push_back((cublasGemmAlgo_t)((int)CUBLAS_GEMM_ALGO0 + a)); + + for (size_t mi = 0; mi < sizeof(math_modes) / sizeof(math_modes[0]); ++mi) { + cublasSetMathMode(handle, math_modes[mi]); + for (cublasGemmAlgo_t algo : algos) { + cublasStatus_t st = cublasSetStream(handle, stream); + (void)st; + const float alpha = 1.f, beta = 0.f; + cublasOperation_t op_a = CUBLAS_OP_T; + cublasOperation_t op_b = CUBLAS_OP_N; + int lda = (op_a == CUBLAS_OP_N) ? Kdim : M; + int ldb = (op_b == CUBLAS_OP_N) ? N : Kdim; + st = cublasGemmEx(handle, op_b, op_a, N, M, Kdim, &alpha, dB, kCudaPrec, ldb, dA, kCudaPrec, lda, + &beta, dC, kCudaPrec, N, kCompute, algo); + + if (st != CUBLAS_STATUS_SUCCESS) { + rows.push_back({algo, math_modes[mi], 0.f, 0.f, false}); + continue; + } + cudaDeviceSynchronize(); + + float ms = time_ms( + stream, + [&](cudaStream_t s) { + cublasGemmEx(handle, op_b, op_a, N, M, Kdim, &alpha, dB, kCudaPrec, ldb, dA, kCudaPrec, lda, + &beta, dC, kCudaPrec, N, kCompute, algo); + }, + warmup, iters); + +#ifdef PRECISION_FLOAT + std::vector hC(lenC); + check_cuda(cudaMemcpy(hC.data(), dC, (size_t)lenC * sizeof(float), cudaMemcpyDeviceToHost)); + float md = max_abs_diff_fp32(ref_host.data(), hC.data(), lenC); +#else + std::vector hCb(lenC); + check_cuda(cudaMemcpy(hCb.data(), dC, (size_t)lenC * sizeof(precision_t), cudaMemcpyDeviceToHost)); + std::vector hC(lenC); + for (int i = 0; i < lenC; ++i) hC[i] = __bfloat162float(hCb[i]); + float md = max_abs_diff_fp32(ref_host.data(), hC.data(), lenC); +#endif + rows.push_back({algo, math_modes[mi], ms, md, true}); + } + } + + cublasDestroy(handle); + cudaFree(dA); + cudaFree(dB); + cudaFree(dC); + cudaFree(dRef); + + float best_ms = 1e30f; + int best_i = -1; + for (size_t i = 0; i < rows.size(); ++i) { + if (!rows[i].ok) continue; + if (rows[i].ms < best_ms) { + best_ms = rows[i].ms; + best_i = (int)i; + } + } + + printf("\n%-6s %-22s %-8s %10s %12s %s\n", "algo", "math", "ok", "us/iter", "max|diff|", "note"); + printf( + "------ ---------------------- -------- ---------- ------------ ----\n"); + for (size_t i = 0; i < rows.size(); ++i) { + const Row& r = rows[i]; + const char* mn = (r.math == CUBLAS_DEFAULT_MATH) ? math_names[0] : math_names[1]; + if (!r.ok) { + printf("%-6d %-22s %-8s\n", (int)r.algo, mn, "no"); + continue; + } + const char* tag = ((int)i == best_i) ? "best" : ""; + printf("%-6d %-22s %-8s %10.4f %12.5g %s\n", (int)r.algo, mn, "yes", r.ms * 1000.0f, r.max_diff, tag); + } + if (best_i >= 0) { + printf("\nFastest OK: algo=%d math=%s (%.4f us/iter) max|diff|=%g vs DEFAULT/DEFAULT_MATH " + "reference\n", + (int)rows[best_i].algo, + rows[best_i].math == CUBLAS_DEFAULT_MATH ? math_names[0] : math_names[1], best_ms * 1000.0f, + rows[best_i].max_diff); + } + return 0; +} diff --git a/tests/tune_cublas_gemm.sh b/tests/tune_cublas_gemm.sh new file mode 100755 index 0000000000..ff419ba353 --- /dev/null +++ b/tests/tune_cublas_gemm.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +# Build and run cublasGemmEx tuner (same layout as puf_mm_tn in kernels.cu). +# Default: fp32. Use --bf16 / --half for bf16. +set -euo pipefail +ROOT="$(cd "$(dirname "$0")/.." && pwd)" +cd "$ROOT" + +CUDA_HOME="${CUDA_HOME:-${CUDA_PATH:-$(dirname "$(dirname "$(command -v nvcc)")")}}" +NVCC="${NVCC:-$CUDA_HOME/bin/nvcc}" +ARCH="${NVCC_ARCH:-native}" + +PRECISION_FLAG="-DPRECISION_FLOAT" +USER_ARGS=() +for arg in "$@"; do + case "$arg" in + --bf16|--half) PRECISION_FLAG="" ;; + --float|--fp32) PRECISION_FLAG="-DPRECISION_FLOAT" ;; + *) USER_ARGS+=("$arg") ;; + esac +done + +OUT="${ROOT}/tests/tune_cublas_gemm" +if [[ -n "$PRECISION_FLAG" ]]; then + echo "nvcc $ARCH $OUT (fp32)" +else + echo "nvcc $ARCH $OUT (bf16)" +fi +"$NVCC" -O2 -std=c++17 "-arch=$ARCH" \ + $PRECISION_FLAG \ + "${ROOT}/tests/tune_cublas_gemm.cu" \ + -o "$OUT" \ + -L"${CUDA_HOME}/lib64" -L"${CUDA_HOME}/lib" \ + -lcublas + +echo "Running: $OUT ${USER_ARGS[*]}" +exec "$OUT" "${USER_ARGS[@]}" From a42c9bec2a6524b47a7fd5b7ed8d3877f46015bb Mon Sep 17 00:00:00 2001 From: jonah Date: Thu, 9 Apr 2026 07:40:50 -0700 Subject: [PATCH 2/9] fwd + bwd --- src/ocean.cu | 88 +++++++++ tests/bench_conv_gemm_vs_cudnn.cu | 310 ++++++++++++++++++++++++++++++ tests/bench_conv_gemm_vs_cudnn.sh | 4 + 3 files changed, 402 insertions(+) diff --git a/src/ocean.cu b/src/ocean.cu index abe0a19208..3d5fadf599 100644 --- a/src/ocean.cu +++ b/src/ocean.cu @@ -271,6 +271,10 @@ struct Im2ColFastMods { FastDivMod dm_kk; FastDivMod dm_k; FastDivMod dm_oc; + FastDivMod dm_iw; + FastDivMod dm_ih; + FastDivMod dm_ic; + FastDivMod dm_s; int total_no_batch; int oh_ow; int oc_spatial; @@ -279,6 +283,7 @@ struct Im2ColFastMods { __host__ Im2ColFastMods(int ic, int ih, int iw, int oc, int k, int s, int oh, int ow) : dm_col_w(ic * k * k), dm_oh_ow(oh * ow), dm_ow(ow), dm_kk(k * k), dm_k(k), dm_oc(oc), + dm_iw(iw), dm_ih(ih), dm_ic(ic), dm_s(s), total_no_batch((oh * ow) * (ic * k * k)), oh_ow(oh * ow), col_cols(ic * k * k), oc_spatial(oc * oh * ow), IC(ic), IH(ih), IW(iw), OC(oc), K(k), S(s), OH(oh), OW(ow) {} @@ -317,6 +322,43 @@ __global__ void im2col_kernel_fast( // Backward: col2im — input-centric gather to avoid atomics. // Each thread owns one (b, ic, ih, iw) element and sums contributions from all // (oh, ow, kh, kw) patches that map to it. +// col2im fast path: dm_iw/dm_ih/dm_ic/dm_s and col_cols/oh_ow from Im2ColFastMods. +__global__ void col2im_kernel_fast( + const precision_t* __restrict__ col, precision_t* __restrict__ grad_input, + int B, int IC, int IH, int IW, int K, int OH, int OW, + const FastDivMod dm_iw, const FastDivMod dm_ih, const FastDivMod dm_ic, + const FastDivMod dm_s, int col_cols, int oh_ow +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = B * IC * IH * IW; + if (idx >= total) return; + int q0, iw, q1, ih, b, ic; + dm_iw.divmod(idx, q0, iw); + dm_ih.divmod(q0, q1, ih); + dm_ic.divmod(q1, b, ic); + int bohow_ickk = b * oh_ow * col_cols + ic * (K * K); + float sum = 0.0f; + for (int kh = 0; kh < K; kh++) { + int ih_off = ih - kh; + if (ih_off < 0) continue; + int oh, ih_rem; + dm_s.divmod(ih_off, oh, ih_rem); + if (ih_rem != 0 || oh >= OH) continue; + int ohowcc_khk = oh * OW * col_cols + kh * K; + int inner_value = bohow_ickk + ohowcc_khk; + for (int kw = 0; kw < K; kw++) { + int iw_off = iw - kw; + if (iw_off < 0) continue; + int ow, iw_rem; + dm_s.divmod(iw_off, ow, iw_rem); + if (iw_rem != 0 || ow >= OW) continue; + int col_idx = inner_value + ow * col_cols + kw; + sum += to_float(col[col_idx]); + } + } + grad_input[idx] = from_float(sum); +} + __global__ void col2im_kernel( const precision_t* __restrict__ col, precision_t* __restrict__ grad_input, int B, int IC, int IH, int IW, int K, int S, int OH, int OW @@ -347,6 +389,20 @@ __global__ void col2im_kernel( } // Transpose (B, OC, OH, OW) -> (B*OH*OW, OC) [NCHW to row-major spatial-first] +// Same idx layout as rows_to_nchw_kernel_fused; dm_oh_ow = spatial, dm_oc = OC. +__global__ void nchw_to_rows_kernel_fast( + const precision_t* __restrict__ src, precision_t* __restrict__ dst, + int B, int OC, int spatial, + const FastDivMod dm_oh_ow, const FastDivMod dm_oc +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = B * OC * spatial; + if (idx >= total) return; + int q, s, b, oc; + dm_oh_ow.divmod(idx, q, s); + dm_oc.divmod(q, b, oc); + dst[(b * spatial + s) * OC + oc] = src[idx]; +} __global__ void nchw_to_rows_kernel( const precision_t* __restrict__ src, precision_t* __restrict__ dst, int B, int OC, int spatial @@ -501,6 +557,38 @@ static void gemm_conv_backward( } } +// NMMO3 conv1/conv2 only: pass kIm2ColModsC1 or kIm2ColModsC2 (same as gemm_conv_forward_fast). +static void gemm_conv_backward_fast( + PrecisionTensor* weight, + precision_t* saved_input, precision_t* grad_output, + precision_t* wgrad, precision_t* input_grad, + precision_t* col_buf, precision_t* mm_buf, + int B, const Im2ColFastMods& m, cudaStream_t stream +) { + int col_rows = B * m.oh_ow; + int total_col = col_rows * m.col_cols; + int total_out = B * m.OC * m.oh_ow; + + nchw_to_rows_kernel_fast<<>>( + grad_output, mm_buf, B, m.OC, m.oh_ow, m.dm_oh_ow, m.dm_oc); + + im2col_kernel_fast<<>>( + saved_input, col_buf, B, m.IC, m.IH, m.IW, m.K, m.S, m.OH, m.OW, + m.dm_col_w, m.dm_oh_ow, m.dm_ow, m.dm_kk, m.dm_k, m.total_no_batch); + + PrecisionTensor mm_t = {.data = mm_buf, .shape = {col_rows, m.OC}}; + PrecisionTensor col_t = {.data = col_buf, .shape = {col_rows, m.col_cols}}; + PrecisionTensor wg_t = {.data = wgrad, .shape = {m.OC, m.col_cols}}; + puf_mm_tn(&mm_t, &col_t, &wg_t, stream); + + if (input_grad) { + puf_mm_nn(&mm_t, weight, &col_t, stream); + col2im_kernel_fast<<>>( + col_buf, input_grad, B, m.IC, m.IH, m.IW, m.K, m.OH, m.OW, + m.dm_iw, m.dm_ih, m.dm_ic, m.dm_s, m.col_cols, m.oh_ow); + } +} + // ---- NMMO3 encoder structs ---- struct NMMO3EncoderWeights { diff --git a/tests/bench_conv_gemm_vs_cudnn.cu b/tests/bench_conv_gemm_vs_cudnn.cu index babf763468..cd05977ee2 100755 --- a/tests/bench_conv_gemm_vs_cudnn.cu +++ b/tests/bench_conv_gemm_vs_cudnn.cu @@ -439,6 +439,139 @@ static int run_backward(const BenchDims& dim, int warmup, int iters) { return 0; } +// ∂W-only (input_grad=null) vs full backward (+ col2im / cudnn BackwardData); gemm vs cudnn. +static int run_backward_input_grad_bench(const BenchDims& dim, int warmup, int iters) { + ConvWeights cw{}; + conv_init(&cw, dim.IC, dim.OC, dim.K, dim.S, dim.IH, dim.IW, false); + + Allocator param_alloc{}; + conv_reg_params(&cw, ¶m_alloc); + if (alloc_create(¶m_alloc) != cudaSuccess) return 1; + uint64_t seed = 7; + conv_init_weights(&cw, &seed, 0); + cudaDeviceSynchronize(); + + int OH = cw.OH; + int OW = cw.OW; + int out_elems = dim.B * dim.OC * OH * OW; + int in_elems = dim.B * dim.IC * dim.IH * dim.IW; + int w_elems = (int)numel(cw.w.shape); + + Allocator act_g{}; + PrecisionTensor col{}, mm{}; + int col_rows = dim.B * OH * OW; + int col_cols = dim.IC * dim.K * dim.K; + col = {.shape = {col_rows, col_cols}}; + mm = {.shape = {col_rows, dim.OC}}; + PrecisionTensor saved_in{}, grad_out{}, wgrad_g{}; + saved_in = {.shape = {dim.B, dim.IC, dim.IH, dim.IW}}; + grad_out = {.shape = {(int64_t)out_elems}}; + wgrad_g = {.shape = {cw.w.shape[0], cw.w.shape[1]}}; + PrecisionTensor dinput_g{}; + dinput_g = {.shape = {dim.B, dim.IC, dim.IH, dim.IW}}; + alloc_register(&act_g, &col); + alloc_register(&act_g, &mm); + alloc_register(&act_g, &saved_in); + alloc_register(&act_g, &grad_out); + alloc_register(&act_g, &wgrad_g); + alloc_register(&act_g, &dinput_g); + if (alloc_create(&act_g) != cudaSuccess) return 1; + + std::vector hs(in_elems), hg(out_elems); + fill_rand_host(hs.data(), in_elems, 101u); + fill_rand_host(hg.data(), out_elems, 202u); + copy_fp32_h2d(hs.data(), saved_in.data, in_elems); + copy_fp32_h2d(hg.data(), grad_out.data, out_elems); + + Allocator acts{}, grads{}; + ConvActivations ca{}; + conv_reg_train(&cw, &ca, &acts, &grads, dim.B, n3_cudnn_dtype()); + if (alloc_create(&acts) != cudaSuccess || alloc_create(&grads) != cudaSuccess) return 1; + cudaMemcpy(ca.saved_input.data, saved_in.data, (size_t)in_elems * sizeof(precision_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(ca.grad.data, grad_out.data, (size_t)out_elems * sizeof(precision_t), cudaMemcpyDeviceToDevice); + + cudaStream_t stream = 0; + + cudaMemset(wgrad_g.data, 0, (size_t)w_elems * sizeof(precision_t)); + gemm_conv_backward(&cw.w, saved_in.data, grad_out.data, wgrad_g.data, nullptr, col.data, mm.data, dim.B, + dim.IC, dim.IH, dim.IW, dim.OC, dim.K, dim.S, OH, OW, stream); + cudaDeviceSynchronize(); + std::vector hwg_wnull_g; + copy_precision_d2h(wgrad_g.data, w_elems, &hwg_wnull_g); + + cudaMemset(ca.wgrad.data, 0, (size_t)w_elems * sizeof(precision_t)); + conv_backward(&cw, &ca, nullptr, dim.B, stream); + cudaDeviceSynchronize(); + std::vector hwg_wnull_c; + copy_precision_d2h(ca.wgrad.data, w_elems, &hwg_wnull_c); + + float d_wg_wnull, m_wg_wnull; + stats_diff(hwg_wnull_g.data(), hwg_wnull_c.data(), w_elems, &d_wg_wnull, &m_wg_wnull); + printf(" ∂W only wgrad max |diff| (gemm vs cudnn): %.6g mean |diff|: %.6g\n", d_wg_wnull, m_wg_wnull); + + cudaMemset(wgrad_g.data, 0, (size_t)w_elems * sizeof(precision_t)); + cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); + gemm_conv_backward(&cw.w, saved_in.data, grad_out.data, wgrad_g.data, dinput_g.data, col.data, mm.data, + dim.B, dim.IC, dim.IH, dim.IW, dim.OC, dim.K, dim.S, OH, OW, stream); + cudaDeviceSynchronize(); + std::vector hwg_full_g, hdi_full_g; + copy_precision_d2h(wgrad_g.data, w_elems, &hwg_full_g); + copy_precision_d2h(dinput_g.data, in_elems, &hdi_full_g); + + cudaMemset(ca.wgrad.data, 0, (size_t)w_elems * sizeof(precision_t)); + cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); + conv_backward(&cw, &ca, dinput_g.data, dim.B, stream); + cudaDeviceSynchronize(); + std::vector hwg_full_c, hdi_full_c; + copy_precision_d2h(ca.wgrad.data, w_elems, &hwg_full_c); + copy_precision_d2h(dinput_g.data, in_elems, &hdi_full_c); + + float d_wg_full, m_wg_full, d_di, m_di; + stats_diff(hwg_full_g.data(), hwg_full_c.data(), w_elems, &d_wg_full, &m_wg_full); + stats_diff(hdi_full_g.data(), hdi_full_c.data(), in_elems, &d_di, &m_di); + printf(" full wgrad max |diff| (gemm vs cudnn): %.6g mean |diff|: %.6g\n", d_wg_full, m_wg_full); + printf(" full d_input max |diff| (gemm vs cudnn): %.6g mean |diff|: %.6g\n", d_di, m_di); + + auto run_gemm_wnull = [&](cudaStream_t s) { + cudaMemset(wgrad_g.data, 0, (size_t)w_elems * sizeof(precision_t)); + gemm_conv_backward(&cw.w, saved_in.data, grad_out.data, wgrad_g.data, nullptr, col.data, mm.data, + dim.B, dim.IC, dim.IH, dim.IW, dim.OC, dim.K, dim.S, OH, OW, s); + }; + auto run_gemm_full = [&](cudaStream_t s) { + cudaMemset(wgrad_g.data, 0, (size_t)w_elems * sizeof(precision_t)); + cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); + gemm_conv_backward(&cw.w, saved_in.data, grad_out.data, wgrad_g.data, dinput_g.data, col.data, + mm.data, dim.B, dim.IC, dim.IH, dim.IW, dim.OC, dim.K, dim.S, OH, OW, s); + }; + auto run_cudnn_wnull = [&](cudaStream_t s) { + cudaMemset(ca.wgrad.data, 0, (size_t)w_elems * sizeof(precision_t)); + conv_backward(&cw, &ca, nullptr, dim.B, s); + }; + auto run_cudnn_full = [&](cudaStream_t s) { + cudaMemset(ca.wgrad.data, 0, (size_t)w_elems * sizeof(precision_t)); + cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); + conv_backward(&cw, &ca, dinput_g.data, dim.B, s); + }; + + float ms_g_w = time_kernel_ms(stream, run_gemm_wnull, warmup, iters); + float ms_g_f = time_kernel_ms(stream, run_gemm_full, warmup, iters); + float ms_c_w = time_kernel_ms(stream, run_cudnn_wnull, warmup, iters); + float ms_c_f = time_kernel_ms(stream, run_cudnn_full, warmup, iters); + + printf(" gemm ∂W only (d_input=null): %8.4f us/iter\n", ms_g_w * 1000.0f); + printf(" gemm full (+d_input): %8.4f us/iter (+%.4f us d_input slice)\n", + ms_g_f * 1000.0f, (ms_g_f - ms_g_w) * 1000.0f); + printf(" cudnn ∂W only (no BwdData): %8.4f us/iter\n", ms_c_w * 1000.0f); + printf(" cudnn full (+BwdData): %8.4f us/iter (+%.4f us BwdData slice)\n", + ms_c_f * 1000.0f, (ms_c_f - ms_c_w) * 1000.0f); + + alloc_free(¶m_alloc); + alloc_free(&act_g); + alloc_free(&acts); + alloc_free(&grads); + return 0; +} + static int run_im2col_bench(const BenchDims& dim, int warmup, int iters) { int B = dim.B, IC = dim.IC, IH = dim.IH, IW = dim.IW, K = dim.K, S = dim.S; int OH = (IH - K) / S + 1; @@ -506,6 +639,138 @@ static int run_im2col_bench(const BenchDims& dim, int warmup, int iters) { return 0; } +// Full backward: gemm_conv_backward vs gemm_conv_backward_fast vs cudnn (NMMO3 layer geometry only). +static int run_backward_gemm_fast_bench(const BenchDims& dim, int layer, int warmup, int iters) { + const Im2ColFastMods& m = (layer == 1) ? kIm2ColModsC1 : kIm2ColModsC2; + int OH = (dim.IH - dim.K) / dim.S + 1; + int OW = (dim.IW - dim.K) / dim.S + 1; + if (OH != m.OH || OW != m.OW || dim.IC != m.IC || dim.IH != m.IH || dim.IW != m.IW || dim.OC != m.OC + || dim.K != m.K || dim.S != m.S) { + fprintf(stderr, "backward gemm-fast bench: dimensions must match NMMO3 layer %d (use --layer %d)\n", + layer, layer); + return 1; + } + + ConvWeights cw{}; + conv_init(&cw, dim.IC, dim.OC, dim.K, dim.S, dim.IH, dim.IW, false); + + Allocator param_alloc{}; + conv_reg_params(&cw, ¶m_alloc); + if (alloc_create(¶m_alloc) != cudaSuccess) return 1; + uint64_t seed = 23; + conv_init_weights(&cw, &seed, 0); + cudaDeviceSynchronize(); + + int out_elems = dim.B * dim.OC * OH * OW; + int in_elems = dim.B * dim.IC * dim.IH * dim.IW; + int w_elems = (int)numel(cw.w.shape); + + Allocator act_g{}; + PrecisionTensor col{}, mm{}; + int col_rows = dim.B * OH * OW; + int col_cols = dim.IC * dim.K * dim.K; + col = {.shape = {col_rows, col_cols}}; + mm = {.shape = {col_rows, dim.OC}}; + PrecisionTensor saved_in{}, grad_out{}, wgrad_g{}; + saved_in = {.shape = {dim.B, dim.IC, dim.IH, dim.IW}}; + grad_out = {.shape = {(int64_t)out_elems}}; + wgrad_g = {.shape = {cw.w.shape[0], cw.w.shape[1]}}; + PrecisionTensor dinput_g{}; + dinput_g = {.shape = {dim.B, dim.IC, dim.IH, dim.IW}}; + alloc_register(&act_g, &col); + alloc_register(&act_g, &mm); + alloc_register(&act_g, &saved_in); + alloc_register(&act_g, &grad_out); + alloc_register(&act_g, &wgrad_g); + alloc_register(&act_g, &dinput_g); + if (alloc_create(&act_g) != cudaSuccess) return 1; + + std::vector hs(in_elems), hg(out_elems); + fill_rand_host(hs.data(), in_elems, 101u); + fill_rand_host(hg.data(), out_elems, 202u); + copy_fp32_h2d(hs.data(), saved_in.data, in_elems); + copy_fp32_h2d(hg.data(), grad_out.data, out_elems); + + Allocator acts{}, grads{}; + ConvActivations ca{}; + conv_reg_train(&cw, &ca, &acts, &grads, dim.B, n3_cudnn_dtype()); + if (alloc_create(&acts) != cudaSuccess || alloc_create(&grads) != cudaSuccess) return 1; + cudaMemcpy(ca.saved_input.data, saved_in.data, (size_t)in_elems * sizeof(precision_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(ca.grad.data, grad_out.data, (size_t)out_elems * sizeof(precision_t), cudaMemcpyDeviceToDevice); + + cudaStream_t stream = 0; + + cudaMemset(wgrad_g.data, 0, (size_t)w_elems * sizeof(precision_t)); + cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); + gemm_conv_backward(&cw.w, saved_in.data, grad_out.data, wgrad_g.data, dinput_g.data, col.data, mm.data, + dim.B, dim.IC, dim.IH, dim.IW, dim.OC, dim.K, dim.S, OH, OW, stream); + cudaDeviceSynchronize(); + std::vector hwg_slow, hdi_slow; + copy_precision_d2h(wgrad_g.data, w_elems, &hwg_slow); + copy_precision_d2h(dinput_g.data, in_elems, &hdi_slow); + + cudaMemset(wgrad_g.data, 0, (size_t)w_elems * sizeof(precision_t)); + cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); + gemm_conv_backward_fast(&cw.w, saved_in.data, grad_out.data, wgrad_g.data, dinput_g.data, col.data, mm.data, + dim.B, m, stream); + cudaDeviceSynchronize(); + std::vector hwg_fast, hdi_fast; + copy_precision_d2h(wgrad_g.data, w_elems, &hwg_fast); + copy_precision_d2h(dinput_g.data, in_elems, &hdi_fast); + + float d_w_sg, m_w_sg, d_i_sg, m_i_sg; + stats_diff(hwg_slow.data(), hwg_fast.data(), w_elems, &d_w_sg, &m_w_sg); + stats_diff(hdi_slow.data(), hdi_fast.data(), in_elems, &d_i_sg, &m_i_sg); + printf(" wgrad max |diff| (gemm vs gemm_fast): %.6g mean |diff|: %.6g\n", d_w_sg, m_w_sg); + printf(" d_input max |diff| (gemm vs gemm_fast): %.6g mean |diff|: %.6g\n", d_i_sg, m_i_sg); + + cudaMemset(ca.wgrad.data, 0, (size_t)w_elems * sizeof(precision_t)); + cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); + conv_backward(&cw, &ca, dinput_g.data, dim.B, stream); + cudaDeviceSynchronize(); + std::vector hwg_c, hdi_c; + copy_precision_d2h(ca.wgrad.data, w_elems, &hwg_c); + copy_precision_d2h(dinput_g.data, in_elems, &hdi_c); + float d_w_sc, m_w_sc, d_i_sc, m_i_sc; + stats_diff(hwg_slow.data(), hwg_c.data(), w_elems, &d_w_sc, &m_w_sc); + stats_diff(hdi_slow.data(), hdi_c.data(), in_elems, &d_i_sc, &m_i_sc); + printf(" wgrad max |diff| (gemm vs cudnn): %.6g mean |diff|: %.6g\n", d_w_sc, m_w_sc); + printf(" d_input max |diff| (gemm vs cudnn): %.6g mean |diff|: %.6g\n", d_i_sc, m_i_sc); + + auto run_gemm = [&](cudaStream_t s) { + cudaMemset(wgrad_g.data, 0, (size_t)w_elems * sizeof(precision_t)); + cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); + gemm_conv_backward(&cw.w, saved_in.data, grad_out.data, wgrad_g.data, dinput_g.data, col.data, mm.data, + dim.B, dim.IC, dim.IH, dim.IW, dim.OC, dim.K, dim.S, OH, OW, s); + }; + auto run_gemm_fast = [&](cudaStream_t s) { + cudaMemset(wgrad_g.data, 0, (size_t)w_elems * sizeof(precision_t)); + cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); + gemm_conv_backward_fast(&cw.w, saved_in.data, grad_out.data, wgrad_g.data, dinput_g.data, col.data, mm.data, + dim.B, m, s); + }; + auto run_cudnn = [&](cudaStream_t s) { + cudaMemset(ca.wgrad.data, 0, (size_t)w_elems * sizeof(precision_t)); + cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); + conv_backward(&cw, &ca, dinput_g.data, dim.B, s); + }; + + float ms_g = time_kernel_ms(stream, run_gemm, warmup, iters); + float ms_gf = time_kernel_ms(stream, run_gemm_fast, warmup, iters); + float ms_c = time_kernel_ms(stream, run_cudnn, warmup, iters); + printf(" gemm_conv_backward: %8.4f us/iter\n", ms_g * 1000.0f); + printf(" gemm_conv_backward_fast: %8.4f us/iter (%.2fx vs gemm_conv_backward)\n", ms_gf * 1000.0f, + ms_g / ms_gf); + printf(" conv_backward (cudnn): %8.4f us/iter (%.2fx vs gemm, %.2fx vs gemm_fast)\n", ms_c * 1000.0f, + ms_g / ms_c, ms_gf / ms_c); + + alloc_free(¶m_alloc); + alloc_free(&act_g); + alloc_free(&acts); + alloc_free(&grads); + return 0; +} + // gemm_conv_forward vs gemm_conv_forward_fast; relu on/off (NMMO3 layer geometry only). static int run_gemm_fast_fwd_bench(const BenchDims& dim, int layer, int warmup, int iters) { const Im2ColFastMods& m = (layer == 1) ? kIm2ColModsC1 : kIm2ColModsC2; @@ -597,6 +862,8 @@ int main(int argc, char** argv) { bool do_wgrad_breakdown = false; bool do_im2col_bench = false; bool do_gemm_fast_bench = false; + bool do_bwd_dinput_bench = false; + bool do_gemm_bwd_fast_bench = false; for (int i = 1; i < argc; ++i) { if (strcmp(argv[i], "-B") == 0 && i + 1 < argc) B = atoi(argv[++i]); else if (strcmp(argv[i], "--layer") == 0 && i + 1 < argc) layer = atoi(argv[++i]); @@ -610,6 +877,8 @@ int main(int argc, char** argv) { do_wgrad_breakdown = true; do_fwd = false; do_bwd = false; + do_bwd_dinput_bench = false; + do_gemm_bwd_fast_bench = false; } else if (strcmp(argv[i], "--wgrad-breakdown") == 0 || strcmp(argv[i], "--filter-bwd") == 0) { do_wgrad_breakdown = true; } else if (strcmp(argv[i], "--im2col-bench-only") == 0) { @@ -617,6 +886,8 @@ int main(int argc, char** argv) { do_fwd = false; do_bwd = false; do_wgrad_breakdown = false; + do_bwd_dinput_bench = false; + do_gemm_bwd_fast_bench = false; } else if (strcmp(argv[i], "--im2col-bench") == 0) { do_im2col_bench = true; } else if (strcmp(argv[i], "--gemm-fast-bench-only") == 0) { @@ -625,8 +896,30 @@ int main(int argc, char** argv) { do_bwd = false; do_wgrad_breakdown = false; do_im2col_bench = false; + do_bwd_dinput_bench = false; + do_gemm_bwd_fast_bench = false; } else if (strcmp(argv[i], "--gemm-fast-bench") == 0) { do_gemm_fast_bench = true; + } else if (strcmp(argv[i], "--gemm-bwd-fast-bench-only") == 0) { + do_gemm_bwd_fast_bench = true; + do_fwd = false; + do_bwd = false; + do_wgrad_breakdown = false; + do_im2col_bench = false; + do_gemm_fast_bench = false; + do_bwd_dinput_bench = false; + } else if (strcmp(argv[i], "--gemm-bwd-fast-bench") == 0) { + do_gemm_bwd_fast_bench = true; + } else if (strcmp(argv[i], "--bwd-dinput-bench-only") == 0) { + do_bwd_dinput_bench = true; + do_fwd = false; + do_bwd = false; + do_wgrad_breakdown = false; + do_im2col_bench = false; + do_gemm_fast_bench = false; + do_gemm_bwd_fast_bench = false; + } else if (strcmp(argv[i], "--bwd-dinput-bench") == 0) { + do_bwd_dinput_bench = true; } else if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) { printf("Usage: %s [options]\n", argv[0]); printf(" -B N batch size (default 1024)\n"); @@ -642,6 +935,10 @@ int main(int argc, char** argv) { printf(" --im2col-bench-only only that im2col bench\n"); printf(" --gemm-fast-bench also bench gemm_conv_forward vs gemm_conv_forward_fast (relu 0/1)\n"); printf(" --gemm-fast-bench-only only that bench\n"); + printf(" --bwd-dinput-bench also bench ∂W-only vs full bwd (gemm vs cudnn)\n"); + printf(" --bwd-dinput-bench-only only that bench\n"); + printf(" --gemm-bwd-fast-bench also bench full bwd: gemm vs gemm_fast vs cudnn (NMMO3 layer)\n"); + printf(" --gemm-bwd-fast-bench-only only that bench\n"); printf(" (script) --float / --fp32 compile fp32 (default)\n"); printf(" (script) --bf16 / --half compile bf16 (matches default native backend)\n"); return 0; @@ -689,5 +986,18 @@ int main(int argc, char** argv) { printf("\n--- gemm_conv_forward vs gemm_conv_forward_fast (relu off/on) ---\n"); if (run_gemm_fast_fwd_bench(dim, layer, warmup, iters)) return 1; } + if (do_bwd_dinput_bench) { + BenchDims bd = dim; + bd.relu = false; + printf("\n--- backward: ∂W-only (d_input=null) vs full (+d_input); gemm vs cudnn ---\n"); + if (run_backward_input_grad_bench(bd, warmup, iters)) return 1; + } + if (do_gemm_bwd_fast_bench) { + BenchDims bd = dim; + bd.relu = false; + printf("\n--- backward: gemm_conv_backward vs gemm_conv_backward_fast vs conv_backward ---\n"); + printf(" (NMMO3 geometry; use --layer 1 or 2)\n"); + if (run_backward_gemm_fast_bench(bd, layer, warmup, iters)) return 1; + } return 0; } diff --git a/tests/bench_conv_gemm_vs_cudnn.sh b/tests/bench_conv_gemm_vs_cudnn.sh index 4e35dffaa3..91293f3ffe 100755 --- a/tests/bench_conv_gemm_vs_cudnn.sh +++ b/tests/bench_conv_gemm_vs_cudnn.sh @@ -7,6 +7,10 @@ # ./tests/bench_conv_gemm_vs_cudnn.sh --bf16 --layer 2 --im2col-bench # gemm slow vs fast forward (relu 0/1), NMMO3 layer sizes: # ./tests/bench_conv_gemm_vs_cudnn.sh --layer 1 --gemm-fast-bench-only +# ∂W-only vs full backward (gemm vs cudnn): +# ./tests/bench_conv_gemm_vs_cudnn.sh --layer 1 --bwd-dinput-bench-only +# gemm vs gemm_fast vs cudnn full backward: +# ./tests/bench_conv_gemm_vs_cudnn.sh --layer 1 --gemm-bwd-fast-bench-only set -euo pipefail ROOT="$(cd "$(dirname "$0")/.." && pwd)" cd "$ROOT" From 8ebefd57f486c59ea1e6d964ffb010a0b936fb80 Mon Sep 17 00:00:00 2001 From: jonah Date: Thu, 9 Apr 2026 08:04:41 -0700 Subject: [PATCH 3/9] better test --- tests/bench_gemm_conv_end2end.cu | 333 +++++++++++++++++++++++++++++++ tests/bench_gemm_conv_end2end.sh | 65 ++++++ 2 files changed, 398 insertions(+) create mode 100644 tests/bench_gemm_conv_end2end.cu create mode 100755 tests/bench_gemm_conv_end2end.sh diff --git a/tests/bench_gemm_conv_end2end.cu b/tests/bench_gemm_conv_end2end.cu new file mode 100644 index 0000000000..d017c2f660 --- /dev/null +++ b/tests/bench_gemm_conv_end2end.cu @@ -0,0 +1,333 @@ +// End-to-end: gemm conv (slow) vs gemm_fast vs cudnn — forward & backward timed separately, layers 1 & 2 (NMMO3). +// Build/run: tests/bench_gemm_conv_end2end.sh [--float|--bf16] + +#include + +#include "models.cu" +#include "ocean.cu" + +#ifndef PRECISION_FLOAT +#include +#endif +#include +#include +#include +#include +#include +#include + +static void fill_rand_host(float* h, int n, unsigned s) { + for (int i = 0; i < n; ++i) { + s = s * 1103515245u + 12345u; + h[i] = ((s >> 16) & 0x7fff) / 16384.0f - 1.0f; + } +} + +static void stats_diff(const float* a, const float* b, int n, float* max_abs, float* mean_abs) { + float mx = 0.0f, sum = 0.0f; + for (int i = 0; i < n; ++i) { + float d = fabsf(a[i] - b[i]); + if (d > mx) mx = d; + sum += d; + } + *max_abs = mx; + *mean_abs = sum / (float)std::max(1, n); +} + +// max_i |a-b| / max(|a|, eps) — scale-free compare vs reference `a`. +static float stats_rel_max(const float* a, const float* b, int n, float eps) { + float mx = 0.0f; + for (int i = 0; i < n; ++i) { + float den = fmaxf(fabsf(a[i]), eps); + float r = fabsf(a[i] - b[i]) / den; + if (r > mx) mx = r; + } + return mx; +} + +static void copy_precision_d2h(const precision_t* d, int n, std::vector* hf) { + hf->resize((size_t)n); +#ifdef PRECISION_FLOAT + cudaMemcpy(hf->data(), d, (size_t)n * sizeof(float), cudaMemcpyDeviceToHost); +#else + std::vector h((size_t)n); + cudaMemcpy(h.data(), d, (size_t)n * sizeof(precision_t), cudaMemcpyDeviceToHost); + for (int i = 0; i < n; ++i) (*hf)[i] = __bfloat162float(h[i]); +#endif +} + +static void copy_fp32_h2d(const float* h, precision_t* d, int n) { +#ifdef PRECISION_FLOAT + cudaMemcpy(d, h, (size_t)n * sizeof(float), cudaMemcpyHostToDevice); +#else + std::vector hb((size_t)n); + for (int i = 0; i < n; ++i) hb[i] = __float2bfloat16(h[i]); + cudaMemcpy(d, hb.data(), (size_t)n * sizeof(precision_t), cudaMemcpyHostToDevice); +#endif +} + +template +static float time_kernel_ms(cudaStream_t stream, F fn, int warmup, int iters) { + for (int i = 0; i < warmup; ++i) { + fn(stream); + cudaStreamSynchronize(stream); + } + cudaEvent_t ev0, ev1; + cudaEventCreate(&ev0); + cudaEventCreate(&ev1); + cudaEventRecord(ev0, stream); + for (int i = 0; i < iters; ++i) fn(stream); + cudaEventRecord(ev1, stream); + cudaEventSynchronize(ev1); + float ms = 0.0f; + cudaEventElapsedTime(&ms, ev0, ev1); + cudaEventDestroy(ev0); + cudaEventDestroy(ev1); + return ms / (float)iters; +} + +struct BenchDims { + int B; + int IC, OC, K, S, IH, IW; + bool relu; +}; + +static void dims_conv1(BenchDims* d, int B) { + d->B = B; + d->IC = N3_C1_IC; + d->OC = N3_C1_OC; + d->K = N3_C1_K; + d->S = N3_C1_S; + d->IH = N3_MAP_H; + d->IW = N3_MAP_W; + d->relu = true; +} + +static void dims_conv2(BenchDims* d, int B) { + d->B = B; + d->IC = N3_C2_IC; + d->OC = N3_C2_OC; + d->K = N3_C2_K; + d->S = N3_C2_S; + d->IH = N3_C1_OH; + d->IW = N3_C1_OW; + d->relu = false; +} + +static int run_layer(int layer, int warmup, int iters) { + BenchDims dim{}; + if (layer == 1) dims_conv1(&dim, 1024); + else if (layer == 2) dims_conv2(&dim, 1024); + else return 1; + + const Im2ColFastMods& m = (layer == 1) ? kIm2ColModsC1 : kIm2ColModsC2; + int OH = (dim.IH - dim.K) / dim.S + 1; + int OW = (dim.IW - dim.K) / dim.S + 1; + if (OH != m.OH || OW != m.OW || dim.IC != m.IC || dim.IH != m.IH || dim.IW != m.IW || dim.OC != m.OC + || dim.K != m.K || dim.S != m.S) { + fprintf(stderr, "layer %d: dim mismatch vs Im2ColFastMods\n", layer); + return 1; + } + + ConvWeights cw{}; + conv_init(&cw, dim.IC, dim.OC, dim.K, dim.S, dim.IH, dim.IW, dim.relu); + + Allocator param_alloc{}; + conv_reg_params(&cw, ¶m_alloc); + if (alloc_create(¶m_alloc) != cudaSuccess) return 1; + uint64_t seed = 1000u + (unsigned)layer; + conv_init_weights(&cw, &seed, 0); + cudaDeviceSynchronize(); + + int B = dim.B; + int out_elems = B * dim.OC * OH * OW; + int in_elems = B * dim.IC * dim.IH * dim.IW; + int w_elems = (int)numel(cw.w.shape); + int col_rows = B * OH * OW; + int col_cols = dim.IC * dim.K * dim.K; + + Allocator act_g{}; + PrecisionTensor out_gemm{}, out_fast{}, col{}, mm{}; + PrecisionTensor saved_in{}, grad_out{}, wgrad_g{}, dinput_g{}; + out_gemm = {.shape = {(int64_t)out_elems}}; + out_fast = {.shape = {(int64_t)out_elems}}; + col = {.shape = {col_rows, col_cols}}; + mm = {.shape = {col_rows, dim.OC}}; + saved_in = {.shape = {B, dim.IC, dim.IH, dim.IW}}; + grad_out = {.shape = {(int64_t)out_elems}}; + wgrad_g = {.shape = {cw.w.shape[0], cw.w.shape[1]}}; + dinput_g = {.shape = {B, dim.IC, dim.IH, dim.IW}}; + alloc_register(&act_g, &out_gemm); + alloc_register(&act_g, &out_fast); + alloc_register(&act_g, &col); + alloc_register(&act_g, &mm); + alloc_register(&act_g, &saved_in); + alloc_register(&act_g, &grad_out); + alloc_register(&act_g, &wgrad_g); + alloc_register(&act_g, &dinput_g); + if (alloc_create(&act_g) != cudaSuccess) return 1; + + Allocator acts{}, grads{}; + ConvActivations ca{}; + conv_reg_train(&cw, &ca, &acts, &grads, B, n3_cudnn_dtype()); + if (alloc_create(&acts) != cudaSuccess || alloc_create(&grads) != cudaSuccess) return 1; + + std::vector hin(in_elems), hg_up(out_elems); + fill_rand_host(hin.data(), in_elems, 201u + (unsigned)layer); + fill_rand_host(hg_up.data(), out_elems, 303u + (unsigned)layer); + copy_fp32_h2d(hin.data(), saved_in.data, in_elems); + cudaMemcpy(ca.saved_input.data, saved_in.data, (size_t)in_elems * sizeof(precision_t), cudaMemcpyDeviceToDevice); + + cudaStream_t stream = 0; + const float rel_eps = 1e-5f; + + printf("layer %d B=%d IC=%d OC=%d %dx%d K=%d S=%d -> %dx%d relu=%d\n", layer, B, dim.IC, dim.OC, dim.IH, + dim.IW, dim.K, dim.S, OH, OW, (int)dim.relu); + printf(" --- correctness (reference = gemm_conv forward/backward) ---\n"); + + gemm_conv_forward(&cw.w, &cw.b, saved_in.data, out_gemm.data, col.data, mm.data, B, dim.IC, dim.IH, dim.IW, + dim.OC, dim.K, dim.S, OH, OW, dim.relu, stream); + cudaDeviceSynchronize(); + gemm_conv_forward_fast(&cw.w, &cw.b, saved_in.data, out_fast.data, col.data, mm.data, B, m, dim.relu, stream); + cudaDeviceSynchronize(); + conv_forward(&cw, &ca, saved_in.data, B, stream); + cudaDeviceSynchronize(); + + std::vector h_gemm_o, h_fast_o, h_cdnn_o; + copy_precision_d2h(out_gemm.data, out_elems, &h_gemm_o); + copy_precision_d2h(out_fast.data, out_elems, &h_fast_o); + copy_precision_d2h(ca.out.data, out_elems, &h_cdnn_o); + float mx, mn, rel; + stats_diff(h_gemm_o.data(), h_fast_o.data(), out_elems, &mx, &mn); + rel = stats_rel_max(h_gemm_o.data(), h_fast_o.data(), out_elems, rel_eps); + printf(" forward gemm_fast vs gemm: max|diff| %.6g mean|diff| %.6g max rel err %.6g\n", mx, mn, rel); + stats_diff(h_gemm_o.data(), h_cdnn_o.data(), out_elems, &mx, &mn); + rel = stats_rel_max(h_gemm_o.data(), h_cdnn_o.data(), out_elems, rel_eps); + printf(" forward cudnn vs gemm: max|diff| %.6g mean|diff| %.6g max rel err %.6g\n", mx, mn, rel); + + copy_fp32_h2d(hg_up.data(), grad_out.data, out_elems); + cudaMemcpy(ca.grad.data, grad_out.data, (size_t)out_elems * sizeof(precision_t), cudaMemcpyDeviceToDevice); + + cudaMemset(wgrad_g.data, 0, (size_t)w_elems * sizeof(precision_t)); + cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); + gemm_conv_backward(&cw.w, saved_in.data, grad_out.data, wgrad_g.data, dinput_g.data, col.data, mm.data, B, + dim.IC, dim.IH, dim.IW, dim.OC, dim.K, dim.S, OH, OW, stream); + cudaDeviceSynchronize(); + std::vector hwg_ref, hdi_ref; + copy_precision_d2h(wgrad_g.data, w_elems, &hwg_ref); + copy_precision_d2h(dinput_g.data, in_elems, &hdi_ref); + + cudaMemset(wgrad_g.data, 0, (size_t)w_elems * sizeof(precision_t)); + cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); + gemm_conv_backward_fast(&cw.w, saved_in.data, grad_out.data, wgrad_g.data, dinput_g.data, col.data, mm.data, B, + m, stream); + cudaDeviceSynchronize(); + std::vector hwg_f, hdi_f; + copy_precision_d2h(wgrad_g.data, w_elems, &hwg_f); + copy_precision_d2h(dinput_g.data, in_elems, &hdi_f); + stats_diff(hwg_ref.data(), hwg_f.data(), w_elems, &mx, &mn); + rel = stats_rel_max(hwg_ref.data(), hwg_f.data(), w_elems, rel_eps); + printf(" backward wgrad gemm_fast vs gemm: max|diff| %.6g mean|diff| %.6g max rel err %.6g\n", mx, mn, rel); + stats_diff(hdi_ref.data(), hdi_f.data(), in_elems, &mx, &mn); + rel = stats_rel_max(hdi_ref.data(), hdi_f.data(), in_elems, rel_eps); + printf(" backward d_input gemm_fast vs gemm: max|diff| %.6g mean|diff| %.6g max rel err %.6g\n", mx, mn, rel); + + cudaMemset(ca.wgrad.data, 0, (size_t)w_elems * sizeof(precision_t)); + cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); + conv_backward(&cw, &ca, dinput_g.data, B, stream); + cudaDeviceSynchronize(); + std::vector hwg_c, hdi_c; + copy_precision_d2h(ca.wgrad.data, w_elems, &hwg_c); + copy_precision_d2h(dinput_g.data, in_elems, &hdi_c); + stats_diff(hwg_ref.data(), hwg_c.data(), w_elems, &mx, &mn); + rel = stats_rel_max(hwg_ref.data(), hwg_c.data(), w_elems, rel_eps); + printf(" backward wgrad cudnn vs gemm: max|diff| %.6g mean|diff| %.6g max rel err %.6g\n", mx, mn, rel); + stats_diff(hdi_ref.data(), hdi_c.data(), in_elems, &mx, &mn); + rel = stats_rel_max(hdi_ref.data(), hdi_c.data(), in_elems, rel_eps); + printf(" backward d_input cudnn vs gemm: max|diff| %.6g mean|diff| %.6g max rel err %.6g\n", mx, mn, rel); + + printf(" --- timing (%d warmup / %d iters): forward and backward measured separately ---\n", warmup, iters); + + auto run_gemm_fwd = [&](cudaStream_t s) { + gemm_conv_forward(&cw.w, &cw.b, saved_in.data, out_gemm.data, col.data, mm.data, B, dim.IC, dim.IH, dim.IW, + dim.OC, dim.K, dim.S, OH, OW, dim.relu, s); + }; + auto run_gemm_fast_fwd = [&](cudaStream_t s) { + gemm_conv_forward_fast(&cw.w, &cw.b, saved_in.data, out_fast.data, col.data, mm.data, B, m, dim.relu, s); + }; + auto run_cudnn_fwd = [&](cudaStream_t s) { conv_forward(&cw, &ca, saved_in.data, B, s); }; + + auto run_gemm_bwd = [&](cudaStream_t s) { + cudaMemsetAsync(wgrad_g.data, 0, (size_t)w_elems * sizeof(precision_t), s); + cudaMemsetAsync(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t), s); + gemm_conv_backward(&cw.w, saved_in.data, grad_out.data, wgrad_g.data, dinput_g.data, col.data, mm.data, B, + dim.IC, dim.IH, dim.IW, dim.OC, dim.K, dim.S, OH, OW, s); + }; + auto run_gemm_fast_bwd = [&](cudaStream_t s) { + cudaMemsetAsync(wgrad_g.data, 0, (size_t)w_elems * sizeof(precision_t), s); + cudaMemsetAsync(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t), s); + gemm_conv_backward_fast(&cw.w, saved_in.data, grad_out.data, wgrad_g.data, dinput_g.data, col.data, mm.data, + B, m, s); + }; + auto run_cudnn_bwd = [&](cudaStream_t s) { + cudaMemcpyAsync(ca.saved_input.data, saved_in.data, (size_t)in_elems * sizeof(precision_t), + cudaMemcpyDeviceToDevice, s); + cudaMemcpyAsync(ca.grad.data, grad_out.data, (size_t)out_elems * sizeof(precision_t), cudaMemcpyDeviceToDevice, + s); + cudaMemsetAsync(ca.wgrad.data, 0, (size_t)w_elems * sizeof(precision_t), s); + cudaMemsetAsync(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t), s); + conv_backward(&cw, &ca, dinput_g.data, B, s); + }; + + float ms_gf = time_kernel_ms(stream, run_gemm_fwd, warmup, iters); + float ms_ff = time_kernel_ms(stream, run_gemm_fast_fwd, warmup, iters); + float ms_cf = time_kernel_ms(stream, run_cudnn_fwd, warmup, iters); + float ms_gb = time_kernel_ms(stream, run_gemm_bwd, warmup, iters); + float ms_fb = time_kernel_ms(stream, run_gemm_fast_bwd, warmup, iters); + float ms_cb = time_kernel_ms(stream, run_cudnn_bwd, warmup, iters); + + printf(" forward:\n"); + printf(" gemm (slow): %8.4f us/iter\n", ms_gf * 1000.0f); + printf(" gemm_fast: %8.4f us/iter (%.2fx vs gemm)\n", ms_ff * 1000.0f, ms_gf / ms_ff); + printf(" cudnn: %8.4f us/iter (%.2fx vs gemm, %.2fx vs gemm_fast)\n", ms_cf * 1000.0f, ms_gf / ms_cf, + ms_ff / ms_cf); + printf(" backward:\n"); + printf(" gemm (slow): %8.4f us/iter\n", ms_gb * 1000.0f); + printf(" gemm_fast: %8.4f us/iter (%.2fx vs gemm)\n", ms_fb * 1000.0f, ms_gb / ms_fb); + printf(" cudnn: %8.4f us/iter (%.2fx vs gemm, %.2fx vs gemm_fast)\n", ms_cb * 1000.0f, ms_gb / ms_cb, + ms_fb / ms_cb); + printf("\n"); + + alloc_free(¶m_alloc); + alloc_free(&act_g); + alloc_free(&acts); + alloc_free(&grads); + return 0; +} + +int main(int argc, char** argv) { + const int warmup = 50; + const int iters = 200; + for (int i = 1; i < argc; ++i) { + if (strcmp(argv[i], "--float") == 0 || strcmp(argv[i], "--fp32") == 0 || strcmp(argv[i], "--bf16") == 0 + || strcmp(argv[i], "--half") == 0) { + continue; + } + if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) { + printf("Usage: %s (precision: compile with tests/bench_gemm_conv_end2end.sh --float|--bf16)\n", argv[0]); + return 0; + } + fprintf(stderr, "Unknown arg: %s\n", argv[i]); + return 1; + } + +#ifdef PRECISION_FLOAT + printf("bench_gemm_conv_end2end precision=fp32 warmup=%d iters=%d\n\n", warmup, iters); +#else + printf("bench_gemm_conv_end2end precision=bf16 warmup=%d iters=%d\n\n", warmup, iters); +#endif + + if (run_layer(1, warmup, iters)) return 1; + if (run_layer(2, warmup, iters)) return 1; + return 0; +} diff --git a/tests/bench_gemm_conv_end2end.sh b/tests/bench_gemm_conv_end2end.sh new file mode 100755 index 0000000000..4c3ba31442 --- /dev/null +++ b/tests/bench_gemm_conv_end2end.sh @@ -0,0 +1,65 @@ +#!/usr/bin/env bash +# Build and run end-to-end conv benchmark: gemm vs gemm_fast vs cudnn (fwd+bwd), layers 1 & 2. +# Args: --float | --fp32 (default) or --bf16 | --half +# +# ./tests/bench_gemm_conv_end2end.sh +# ./tests/bench_gemm_conv_end2end.sh --bf16 +set -euo pipefail +ROOT="$(cd "$(dirname "$0")/.." && pwd)" +cd "$ROOT" + +CUDA_HOME="${CUDA_HOME:-${CUDA_PATH:-$(dirname "$(dirname "$(command -v nvcc)")")}}" +NVCC="${NVCC:-$CUDA_HOME/bin/nvcc}" +ARCH="${NVCC_ARCH:-native}" + +PRECISION_FLAG="-DPRECISION_FLOAT" +for arg in "$@"; do + case "$arg" in + --bf16|--half) PRECISION_FLAG="" ;; + --float|--fp32) PRECISION_FLAG="-DPRECISION_FLOAT" ;; + *) + echo "Unknown argument: $arg (use --float or --bf16)" >&2 + exit 1 + ;; + esac +done + +CUDNN_IFLAG="" +CUDNN_LFLAG="" +for dir in "$CUDA_HOME/include" /usr/local/cuda/include /usr/include; do + if [[ -f "$dir/cudnn.h" ]]; then + CUDNN_IFLAG="-I$dir" + break + fi +done +for dir in "$CUDA_HOME/lib64" "$CUDA_HOME/lib" /usr/lib/x86_64-linux-gnu; do + if [[ -f "$dir/libcudnn.so" ]] || [[ -f "$dir/libcudnn.dylib" ]]; then + CUDNN_LFLAG="-L$dir" + break + fi +done +if [[ -z "$CUDNN_IFLAG" ]]; then + CUDNN_IFLAG=$(python3 -c "import nvidia.cudnn, os; print('-I' + os.path.join(nvidia.cudnn.__path__[0], 'include'))" 2>/dev/null || true) +fi +if [[ -z "$CUDNN_LFLAG" ]]; then + CUDNN_LFLAG=$(python3 -c "import nvidia.cudnn, os; print('-L' + os.path.join(nvidia.cudnn.__path__[0], 'lib'))" 2>/dev/null || true) +fi + +OUT="${ROOT}/tests/bench_gemm_conv_end2end" +if [[ -n "$PRECISION_FLAG" ]]; then + echo "nvcc $ARCH $OUT (fp32)" +else + echo "nvcc $ARCH $OUT (bf16)" +fi +"$NVCC" -O2 -std=c++17 "-arch=$ARCH" \ + "-I${ROOT}/src" \ + "-I${CUDA_HOME}/include" \ + $CUDNN_IFLAG \ + $PRECISION_FLAG \ + "${ROOT}/tests/bench_gemm_conv_end2end.cu" \ + -o "$OUT" \ + $CUDNN_LFLAG \ + -L"${CUDA_HOME}/lib64" -L"${CUDA_HOME}/lib" \ + -lcublas -lcudnn -lcurand + +exec "$OUT" From fabbe9b45b4acfcbc56915ddf4d4036e212a6d95 Mon Sep 17 00:00:00 2001 From: jonah Date: Thu, 9 Apr 2026 08:05:11 -0700 Subject: [PATCH 4/9] remove --- tests/bench_conv_gemm_vs_cudnn.cu | 1003 ----------------------------- tests/bench_conv_gemm_vs_cudnn.sh | 72 --- tests/tune_cublas_gemm.cu | 296 --------- tests/tune_cublas_gemm.sh | 36 -- 4 files changed, 1407 deletions(-) delete mode 100755 tests/bench_conv_gemm_vs_cudnn.cu delete mode 100755 tests/bench_conv_gemm_vs_cudnn.sh delete mode 100644 tests/tune_cublas_gemm.cu delete mode 100755 tests/tune_cublas_gemm.sh diff --git a/tests/bench_conv_gemm_vs_cudnn.cu b/tests/bench_conv_gemm_vs_cudnn.cu deleted file mode 100755 index cd05977ee2..0000000000 --- a/tests/bench_conv_gemm_vs_cudnn.cu +++ /dev/null @@ -1,1003 +0,0 @@ -// Benchmark: gemm_conv_* (im2col + cuBLAS) vs conv_* (cuDNN), same weights/inputs. -// Baseline for correctness: gemm path (per project convention). -// -// Build: see tests/bench_conv_gemm_vs_cudnn.sh (-DPRECISION_FLOAT => fp32; omit => bf16) - -#include -#include "models.cu" -#include "ocean.cu" - -#ifndef PRECISION_FLOAT -#include -#endif -#include -#include -#include -#include -#include -#include - -static void fill_rand_host(float* h, int n, unsigned s) { - for (int i = 0; i < n; ++i) { - s = s * 1103515245u + 12345u; - h[i] = ((s >> 16) & 0x7fff) / 16384.0f - 1.0f; - } -} - -static void stats_diff(const float* a, const float* b, int n, float* max_abs, float* mean_abs) { - float mx = 0.0f, sum = 0.0f; - for (int i = 0; i < n; ++i) { - float d = fabsf(a[i] - b[i]); - if (d > mx) mx = d; - sum += d; - } - *max_abs = mx; - *mean_abs = sum / (float)std::max(1, n); -} - -static void copy_precision_d2h(const precision_t* d, int n, std::vector* hf) { - hf->resize((size_t)n); -#ifdef PRECISION_FLOAT - cudaMemcpy(hf->data(), d, (size_t)n * sizeof(float), cudaMemcpyDeviceToHost); -#else - std::vector h((size_t)n); - cudaMemcpy(h.data(), d, (size_t)n * sizeof(precision_t), cudaMemcpyDeviceToHost); - for (int i = 0; i < n; ++i) (*hf)[i] = __bfloat162float(h[i]); -#endif -} - -static void copy_fp32_h2d(const float* h, precision_t* d, int n) { -#ifdef PRECISION_FLOAT - cudaMemcpy(d, h, (size_t)n * sizeof(float), cudaMemcpyHostToDevice); -#else - std::vector hb((size_t)n); - for (int i = 0; i < n; ++i) hb[i] = __float2bfloat16(h[i]); - cudaMemcpy(d, hb.data(), (size_t)n * sizeof(precision_t), cudaMemcpyHostToDevice); -#endif -} - -template -static float time_kernel_ms(cudaStream_t stream, F fn, int warmup, int iters) { - for (int i = 0; i < warmup; ++i) { - fn(stream); - cudaStreamSynchronize(stream); - } - cudaEvent_t ev0, ev1; - cudaEventCreate(&ev0); - cudaEventCreate(&ev1); - cudaEventRecord(ev0, stream); - for (int i = 0; i < iters; ++i) fn(stream); - cudaEventRecord(ev1, stream); - cudaEventSynchronize(ev1); - float ms = 0.0f; - cudaEventElapsedTime(&ms, ev0, ev1); - cudaEventDestroy(ev0); - cudaEventDestroy(ev1); - return ms / (float)iters; -} - -struct BenchDims { - int B; - int IC, OC, K, S, IH, IW; - bool relu; -}; - -static void dims_conv1(BenchDims* d, int B) { - d->B = B; - d->IC = N3_C1_IC; - d->OC = N3_C1_OC; - d->K = N3_C1_K; - d->S = N3_C1_S; - d->IH = N3_MAP_H; - d->IW = N3_MAP_W; - d->relu = true; -} - -static void dims_conv2(BenchDims* d, int B) { - d->B = B; - d->IC = N3_C2_IC; - d->OC = N3_C2_OC; - d->K = N3_C2_K; - d->S = N3_C2_S; - d->IH = N3_C1_OH; - d->IW = N3_C1_OW; - d->relu = false; -} - -static int run_forward(const BenchDims& dim, int warmup, int iters, bool cudnn_save_input) { - ConvWeights cw{}; - conv_init(&cw, dim.IC, dim.OC, dim.K, dim.S, dim.IH, dim.IW, dim.relu); - - Allocator param_alloc{}; - conv_reg_params(&cw, ¶m_alloc); - if (alloc_create(¶m_alloc) != cudaSuccess) { - fprintf(stderr, "alloc_create params failed\n"); - return 1; - } - uint64_t seed = 42; - conv_init_weights(&cw, &seed, 0); - cudaDeviceSynchronize(); - - int OH = cw.OH; - int OW = cw.OW; - int out_elems = dim.B * dim.OC * OH * OW; - int in_elems = dim.B * dim.IC * dim.IH * dim.IW; - - Allocator act_g{}; - PrecisionTensor out_g{}, col{}, mm{}, input{}; - int col_rows = dim.B * OH * OW; - int col_cols = dim.IC * dim.K * dim.K; - out_g = {.shape = {(int64_t)out_elems}}; - col = {.shape = {col_rows, col_cols}}; - mm = {.shape = {col_rows, dim.OC}}; - input = {.shape = {dim.B, dim.IC, dim.IH, dim.IW}}; - alloc_register(&act_g, &out_g); - alloc_register(&act_g, &col); - alloc_register(&act_g, &mm); - alloc_register(&act_g, &input); - if (alloc_create(&act_g) != cudaSuccess) { - fprintf(stderr, "alloc_create gemm acts failed\n"); - return 1; - } - - std::vector hin(in_elems); - fill_rand_host(hin.data(), in_elems, 99u); - copy_fp32_h2d(hin.data(), input.data, in_elems); - - Allocator acts{}, grads{}; - ConvActivations ca{}; - conv_reg_train(&cw, &ca, &acts, &grads, dim.B, n3_cudnn_dtype()); - if (alloc_create(&acts) != cudaSuccess || alloc_create(&grads) != cudaSuccess) { - fprintf(stderr, "alloc_create cudnn failed\n"); - return 1; - } - if (!cudnn_save_input) ca.saved_input.data = nullptr; - - cudaStream_t stream = 0; - - gemm_conv_forward(&cw.w, &cw.b, input.data, out_g.data, col.data, mm.data, dim.B, dim.IC, dim.IH, - dim.IW, dim.OC, dim.K, dim.S, OH, OW, dim.relu, stream); - cudaDeviceSynchronize(); - - conv_forward(&cw, &ca, input.data, dim.B, stream); - cudaDeviceSynchronize(); - - std::vector hg, hc; - copy_precision_d2h(out_g.data, out_elems, &hg); - copy_precision_d2h(ca.out.data, out_elems, &hc); - - float max_abs, mean_abs; - stats_diff(hg.data(), hc.data(), out_elems, &max_abs, &mean_abs); - printf(" forward max |diff|: %.6g mean |diff|: %.6g\n", max_abs, mean_abs); - - auto run_gemm = [&](cudaStream_t s) { - gemm_conv_forward(&cw.w, &cw.b, input.data, out_g.data, col.data, mm.data, dim.B, dim.IC, - dim.IH, dim.IW, dim.OC, dim.K, dim.S, OH, OW, dim.relu, s); - }; - auto run_cudnn = [&](cudaStream_t s) { - conv_forward(&cw, &ca, input.data, dim.B, s); - }; - - float ms_g = time_kernel_ms(stream, run_gemm, warmup, iters); - float ms_c = time_kernel_ms(stream, run_cudnn, warmup, iters); - printf(" gemm_conv_forward: %8.4f us/iter\n", ms_g * 1000.0f); - printf(" conv_forward: %8.4f us/iter (%.2fx vs gemm)\n", ms_c * 1000.0f, ms_g / ms_c); - - alloc_free(¶m_alloc); - alloc_free(&act_g); - alloc_free(&acts); - alloc_free(&grads); - return 0; -} - -// ∂L/∂W only: gemm path vs cudnnConvolutionBackwardFilter. -// Correctness: gemm_conv_backward(..., input_grad=null) vs conv_backward(..., nullptr). -// Timings: inlined nchw+im2col+mm_tn, gemm_conv_backward, and cudnn BackwardFilter. -static int run_filter_backward_bench(const BenchDims& dim, int warmup, int iters) { - int B = dim.B, IC = dim.IC, OC = dim.OC, K = dim.K, S = dim.S, IH = dim.IH, IW = dim.IW; - int OH = (IH - K) / S + 1; - int OW = (IW - K) / S + 1; - int col_rows = B * OH * OW; - int col_cols = IC * K * K; - int total_col = col_rows * col_cols; - int total_out = B * OC * OH * OW; - int in_elems = B * IC * IH * IW; - int spatial = OH * OW; - int w_elems = OC * col_cols; - - ConvWeights cw{}; - conv_init(&cw, IC, OC, K, S, IH, IW, false); - Allocator param_alloc{}; - conv_reg_params(&cw, ¶m_alloc); - if (alloc_create(¶m_alloc) != cudaSuccess) return 1; - uint64_t seed = 301; - conv_init_weights(&cw, &seed, 0); - cudaDeviceSynchronize(); - - Allocator act{}; - PrecisionTensor col{}, mm{}, saved_in{}, grad_out{}, wgrad{}; - col = {.shape = {col_rows, col_cols}}; - mm = {.shape = {col_rows, OC}}; - saved_in = {.shape = {B, IC, IH, IW}}; - grad_out = {.shape = {(int64_t)total_out}}; - wgrad = {.shape = {OC, col_cols}}; - alloc_register(&act, &col); - alloc_register(&act, &mm); - alloc_register(&act, &saved_in); - alloc_register(&act, &grad_out); - alloc_register(&act, &wgrad); - if (alloc_create(&act) != cudaSuccess) return 1; - - Allocator acts{}, grads{}; - ConvActivations ca{}; - conv_reg_train(&cw, &ca, &acts, &grads, B, n3_cudnn_dtype()); - if (alloc_create(&acts) != cudaSuccess || alloc_create(&grads) != cudaSuccess) return 1; - - std::vector hs(in_elems), hg(total_out); - fill_rand_host(hs.data(), in_elems, 301u); - fill_rand_host(hg.data(), total_out, 302u); - copy_fp32_h2d(hs.data(), saved_in.data, in_elems); - copy_fp32_h2d(hg.data(), grad_out.data, total_out); - cudaMemcpy(ca.saved_input.data, saved_in.data, (size_t)in_elems * sizeof(precision_t), cudaMemcpyDeviceToDevice); - cudaMemcpy(ca.grad.data, grad_out.data, (size_t)total_out * sizeof(precision_t), cudaMemcpyDeviceToDevice); - - PrecisionTensor mm_t = {.data = mm.data, .shape = {col_rows, OC}}; - PrecisionTensor col_t = {.data = col.data, .shape = {col_rows, col_cols}}; - PrecisionTensor wg_t = {.data = wgrad.data, .shape = {OC, col_cols}}; - - cudaStream_t stream = 0; - - cudaMemset(wgrad.data, 0, (size_t)w_elems * sizeof(precision_t)); - gemm_conv_backward(&cw.w, saved_in.data, grad_out.data, wgrad.data, nullptr, col.data, mm.data, B, IC, IH, IW, - OC, K, S, OH, OW, stream); - cudaDeviceSynchronize(); - std::vector h_wg_g, h_wg_c; - copy_precision_d2h(wgrad.data, w_elems, &h_wg_g); - - cudaMemset(ca.wgrad.data, 0, (size_t)w_elems * sizeof(precision_t)); - conv_backward(&cw, &ca, nullptr, B, stream); - cudaDeviceSynchronize(); - copy_precision_d2h(ca.wgrad.data, w_elems, &h_wg_c); - - float f_wg_max, f_wg_mean; - stats_diff(h_wg_g.data(), h_wg_c.data(), w_elems, &f_wg_max, &f_wg_mean); - printf(" filter ∂W max |diff| (gemm_conv_backward vs cudnn): %.6g mean |diff|: %.6g\n", f_wg_max, - f_wg_mean); - - float ms_nchw = time_kernel_ms( - stream, - [&](cudaStream_t s) { - nchw_to_rows_kernel<<>>( - grad_out.data, mm.data, B, OC, spatial); - }, - warmup, iters); - - float ms_im2col = time_kernel_ms( - stream, - [&](cudaStream_t s) { - im2col_kernel<<>>( - saved_in.data, col.data, B, IC, IH, IW, K, S, OH, OW); - }, - warmup, iters); - - nchw_to_rows_kernel<<>>( - grad_out.data, mm.data, B, OC, spatial); - im2col_kernel<<>>( - saved_in.data, col.data, B, IC, IH, IW, K, S, OH, OW); - cudaStreamSynchronize(stream); - - float ms_tn = time_kernel_ms( - stream, - [&](cudaStream_t s) { puf_mm_tn(&mm_t, &col_t, &wg_t, s); }, warmup, iters); - - float ms_gemm_chain = time_kernel_ms( - stream, - [&](cudaStream_t s) { - nchw_to_rows_kernel<<>>( - grad_out.data, mm.data, B, OC, spatial); - im2col_kernel<<>>( - saved_in.data, col.data, B, IC, IH, IW, K, S, OH, OW); - puf_mm_tn(&mm_t, &col_t, &wg_t, s); - }, - warmup, iters); - - float ms_gemm_conv_bwd_wonly = time_kernel_ms( - stream, - [&](cudaStream_t s) { - cudaMemset(wgrad.data, 0, (size_t)w_elems * sizeof(precision_t)); - gemm_conv_backward(&cw.w, saved_in.data, grad_out.data, wgrad.data, nullptr, col.data, mm.data, B, IC, - IH, IW, OC, K, S, OH, OW, s); - }, - warmup, iters); - - float ms_cudnn_filt = time_kernel_ms( - stream, - [&](cudaStream_t s) { - cudaMemset(ca.wgrad.data, 0, (size_t)w_elems * sizeof(precision_t)); - conv_backward(&cw, &ca, nullptr, B, s); - }, - warmup, iters); - - float sum_iso = (ms_nchw + ms_im2col + ms_tn) * 1000.0f; - printf(" gemm: nchw_to_rows only: %8.4f us/iter\n", ms_nchw * 1000.0f); - printf(" gemm: im2col only: %8.4f us/iter\n", ms_im2col * 1000.0f); - printf(" gemm: puf_mm_tn only: %8.4f us/iter (mm/col prefilled)\n", ms_tn * 1000.0f); - printf(" gemm: nchw+im2col+mm_tn: %8.4f us/iter (chained, matches ∂W slice)\n", ms_gemm_chain * 1000.0f); - printf(" gemm: gemm_conv_backward: %8.4f us/iter (input_grad=null, same as ocean.cu ∂W-only)\n", - ms_gemm_conv_bwd_wonly * 1000.0f); - printf(" sum(3 isolated): %8.4f us/iter (vs chained)\n", sum_iso); - printf(" cudnn: BackwardFilter only: %8.4f us/iter (conv_backward, input_grad=null)\n", - ms_cudnn_filt * 1000.0f); - printf(" ratio gemm_chain/cudnn: %.2fx (>1 => cudnn faster)\n", ms_gemm_chain / ms_cudnn_filt); - printf(" ratio gemm_conv_bwd/cudnn: %.2fx\n", ms_gemm_conv_bwd_wonly / ms_cudnn_filt); - - alloc_free(¶m_alloc); - alloc_free(&act); - alloc_free(&acts); - alloc_free(&grads); - return 0; -} - -static int run_backward(const BenchDims& dim, int warmup, int iters) { - ConvWeights cw{}; - conv_init(&cw, dim.IC, dim.OC, dim.K, dim.S, dim.IH, dim.IW, false); - - Allocator param_alloc{}; - conv_reg_params(&cw, ¶m_alloc); - if (alloc_create(¶m_alloc) != cudaSuccess) return 1; - uint64_t seed = 7; - conv_init_weights(&cw, &seed, 0); - cudaDeviceSynchronize(); - - int OH = cw.OH; - int OW = cw.OW; - int out_elems = dim.B * dim.OC * OH * OW; - int in_elems = dim.B * dim.IC * dim.IH * dim.IW; - int w_elems = (int)numel(cw.w.shape); - - Allocator act_g{}; - PrecisionTensor col{}, mm{}; - int col_rows = dim.B * OH * OW; - int col_cols = dim.IC * dim.K * dim.K; - col = {.shape = {col_rows, col_cols}}; - mm = {.shape = {col_rows, dim.OC}}; - PrecisionTensor saved_in{}, grad_out{}, wgrad_g{}; - saved_in = {.shape = {dim.B, dim.IC, dim.IH, dim.IW}}; - grad_out = {.shape = {(int64_t)out_elems}}; - wgrad_g = {.shape = {cw.w.shape[0], cw.w.shape[1]}}; - PrecisionTensor dinput_g{}; - dinput_g = {.shape = {dim.B, dim.IC, dim.IH, dim.IW}}; - alloc_register(&act_g, &col); - alloc_register(&act_g, &mm); - alloc_register(&act_g, &saved_in); - alloc_register(&act_g, &grad_out); - alloc_register(&act_g, &wgrad_g); - alloc_register(&act_g, &dinput_g); - if (alloc_create(&act_g) != cudaSuccess) return 1; - - std::vector hs(in_elems), hg(out_elems); - fill_rand_host(hs.data(), in_elems, 101u); - fill_rand_host(hg.data(), out_elems, 202u); - copy_fp32_h2d(hs.data(), saved_in.data, in_elems); - copy_fp32_h2d(hg.data(), grad_out.data, out_elems); - - Allocator acts{}, grads{}; - ConvActivations ca{}; - conv_reg_train(&cw, &ca, &acts, &grads, dim.B, n3_cudnn_dtype()); - if (alloc_create(&acts) != cudaSuccess || alloc_create(&grads) != cudaSuccess) return 1; - cudaMemcpy(ca.saved_input.data, saved_in.data, (size_t)in_elems * sizeof(precision_t), cudaMemcpyDeviceToDevice); - cudaMemcpy(ca.grad.data, grad_out.data, (size_t)out_elems * sizeof(precision_t), cudaMemcpyDeviceToDevice); - - cudaStream_t stream = 0; - - cudaMemset(wgrad_g.data, 0, (size_t)w_elems * sizeof(precision_t)); - cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); - gemm_conv_backward(&cw.w, saved_in.data, grad_out.data, wgrad_g.data, dinput_g.data, col.data, mm.data, - dim.B, dim.IC, dim.IH, dim.IW, dim.OC, dim.K, dim.S, OH, OW, stream); - cudaDeviceSynchronize(); - - std::vector hwg, hdi; - copy_precision_d2h(wgrad_g.data, w_elems, &hwg); - copy_precision_d2h(dinput_g.data, in_elems, &hdi); - - cudaMemset(ca.wgrad.data, 0, (size_t)w_elems * sizeof(precision_t)); - cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); - conv_backward(&cw, &ca, dinput_g.data, dim.B, stream); - cudaDeviceSynchronize(); - - std::vector hwg_c, hdi_c; - copy_precision_d2h(ca.wgrad.data, w_elems, &hwg_c); - copy_precision_d2h(dinput_g.data, in_elems, &hdi_c); - - float mw, mdw, mdi, mdi_mean; - stats_diff(hwg.data(), hwg_c.data(), w_elems, &mw, &mdw); - stats_diff(hdi.data(), hdi_c.data(), in_elems, &mdi, &mdi_mean); - printf(" backward wgrad max |diff|: %.6g\n", mw); - printf(" backward d_input max |diff|: %.6g\n", mdi); - - auto run_gemm_b = [&](cudaStream_t s) { - cudaMemset(wgrad_g.data, 0, (size_t)w_elems * sizeof(precision_t)); - cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); - gemm_conv_backward(&cw.w, saved_in.data, grad_out.data, wgrad_g.data, dinput_g.data, col.data, - mm.data, dim.B, dim.IC, dim.IH, dim.IW, dim.OC, dim.K, dim.S, OH, OW, s); - }; - auto run_cudnn_b = [&](cudaStream_t s) { - cudaMemset(ca.wgrad.data, 0, (size_t)w_elems * sizeof(precision_t)); - cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); - conv_backward(&cw, &ca, dinput_g.data, dim.B, s); - }; - - float ms_g = time_kernel_ms(stream, run_gemm_b, warmup, iters); - float ms_c = time_kernel_ms(stream, run_cudnn_b, warmup, iters); - printf(" gemm_conv_backward: %8.4f us/iter\n", ms_g * 1000.0f); - printf(" conv_backward: %8.4f us/iter (%.2fx vs gemm)\n", ms_c * 1000.0f, ms_g / ms_c); - - alloc_free(¶m_alloc); - alloc_free(&act_g); - alloc_free(&acts); - alloc_free(&grads); - return 0; -} - -// ∂W-only (input_grad=null) vs full backward (+ col2im / cudnn BackwardData); gemm vs cudnn. -static int run_backward_input_grad_bench(const BenchDims& dim, int warmup, int iters) { - ConvWeights cw{}; - conv_init(&cw, dim.IC, dim.OC, dim.K, dim.S, dim.IH, dim.IW, false); - - Allocator param_alloc{}; - conv_reg_params(&cw, ¶m_alloc); - if (alloc_create(¶m_alloc) != cudaSuccess) return 1; - uint64_t seed = 7; - conv_init_weights(&cw, &seed, 0); - cudaDeviceSynchronize(); - - int OH = cw.OH; - int OW = cw.OW; - int out_elems = dim.B * dim.OC * OH * OW; - int in_elems = dim.B * dim.IC * dim.IH * dim.IW; - int w_elems = (int)numel(cw.w.shape); - - Allocator act_g{}; - PrecisionTensor col{}, mm{}; - int col_rows = dim.B * OH * OW; - int col_cols = dim.IC * dim.K * dim.K; - col = {.shape = {col_rows, col_cols}}; - mm = {.shape = {col_rows, dim.OC}}; - PrecisionTensor saved_in{}, grad_out{}, wgrad_g{}; - saved_in = {.shape = {dim.B, dim.IC, dim.IH, dim.IW}}; - grad_out = {.shape = {(int64_t)out_elems}}; - wgrad_g = {.shape = {cw.w.shape[0], cw.w.shape[1]}}; - PrecisionTensor dinput_g{}; - dinput_g = {.shape = {dim.B, dim.IC, dim.IH, dim.IW}}; - alloc_register(&act_g, &col); - alloc_register(&act_g, &mm); - alloc_register(&act_g, &saved_in); - alloc_register(&act_g, &grad_out); - alloc_register(&act_g, &wgrad_g); - alloc_register(&act_g, &dinput_g); - if (alloc_create(&act_g) != cudaSuccess) return 1; - - std::vector hs(in_elems), hg(out_elems); - fill_rand_host(hs.data(), in_elems, 101u); - fill_rand_host(hg.data(), out_elems, 202u); - copy_fp32_h2d(hs.data(), saved_in.data, in_elems); - copy_fp32_h2d(hg.data(), grad_out.data, out_elems); - - Allocator acts{}, grads{}; - ConvActivations ca{}; - conv_reg_train(&cw, &ca, &acts, &grads, dim.B, n3_cudnn_dtype()); - if (alloc_create(&acts) != cudaSuccess || alloc_create(&grads) != cudaSuccess) return 1; - cudaMemcpy(ca.saved_input.data, saved_in.data, (size_t)in_elems * sizeof(precision_t), cudaMemcpyDeviceToDevice); - cudaMemcpy(ca.grad.data, grad_out.data, (size_t)out_elems * sizeof(precision_t), cudaMemcpyDeviceToDevice); - - cudaStream_t stream = 0; - - cudaMemset(wgrad_g.data, 0, (size_t)w_elems * sizeof(precision_t)); - gemm_conv_backward(&cw.w, saved_in.data, grad_out.data, wgrad_g.data, nullptr, col.data, mm.data, dim.B, - dim.IC, dim.IH, dim.IW, dim.OC, dim.K, dim.S, OH, OW, stream); - cudaDeviceSynchronize(); - std::vector hwg_wnull_g; - copy_precision_d2h(wgrad_g.data, w_elems, &hwg_wnull_g); - - cudaMemset(ca.wgrad.data, 0, (size_t)w_elems * sizeof(precision_t)); - conv_backward(&cw, &ca, nullptr, dim.B, stream); - cudaDeviceSynchronize(); - std::vector hwg_wnull_c; - copy_precision_d2h(ca.wgrad.data, w_elems, &hwg_wnull_c); - - float d_wg_wnull, m_wg_wnull; - stats_diff(hwg_wnull_g.data(), hwg_wnull_c.data(), w_elems, &d_wg_wnull, &m_wg_wnull); - printf(" ∂W only wgrad max |diff| (gemm vs cudnn): %.6g mean |diff|: %.6g\n", d_wg_wnull, m_wg_wnull); - - cudaMemset(wgrad_g.data, 0, (size_t)w_elems * sizeof(precision_t)); - cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); - gemm_conv_backward(&cw.w, saved_in.data, grad_out.data, wgrad_g.data, dinput_g.data, col.data, mm.data, - dim.B, dim.IC, dim.IH, dim.IW, dim.OC, dim.K, dim.S, OH, OW, stream); - cudaDeviceSynchronize(); - std::vector hwg_full_g, hdi_full_g; - copy_precision_d2h(wgrad_g.data, w_elems, &hwg_full_g); - copy_precision_d2h(dinput_g.data, in_elems, &hdi_full_g); - - cudaMemset(ca.wgrad.data, 0, (size_t)w_elems * sizeof(precision_t)); - cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); - conv_backward(&cw, &ca, dinput_g.data, dim.B, stream); - cudaDeviceSynchronize(); - std::vector hwg_full_c, hdi_full_c; - copy_precision_d2h(ca.wgrad.data, w_elems, &hwg_full_c); - copy_precision_d2h(dinput_g.data, in_elems, &hdi_full_c); - - float d_wg_full, m_wg_full, d_di, m_di; - stats_diff(hwg_full_g.data(), hwg_full_c.data(), w_elems, &d_wg_full, &m_wg_full); - stats_diff(hdi_full_g.data(), hdi_full_c.data(), in_elems, &d_di, &m_di); - printf(" full wgrad max |diff| (gemm vs cudnn): %.6g mean |diff|: %.6g\n", d_wg_full, m_wg_full); - printf(" full d_input max |diff| (gemm vs cudnn): %.6g mean |diff|: %.6g\n", d_di, m_di); - - auto run_gemm_wnull = [&](cudaStream_t s) { - cudaMemset(wgrad_g.data, 0, (size_t)w_elems * sizeof(precision_t)); - gemm_conv_backward(&cw.w, saved_in.data, grad_out.data, wgrad_g.data, nullptr, col.data, mm.data, - dim.B, dim.IC, dim.IH, dim.IW, dim.OC, dim.K, dim.S, OH, OW, s); - }; - auto run_gemm_full = [&](cudaStream_t s) { - cudaMemset(wgrad_g.data, 0, (size_t)w_elems * sizeof(precision_t)); - cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); - gemm_conv_backward(&cw.w, saved_in.data, grad_out.data, wgrad_g.data, dinput_g.data, col.data, - mm.data, dim.B, dim.IC, dim.IH, dim.IW, dim.OC, dim.K, dim.S, OH, OW, s); - }; - auto run_cudnn_wnull = [&](cudaStream_t s) { - cudaMemset(ca.wgrad.data, 0, (size_t)w_elems * sizeof(precision_t)); - conv_backward(&cw, &ca, nullptr, dim.B, s); - }; - auto run_cudnn_full = [&](cudaStream_t s) { - cudaMemset(ca.wgrad.data, 0, (size_t)w_elems * sizeof(precision_t)); - cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); - conv_backward(&cw, &ca, dinput_g.data, dim.B, s); - }; - - float ms_g_w = time_kernel_ms(stream, run_gemm_wnull, warmup, iters); - float ms_g_f = time_kernel_ms(stream, run_gemm_full, warmup, iters); - float ms_c_w = time_kernel_ms(stream, run_cudnn_wnull, warmup, iters); - float ms_c_f = time_kernel_ms(stream, run_cudnn_full, warmup, iters); - - printf(" gemm ∂W only (d_input=null): %8.4f us/iter\n", ms_g_w * 1000.0f); - printf(" gemm full (+d_input): %8.4f us/iter (+%.4f us d_input slice)\n", - ms_g_f * 1000.0f, (ms_g_f - ms_g_w) * 1000.0f); - printf(" cudnn ∂W only (no BwdData): %8.4f us/iter\n", ms_c_w * 1000.0f); - printf(" cudnn full (+BwdData): %8.4f us/iter (+%.4f us BwdData slice)\n", - ms_c_f * 1000.0f, (ms_c_f - ms_c_w) * 1000.0f); - - alloc_free(¶m_alloc); - alloc_free(&act_g); - alloc_free(&acts); - alloc_free(&grads); - return 0; -} - -static int run_im2col_bench(const BenchDims& dim, int warmup, int iters) { - int B = dim.B, IC = dim.IC, IH = dim.IH, IW = dim.IW, K = dim.K, S = dim.S; - int OH = (IH - K) / S + 1; - int OW = (IW - K) / S + 1; - int total_col = B * OH * OW * IC * K * K; - int in_elems = B * IC * IH * IW; - - precision_t *d_in = nullptr, *d_col_slow = nullptr, *d_col_fast = nullptr; - if (cudaMalloc(&d_in, (size_t)in_elems * sizeof(precision_t)) != cudaSuccess) return 1; - if (cudaMalloc(&d_col_slow, (size_t)total_col * sizeof(precision_t)) != cudaSuccess) return 1; - if (cudaMalloc(&d_col_fast, (size_t)total_col * sizeof(precision_t)) != cudaSuccess) return 1; - - std::vector h_in((size_t)in_elems); - fill_rand_host(h_in.data(), in_elems, 401u); - copy_fp32_h2d(h_in.data(), d_in, in_elems); - cudaDeviceSynchronize(); - - const int oh_ow = OH * OW; - const int col_cols = IC * K * K; - const int total_no_batch = oh_ow * col_cols; - const int kk = K * K; - FastDivMod dm_col_w(col_cols); - FastDivMod dm_oh_ow(oh_ow); - FastDivMod dm_ow(OW); - FastDivMod dm_kk(kk); - FastDivMod dm_k(K); - - cudaStream_t stream = 0; - im2col_kernel<<>>( - d_in, d_col_slow, B, IC, IH, IW, K, S, OH, OW); - im2col_kernel_fast<<>>( - d_in, d_col_fast, B, IC, IH, IW, K, S, OH, OW, - dm_col_w, dm_oh_ow, dm_ow, dm_kk, dm_k, total_no_batch); - cudaDeviceSynchronize(); - - std::vector hs, hf; - copy_precision_d2h(d_col_slow, total_col, &hs); - copy_precision_d2h(d_col_fast, total_col, &hf); - float max_d = 0.0f, mean_d = 0.0f; - stats_diff(hs.data(), hf.data(), total_col, &max_d, &mean_d); - printf(" im2col vs im2col_fast max |diff|: %.6g mean |diff|: %.6g\n", max_d, mean_d); - - float ms_slow = time_kernel_ms( - stream, - [&](cudaStream_t s) { - im2col_kernel<<>>( - d_in, d_col_slow, B, IC, IH, IW, K, S, OH, OW); - }, - warmup, iters); - float ms_fast = time_kernel_ms( - stream, - [&](cudaStream_t s) { - im2col_kernel_fast<<>>( - d_in, d_col_fast, B, IC, IH, IW, K, S, OH, OW, - dm_col_w, dm_oh_ow, dm_ow, dm_kk, dm_k, total_no_batch); - }, - warmup, iters); - printf(" im2col_kernel: %8.4f us/iter\n", ms_slow * 1000.0f); - printf(" im2col_kernel_fast: %8.4f us/iter (%.2fx vs slow)\n", ms_fast * 1000.0f, - ms_slow / ms_fast); - - cudaFree(d_in); - cudaFree(d_col_slow); - cudaFree(d_col_fast); - return 0; -} - -// Full backward: gemm_conv_backward vs gemm_conv_backward_fast vs cudnn (NMMO3 layer geometry only). -static int run_backward_gemm_fast_bench(const BenchDims& dim, int layer, int warmup, int iters) { - const Im2ColFastMods& m = (layer == 1) ? kIm2ColModsC1 : kIm2ColModsC2; - int OH = (dim.IH - dim.K) / dim.S + 1; - int OW = (dim.IW - dim.K) / dim.S + 1; - if (OH != m.OH || OW != m.OW || dim.IC != m.IC || dim.IH != m.IH || dim.IW != m.IW || dim.OC != m.OC - || dim.K != m.K || dim.S != m.S) { - fprintf(stderr, "backward gemm-fast bench: dimensions must match NMMO3 layer %d (use --layer %d)\n", - layer, layer); - return 1; - } - - ConvWeights cw{}; - conv_init(&cw, dim.IC, dim.OC, dim.K, dim.S, dim.IH, dim.IW, false); - - Allocator param_alloc{}; - conv_reg_params(&cw, ¶m_alloc); - if (alloc_create(¶m_alloc) != cudaSuccess) return 1; - uint64_t seed = 23; - conv_init_weights(&cw, &seed, 0); - cudaDeviceSynchronize(); - - int out_elems = dim.B * dim.OC * OH * OW; - int in_elems = dim.B * dim.IC * dim.IH * dim.IW; - int w_elems = (int)numel(cw.w.shape); - - Allocator act_g{}; - PrecisionTensor col{}, mm{}; - int col_rows = dim.B * OH * OW; - int col_cols = dim.IC * dim.K * dim.K; - col = {.shape = {col_rows, col_cols}}; - mm = {.shape = {col_rows, dim.OC}}; - PrecisionTensor saved_in{}, grad_out{}, wgrad_g{}; - saved_in = {.shape = {dim.B, dim.IC, dim.IH, dim.IW}}; - grad_out = {.shape = {(int64_t)out_elems}}; - wgrad_g = {.shape = {cw.w.shape[0], cw.w.shape[1]}}; - PrecisionTensor dinput_g{}; - dinput_g = {.shape = {dim.B, dim.IC, dim.IH, dim.IW}}; - alloc_register(&act_g, &col); - alloc_register(&act_g, &mm); - alloc_register(&act_g, &saved_in); - alloc_register(&act_g, &grad_out); - alloc_register(&act_g, &wgrad_g); - alloc_register(&act_g, &dinput_g); - if (alloc_create(&act_g) != cudaSuccess) return 1; - - std::vector hs(in_elems), hg(out_elems); - fill_rand_host(hs.data(), in_elems, 101u); - fill_rand_host(hg.data(), out_elems, 202u); - copy_fp32_h2d(hs.data(), saved_in.data, in_elems); - copy_fp32_h2d(hg.data(), grad_out.data, out_elems); - - Allocator acts{}, grads{}; - ConvActivations ca{}; - conv_reg_train(&cw, &ca, &acts, &grads, dim.B, n3_cudnn_dtype()); - if (alloc_create(&acts) != cudaSuccess || alloc_create(&grads) != cudaSuccess) return 1; - cudaMemcpy(ca.saved_input.data, saved_in.data, (size_t)in_elems * sizeof(precision_t), cudaMemcpyDeviceToDevice); - cudaMemcpy(ca.grad.data, grad_out.data, (size_t)out_elems * sizeof(precision_t), cudaMemcpyDeviceToDevice); - - cudaStream_t stream = 0; - - cudaMemset(wgrad_g.data, 0, (size_t)w_elems * sizeof(precision_t)); - cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); - gemm_conv_backward(&cw.w, saved_in.data, grad_out.data, wgrad_g.data, dinput_g.data, col.data, mm.data, - dim.B, dim.IC, dim.IH, dim.IW, dim.OC, dim.K, dim.S, OH, OW, stream); - cudaDeviceSynchronize(); - std::vector hwg_slow, hdi_slow; - copy_precision_d2h(wgrad_g.data, w_elems, &hwg_slow); - copy_precision_d2h(dinput_g.data, in_elems, &hdi_slow); - - cudaMemset(wgrad_g.data, 0, (size_t)w_elems * sizeof(precision_t)); - cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); - gemm_conv_backward_fast(&cw.w, saved_in.data, grad_out.data, wgrad_g.data, dinput_g.data, col.data, mm.data, - dim.B, m, stream); - cudaDeviceSynchronize(); - std::vector hwg_fast, hdi_fast; - copy_precision_d2h(wgrad_g.data, w_elems, &hwg_fast); - copy_precision_d2h(dinput_g.data, in_elems, &hdi_fast); - - float d_w_sg, m_w_sg, d_i_sg, m_i_sg; - stats_diff(hwg_slow.data(), hwg_fast.data(), w_elems, &d_w_sg, &m_w_sg); - stats_diff(hdi_slow.data(), hdi_fast.data(), in_elems, &d_i_sg, &m_i_sg); - printf(" wgrad max |diff| (gemm vs gemm_fast): %.6g mean |diff|: %.6g\n", d_w_sg, m_w_sg); - printf(" d_input max |diff| (gemm vs gemm_fast): %.6g mean |diff|: %.6g\n", d_i_sg, m_i_sg); - - cudaMemset(ca.wgrad.data, 0, (size_t)w_elems * sizeof(precision_t)); - cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); - conv_backward(&cw, &ca, dinput_g.data, dim.B, stream); - cudaDeviceSynchronize(); - std::vector hwg_c, hdi_c; - copy_precision_d2h(ca.wgrad.data, w_elems, &hwg_c); - copy_precision_d2h(dinput_g.data, in_elems, &hdi_c); - float d_w_sc, m_w_sc, d_i_sc, m_i_sc; - stats_diff(hwg_slow.data(), hwg_c.data(), w_elems, &d_w_sc, &m_w_sc); - stats_diff(hdi_slow.data(), hdi_c.data(), in_elems, &d_i_sc, &m_i_sc); - printf(" wgrad max |diff| (gemm vs cudnn): %.6g mean |diff|: %.6g\n", d_w_sc, m_w_sc); - printf(" d_input max |diff| (gemm vs cudnn): %.6g mean |diff|: %.6g\n", d_i_sc, m_i_sc); - - auto run_gemm = [&](cudaStream_t s) { - cudaMemset(wgrad_g.data, 0, (size_t)w_elems * sizeof(precision_t)); - cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); - gemm_conv_backward(&cw.w, saved_in.data, grad_out.data, wgrad_g.data, dinput_g.data, col.data, mm.data, - dim.B, dim.IC, dim.IH, dim.IW, dim.OC, dim.K, dim.S, OH, OW, s); - }; - auto run_gemm_fast = [&](cudaStream_t s) { - cudaMemset(wgrad_g.data, 0, (size_t)w_elems * sizeof(precision_t)); - cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); - gemm_conv_backward_fast(&cw.w, saved_in.data, grad_out.data, wgrad_g.data, dinput_g.data, col.data, mm.data, - dim.B, m, s); - }; - auto run_cudnn = [&](cudaStream_t s) { - cudaMemset(ca.wgrad.data, 0, (size_t)w_elems * sizeof(precision_t)); - cudaMemset(dinput_g.data, 0, (size_t)in_elems * sizeof(precision_t)); - conv_backward(&cw, &ca, dinput_g.data, dim.B, s); - }; - - float ms_g = time_kernel_ms(stream, run_gemm, warmup, iters); - float ms_gf = time_kernel_ms(stream, run_gemm_fast, warmup, iters); - float ms_c = time_kernel_ms(stream, run_cudnn, warmup, iters); - printf(" gemm_conv_backward: %8.4f us/iter\n", ms_g * 1000.0f); - printf(" gemm_conv_backward_fast: %8.4f us/iter (%.2fx vs gemm_conv_backward)\n", ms_gf * 1000.0f, - ms_g / ms_gf); - printf(" conv_backward (cudnn): %8.4f us/iter (%.2fx vs gemm, %.2fx vs gemm_fast)\n", ms_c * 1000.0f, - ms_g / ms_c, ms_gf / ms_c); - - alloc_free(¶m_alloc); - alloc_free(&act_g); - alloc_free(&acts); - alloc_free(&grads); - return 0; -} - -// gemm_conv_forward vs gemm_conv_forward_fast; relu on/off (NMMO3 layer geometry only). -static int run_gemm_fast_fwd_bench(const BenchDims& dim, int layer, int warmup, int iters) { - const Im2ColFastMods& m = (layer == 1) ? kIm2ColModsC1 : kIm2ColModsC2; - - ConvWeights cw{}; - conv_init(&cw, dim.IC, dim.OC, dim.K, dim.S, dim.IH, dim.IW, dim.relu); - Allocator param_alloc{}; - conv_reg_params(&cw, ¶m_alloc); - if (alloc_create(¶m_alloc) != cudaSuccess) return 1; - uint64_t seed = 55; - conv_init_weights(&cw, &seed, 0); - cudaDeviceSynchronize(); - - int OH = cw.OH; - int OW = cw.OW; - int out_elems = dim.B * dim.OC * OH * OW; - int in_elems = dim.B * dim.IC * dim.IH * dim.IW; - int col_rows = dim.B * OH * OW; - int col_cols = dim.IC * dim.K * dim.K; - - Allocator act{}; - PrecisionTensor out_s{}, out_f{}, col{}, mm{}, input{}; - out_s = {.shape = {(int64_t)out_elems}}; - out_f = {.shape = {(int64_t)out_elems}}; - col = {.shape = {col_rows, col_cols}}; - mm = {.shape = {col_rows, dim.OC}}; - input = {.shape = {dim.B, dim.IC, dim.IH, dim.IW}}; - alloc_register(&act, &out_s); - alloc_register(&act, &out_f); - alloc_register(&act, &col); - alloc_register(&act, &mm); - alloc_register(&act, &input); - if (alloc_create(&act) != cudaSuccess) return 1; - - std::vector hin(in_elems); - fill_rand_host(hin.data(), in_elems, 77u); - copy_fp32_h2d(hin.data(), input.data, in_elems); - cudaDeviceSynchronize(); - - cudaStream_t stream = 0; - - for (int ri = 0; ri < 2; ++ri) { - bool use_relu = (ri == 1); - gemm_conv_forward(&cw.w, &cw.b, input.data, out_s.data, col.data, mm.data, dim.B, dim.IC, dim.IH, - dim.IW, dim.OC, dim.K, dim.S, OH, OW, use_relu, stream); - cudaDeviceSynchronize(); - gemm_conv_forward_fast(&cw.w, &cw.b, input.data, out_f.data, col.data, mm.data, dim.B, m, use_relu, - stream); - cudaDeviceSynchronize(); - std::vector hs, hf; - copy_precision_d2h(out_s.data, out_elems, &hs); - copy_precision_d2h(out_f.data, out_elems, &hf); - float mx, mn; - stats_diff(hs.data(), hf.data(), out_elems, &mx, &mn); - printf(" relu=%d max |diff| slow vs fast: %.6g mean |diff|: %.6g\n", (int)use_relu, mx, mn); - - float ms_slow = time_kernel_ms( - stream, - [&](cudaStream_t s) { - gemm_conv_forward(&cw.w, &cw.b, input.data, out_s.data, col.data, mm.data, dim.B, dim.IC, - dim.IH, dim.IW, dim.OC, dim.K, dim.S, OH, OW, use_relu, s); - }, - warmup, iters); - float ms_fast = time_kernel_ms( - stream, - [&](cudaStream_t s) { - gemm_conv_forward_fast(&cw.w, &cw.b, input.data, out_f.data, col.data, mm.data, dim.B, m, - use_relu, s); - }, - warmup, iters); - printf(" relu=%d gemm_conv_forward: %8.4f us/iter\n", (int)use_relu, ms_slow * 1000.0f); - printf(" relu=%d gemm_conv_forward_fast: %8.4f us/iter (%.2fx vs slow)\n", (int)use_relu, - ms_fast * 1000.0f, ms_slow / ms_fast); - } - - alloc_free(¶m_alloc); - alloc_free(&act); - return 0; -} - -int main(int argc, char** argv) { - int B = 1024; - int layer = 1; - int warmup = 50; - int iters = 200; - bool do_fwd = true; - bool do_bwd = true; - bool cudnn_save = true; - bool do_wgrad_breakdown = false; - bool do_im2col_bench = false; - bool do_gemm_fast_bench = false; - bool do_bwd_dinput_bench = false; - bool do_gemm_bwd_fast_bench = false; - for (int i = 1; i < argc; ++i) { - if (strcmp(argv[i], "-B") == 0 && i + 1 < argc) B = atoi(argv[++i]); - else if (strcmp(argv[i], "--layer") == 0 && i + 1 < argc) layer = atoi(argv[++i]); - else if (strcmp(argv[i], "--warmup") == 0 && i + 1 < argc) warmup = atoi(argv[++i]); - else if (strcmp(argv[i], "--iters") == 0 && i + 1 < argc) iters = atoi(argv[++i]); - else if (strcmp(argv[i], "--forward-only") == 0) do_bwd = false; - else if (strcmp(argv[i], "--backward-only") == 0) do_fwd = false; - else if (strcmp(argv[i], "--no-cudnn-save-input") == 0) cudnn_save = false; - else if (strcmp(argv[i], "--wgrad-breakdown-only") == 0 - || strcmp(argv[i], "--filter-bwd-only") == 0) { - do_wgrad_breakdown = true; - do_fwd = false; - do_bwd = false; - do_bwd_dinput_bench = false; - do_gemm_bwd_fast_bench = false; - } else if (strcmp(argv[i], "--wgrad-breakdown") == 0 || strcmp(argv[i], "--filter-bwd") == 0) { - do_wgrad_breakdown = true; - } else if (strcmp(argv[i], "--im2col-bench-only") == 0) { - do_im2col_bench = true; - do_fwd = false; - do_bwd = false; - do_wgrad_breakdown = false; - do_bwd_dinput_bench = false; - do_gemm_bwd_fast_bench = false; - } else if (strcmp(argv[i], "--im2col-bench") == 0) { - do_im2col_bench = true; - } else if (strcmp(argv[i], "--gemm-fast-bench-only") == 0) { - do_gemm_fast_bench = true; - do_fwd = false; - do_bwd = false; - do_wgrad_breakdown = false; - do_im2col_bench = false; - do_bwd_dinput_bench = false; - do_gemm_bwd_fast_bench = false; - } else if (strcmp(argv[i], "--gemm-fast-bench") == 0) { - do_gemm_fast_bench = true; - } else if (strcmp(argv[i], "--gemm-bwd-fast-bench-only") == 0) { - do_gemm_bwd_fast_bench = true; - do_fwd = false; - do_bwd = false; - do_wgrad_breakdown = false; - do_im2col_bench = false; - do_gemm_fast_bench = false; - do_bwd_dinput_bench = false; - } else if (strcmp(argv[i], "--gemm-bwd-fast-bench") == 0) { - do_gemm_bwd_fast_bench = true; - } else if (strcmp(argv[i], "--bwd-dinput-bench-only") == 0) { - do_bwd_dinput_bench = true; - do_fwd = false; - do_bwd = false; - do_wgrad_breakdown = false; - do_im2col_bench = false; - do_gemm_fast_bench = false; - do_gemm_bwd_fast_bench = false; - } else if (strcmp(argv[i], "--bwd-dinput-bench") == 0) { - do_bwd_dinput_bench = true; - } else if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) { - printf("Usage: %s [options]\n", argv[0]); - printf(" -B N batch size (default 1024)\n"); - printf(" --layer 1|2 NMMO3 conv1 or conv2 sizes (default 1)\n"); - printf(" --warmup N timing warmup runs (default 50)\n"); - printf(" --iters N timed iterations (default 200)\n"); - printf(" --forward-only only forward pass\n"); - printf(" --backward-only only backward (identity activation)\n"); - printf(" --no-cudnn-save-input omit cudnn memcpy to saved_input (forward timing)\n"); - printf(" --filter-bwd / --wgrad-breakdown also bench ∂W: gemm (nchw+im2col+mm_tn) vs cudnn BackwardFilter\n"); - printf(" --filter-bwd-only / --wgrad-breakdown-only only that ∂W bench\n"); - printf(" --im2col-bench also bench im2col_kernel vs im2col_kernel_fast\n"); - printf(" --im2col-bench-only only that im2col bench\n"); - printf(" --gemm-fast-bench also bench gemm_conv_forward vs gemm_conv_forward_fast (relu 0/1)\n"); - printf(" --gemm-fast-bench-only only that bench\n"); - printf(" --bwd-dinput-bench also bench ∂W-only vs full bwd (gemm vs cudnn)\n"); - printf(" --bwd-dinput-bench-only only that bench\n"); - printf(" --gemm-bwd-fast-bench also bench full bwd: gemm vs gemm_fast vs cudnn (NMMO3 layer)\n"); - printf(" --gemm-bwd-fast-bench-only only that bench\n"); - printf(" (script) --float / --fp32 compile fp32 (default)\n"); - printf(" (script) --bf16 / --half compile bf16 (matches default native backend)\n"); - return 0; - } - } - - BenchDims dim{}; - if (layer == 1) dims_conv1(&dim, B); - else if (layer == 2) dims_conv2(&dim, B); - else { - fprintf(stderr, "layer must be 1 or 2\n"); - return 1; - } - - int OH = (dim.IH - dim.K) / dim.S + 1; - int OW = (dim.IW - dim.K) / dim.S + 1; - printf("bench_conv_gemm_vs_cudnn B=%d layer=%d IC=%d OC=%d %dx%d K=%d S=%d -> %dx%d relu=%d", - dim.B, layer, dim.IC, dim.OC, dim.IH, dim.IW, dim.K, dim.S, OH, OW, (int)dim.relu); -#ifdef PRECISION_FLOAT - printf(" precision=fp32\n"); -#else - printf(" precision=bf16\n"); -#endif - - if (do_fwd) { - printf("\n--- forward (gemm baseline vs cudnn) ---\n"); - if (run_forward(dim, warmup, iters, cudnn_save)) return 1; - } - if (do_bwd) { - BenchDims bd = dim; - bd.relu = false; - printf("\n--- backward (identity conv; gemm vs cudnn) ---\n"); - printf(" (relu ignored: identity conv so cudnn bwd matches gemm without ReLU mask)\n"); - if (run_backward(bd, warmup, iters)) return 1; - } - if (do_wgrad_breakdown) { - printf("\n--- filter backward (∂W): gemm path vs cudnnConvolutionBackwardFilter ---\n"); - if (run_filter_backward_bench(dim, warmup, iters)) return 1; - } - if (do_im2col_bench) { - printf("\n--- im2col_kernel vs im2col_kernel_fast ---\n"); - if (run_im2col_bench(dim, warmup, iters)) return 1; - } - if (do_gemm_fast_bench) { - printf("\n--- gemm_conv_forward vs gemm_conv_forward_fast (relu off/on) ---\n"); - if (run_gemm_fast_fwd_bench(dim, layer, warmup, iters)) return 1; - } - if (do_bwd_dinput_bench) { - BenchDims bd = dim; - bd.relu = false; - printf("\n--- backward: ∂W-only (d_input=null) vs full (+d_input); gemm vs cudnn ---\n"); - if (run_backward_input_grad_bench(bd, warmup, iters)) return 1; - } - if (do_gemm_bwd_fast_bench) { - BenchDims bd = dim; - bd.relu = false; - printf("\n--- backward: gemm_conv_backward vs gemm_conv_backward_fast vs conv_backward ---\n"); - printf(" (NMMO3 geometry; use --layer 1 or 2)\n"); - if (run_backward_gemm_fast_bench(bd, layer, warmup, iters)) return 1; - } - return 0; -} diff --git a/tests/bench_conv_gemm_vs_cudnn.sh b/tests/bench_conv_gemm_vs_cudnn.sh deleted file mode 100755 index 91293f3ffe..0000000000 --- a/tests/bench_conv_gemm_vs_cudnn.sh +++ /dev/null @@ -1,72 +0,0 @@ -#!/usr/bin/env bash -# Build and run gemm (im2col+cuBLAS) vs cuDNN conv benchmark. -# Default: fp32. Use --bf16 / --half for bf16 (matches native backend without --float). -# -# im2col vs im2col_kernel_fast (correctness + timing), same layer sizes as conv bench: -# ./tests/bench_conv_gemm_vs_cudnn.sh --layer 1 --im2col-bench-only -# ./tests/bench_conv_gemm_vs_cudnn.sh --bf16 --layer 2 --im2col-bench -# gemm slow vs fast forward (relu 0/1), NMMO3 layer sizes: -# ./tests/bench_conv_gemm_vs_cudnn.sh --layer 1 --gemm-fast-bench-only -# ∂W-only vs full backward (gemm vs cudnn): -# ./tests/bench_conv_gemm_vs_cudnn.sh --layer 1 --bwd-dinput-bench-only -# gemm vs gemm_fast vs cudnn full backward: -# ./tests/bench_conv_gemm_vs_cudnn.sh --layer 1 --gemm-bwd-fast-bench-only -set -euo pipefail -ROOT="$(cd "$(dirname "$0")/.." && pwd)" -cd "$ROOT" - -CUDA_HOME="${CUDA_HOME:-${CUDA_PATH:-$(dirname "$(dirname "$(command -v nvcc)")")}}" -NVCC="${NVCC:-$CUDA_HOME/bin/nvcc}" -ARCH="${NVCC_ARCH:-native}" - -# Default fp32; --bf16 / --half drop -DPRECISION_FLOAT -PRECISION_FLAG="-DPRECISION_FLOAT" -USER_ARGS=() -for arg in "$@"; do - case "$arg" in - --bf16|--half) PRECISION_FLAG="" ;; - --float|--fp32) PRECISION_FLAG="-DPRECISION_FLOAT" ;; - *) USER_ARGS+=("$arg") ;; - esac -done - -CUDNN_IFLAG="" -CUDNN_LFLAG="" -for dir in "$CUDA_HOME/include" /usr/local/cuda/include /usr/include; do - if [[ -f "$dir/cudnn.h" ]]; then - CUDNN_IFLAG="-I$dir" - break - fi -done -for dir in "$CUDA_HOME/lib64" "$CUDA_HOME/lib" /usr/lib/x86_64-linux-gnu; do - if [[ -f "$dir/libcudnn.so" ]] || [[ -f "$dir/libcudnn.dylib" ]]; then - CUDNN_LFLAG="-L$dir" - break - fi -done -if [[ -z "$CUDNN_IFLAG" ]]; then - CUDNN_IFLAG=$(python3 -c "import nvidia.cudnn, os; print('-I' + os.path.join(nvidia.cudnn.__path__[0], 'include'))" 2>/dev/null || true) -fi -if [[ -z "$CUDNN_LFLAG" ]]; then - CUDNN_LFLAG=$(python3 -c "import nvidia.cudnn, os; print('-L' + os.path.join(nvidia.cudnn.__path__[0], 'lib'))" 2>/dev/null || true) -fi - -OUT="${ROOT}/tests/bench_conv_gemm_vs_cudnn" -if [[ -n "$PRECISION_FLAG" ]]; then - echo "nvcc $ARCH $OUT (fp32)" -else - echo "nvcc $ARCH $OUT (bf16)" -fi -"$NVCC" -O2 -std=c++17 "-arch=$ARCH" \ - "-I${ROOT}/src" \ - "-I${CUDA_HOME}/include" \ - $CUDNN_IFLAG \ - $PRECISION_FLAG \ - "${ROOT}/tests/bench_conv_gemm_vs_cudnn.cu" \ - -o "$OUT" \ - $CUDNN_LFLAG \ - -L"${CUDA_HOME}/lib64" -L"${CUDA_HOME}/lib" \ - -lcublas -lcudnn -lcurand - -echo "Running: $OUT ${USER_ARGS[*]}" -exec "$OUT" "${USER_ARGS[@]}" diff --git a/tests/tune_cublas_gemm.cu b/tests/tune_cublas_gemm.cu deleted file mode 100644 index 35274e3469..0000000000 --- a/tests/tune_cublas_gemm.cu +++ /dev/null @@ -1,296 +0,0 @@ -// Sweep cublasGemmEx algorithms for the same layout as puf_mm_tn (see kernels.cu). -// Default (M,N,K) matches NMMO3 conv1 ∂W GEMM at B=1024: M=OC, N=IC*K*K, K=B*OH*OW. -// -// Build: see tests/tune_cublas_gemm.sh (-DPRECISION_FLOAT => fp32; omit => bf16) - -#include -#include -#include - -#include -#include -#include -#include -#include - -#ifndef CUBLAS_GEMM_ALGO0 -#define CUBLAS_GEMM_ALGO0 ((cublasGemmAlgo_t)0) -#endif - -#ifdef PRECISION_FLOAT -typedef float precision_t; -static constexpr cudaDataType_t kCudaPrec = CUDA_R_32F; -static constexpr cublasComputeType_t kCompute = CUBLAS_COMPUTE_32F; -#else -typedef __nv_bfloat16 precision_t; -static constexpr cudaDataType_t kCudaPrec = CUDA_R_16BF; -static constexpr cublasComputeType_t kCompute = CUBLAS_COMPUTE_32F; -#endif - -static void check_cuda(cudaError_t e) { - if (e != cudaSuccess) { - fprintf(stderr, "cuda: %s\n", cudaGetErrorString(e)); - exit(1); - } -} - -static const char* cublas_str(cublasStatus_t s) { - switch (s) { - case CUBLAS_STATUS_SUCCESS: return "SUCCESS"; - case CUBLAS_STATUS_NOT_INITIALIZED: return "NOT_INITIALIZED"; - case CUBLAS_STATUS_ALLOC_FAILED: return "ALLOC_FAILED"; - case CUBLAS_STATUS_INVALID_VALUE: return "INVALID_VALUE"; - case CUBLAS_STATUS_ARCH_MISMATCH: return "ARCH_MISMATCH"; - case CUBLAS_STATUS_MAPPING_ERROR: return "MAPPING_ERROR"; - case CUBLAS_STATUS_EXECUTION_FAILED: return "EXECUTION_FAILED"; - case CUBLAS_STATUS_INTERNAL_ERROR: return "INTERNAL_ERROR"; - case CUBLAS_STATUS_NOT_SUPPORTED: return "NOT_SUPPORTED"; - default: return "OTHER"; - } -} - -// Same lda/ldb rules as cublasGemmExDense in kernels.cu -static inline void gemm_ex_like_puf_mm_tn(cublasHandle_t h, int M, int N, int K, const precision_t* A, - const precision_t* B, precision_t* C, cublasGemmAlgo_t algo, cudaStream_t stream) { - const float alpha = 1.0f, beta = 0.0f; - cublasOperation_t op_a = CUBLAS_OP_T; - cublasOperation_t op_b = CUBLAS_OP_N; - int lda = (op_a == CUBLAS_OP_N) ? K : M; - int ldb = (op_b == CUBLAS_OP_N) ? N : K; - cublasSetStream(h, stream); - cublasStatus_t st = cublasGemmEx(h, op_b, op_a, N, M, K, &alpha, B, kCudaPrec, ldb, A, kCudaPrec, lda, - &beta, C, kCudaPrec, N, kCompute, algo); - if (st != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, "cublasGemmEx failed: %s\n", cublas_str(st)); - exit(1); - } -} - -template -static float time_ms(cudaStream_t stream, F fn, int warmup, int iters) { - for (int i = 0; i < warmup; ++i) { - fn(stream); - cudaStreamSynchronize(stream); - } - cudaEvent_t e0, e1; - cudaEventCreate(&e0); - cudaEventCreate(&e1); - cudaEventRecord(e0, stream); - for (int i = 0; i < iters; ++i) fn(stream); - cudaEventRecord(e1, stream); - cudaEventSynchronize(e1); - float ms = 0.f; - cudaEventElapsedTime(&ms, e0, e1); - cudaEventDestroy(e0); - cudaEventDestroy(e1); - return ms / (float)iters; -} - -static float max_abs_diff_fp32(const float* a, const float* b, int n) { - float m = 0.f; - for (int i = 0; i < n; ++i) m = fmaxf(m, fabsf(a[i] - b[i])); - return m; -} - -int main(int argc, char** argv) { - int Bbatch = 1024; - int M = 128, N = 1475, Kdim = 12288; - int warmup = 20, iters = 100; - for (int i = 1; i < argc; ++i) { - if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) { - printf( - "Usage: %s [options]\n" - " Tune cublasGemmEx for the same call pattern as puf_mm_tn:\n" - " C = B * A with op(B)=N, op(A)=T, sizes (N,M,K) -> cublasGemmEx(..., N,M,K,...)\n" - "Options:\n" - " --layer 1|2 NMMO3 conv sizes for (M,N,K) at given -B (default --layer 1)\n" - " -B N batch (default 1024), used with --layer\n" - " -M,-N,-K override matrix dims (after --layer, if set)\n" - " --warmup N (default 20)\n" - " --iters N (default 100)\n" - " (M,N,K) default without --layer: 128 1475 12288 (conv1 ∂W @ B=1024)\n", - argv[0]); - return 0; - } else if (strcmp(argv[i], "-B") == 0 && i + 1 < argc) Bbatch = atoi(argv[++i]); - else if (strcmp(argv[i], "--warmup") == 0 && i + 1 < argc) warmup = atoi(argv[++i]); - else if (strcmp(argv[i], "--iters") == 0 && i + 1 < argc) iters = atoi(argv[++i]); - else if (strcmp(argv[i], "--layer") == 0 && i + 1 < argc) { - int L = atoi(argv[++i]); - if (L == 1) { - const int OH = 3, OW = 4, IC = 59, OC = 128, Kk = 5; - Kdim = Bbatch * OH * OW; - N = IC * Kk * Kk; - M = OC; - } else if (L == 2) { - const int OH = 1, OW = 2, IC = 128, OC = 128, Kk = 3; - Kdim = Bbatch * OH * OW; - N = IC * Kk * Kk; - M = OC; - } else { - fprintf(stderr, "layer must be 1 or 2\n"); - return 1; - } - } else if (strcmp(argv[i], "-M") == 0 && i + 1 < argc) M = atoi(argv[++i]); - else if (strcmp(argv[i], "-N") == 0 && i + 1 < argc) N = atoi(argv[++i]); - else if (strcmp(argv[i], "-K") == 0 && i + 1 < argc) Kdim = atoi(argv[++i]); - } - - int ldc = N; - int lenA = Kdim * M; - int lenB = Kdim * N; - int lenC = M * N; - - printf("tune_cublas_gemm M=%d N=%d K=%d (puf_mm_tn logical sizes)", M, N, Kdim); -#ifdef PRECISION_FLOAT - printf(" dtype=fp32\n"); -#else - printf(" dtype=bf16 compute=CUBLAS_COMPUTE_32F\n"); -#endif - - precision_t *dA, *dB, *dC, *dRef; - check_cuda(cudaMalloc(&dA, (size_t)lenA * sizeof(precision_t))); - check_cuda(cudaMalloc(&dB, (size_t)lenB * sizeof(precision_t))); - check_cuda(cudaMalloc(&dC, (size_t)lenC * sizeof(precision_t))); - check_cuda(cudaMalloc(&dRef, (size_t)lenC * sizeof(precision_t))); - - std::vector hAf(lenA), hBf(lenB); - unsigned seed = 12345; - for (int i = 0; i < lenA; ++i) { - seed = seed * 1103515245u + 12345u; - hAf[i] = (((seed >> 16) & 0x7fff) / 16384.0f - 1.0f) * 0.25f; - } - for (int i = 0; i < lenB; ++i) { - seed = seed * 1103515245u + 12345u; - hBf[i] = (((seed >> 16) & 0x7fff) / 16384.0f - 1.0f) * 0.25f; - } -#ifdef PRECISION_FLOAT - check_cuda(cudaMemcpy(dA, hAf.data(), (size_t)lenA * sizeof(float), cudaMemcpyHostToDevice)); - check_cuda(cudaMemcpy(dB, hBf.data(), (size_t)lenB * sizeof(float), cudaMemcpyHostToDevice)); -#else - std::vector hA(lenA), hB(lenB); - for (int i = 0; i < lenA; ++i) hA[i] = __float2bfloat16(hAf[i]); - for (int i = 0; i < lenB; ++i) hB[i] = __float2bfloat16(hBf[i]); - check_cuda(cudaMemcpy(dA, hA.data(), (size_t)lenA * sizeof(precision_t), cudaMemcpyHostToDevice)); - check_cuda(cudaMemcpy(dB, hB.data(), (size_t)lenB * sizeof(precision_t), cudaMemcpyHostToDevice)); -#endif - - cublasHandle_t handle; - cublasCreate(&handle); - - std::vector ref_host(lenC); - cudaStream_t stream = 0; - - auto run_ref = [&]() { - cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH); - gemm_ex_like_puf_mm_tn(handle, M, N, Kdim, dA, dB, dRef, CUBLAS_GEMM_DEFAULT, stream); - }; - run_ref(); - cudaDeviceSynchronize(); -#ifdef PRECISION_FLOAT - check_cuda(cudaMemcpy(ref_host.data(), dRef, (size_t)lenC * sizeof(float), cudaMemcpyDeviceToHost)); -#else - std::vector ref_bf(lenC); - check_cuda(cudaMemcpy(ref_bf.data(), dRef, (size_t)lenC * sizeof(precision_t), cudaMemcpyDeviceToHost)); - for (int i = 0; i < lenC; ++i) ref_host[i] = __bfloat162float(ref_bf[i]); -#endif - - struct Row { - cublasGemmAlgo_t algo; - cublasMath_t math; - float ms; - float max_diff; - bool ok; - }; - std::vector rows; - - const cublasMath_t math_modes[] = {CUBLAS_DEFAULT_MATH, CUBLAS_TENSOR_OP_MATH}; - const char* math_names[] = {"DEFAULT_MATH", "TENSOR_OP_MATH"}; - - std::vector algos; - algos.push_back(CUBLAS_GEMM_DEFAULT); -#ifdef CUBLAS_GEMM_DEFAULT_TENSOR_OP - algos.push_back(CUBLAS_GEMM_DEFAULT_TENSOR_OP); -#endif - for (int a = 0; a <= 23; ++a) - algos.push_back((cublasGemmAlgo_t)((int)CUBLAS_GEMM_ALGO0 + a)); - - for (size_t mi = 0; mi < sizeof(math_modes) / sizeof(math_modes[0]); ++mi) { - cublasSetMathMode(handle, math_modes[mi]); - for (cublasGemmAlgo_t algo : algos) { - cublasStatus_t st = cublasSetStream(handle, stream); - (void)st; - const float alpha = 1.f, beta = 0.f; - cublasOperation_t op_a = CUBLAS_OP_T; - cublasOperation_t op_b = CUBLAS_OP_N; - int lda = (op_a == CUBLAS_OP_N) ? Kdim : M; - int ldb = (op_b == CUBLAS_OP_N) ? N : Kdim; - st = cublasGemmEx(handle, op_b, op_a, N, M, Kdim, &alpha, dB, kCudaPrec, ldb, dA, kCudaPrec, lda, - &beta, dC, kCudaPrec, N, kCompute, algo); - - if (st != CUBLAS_STATUS_SUCCESS) { - rows.push_back({algo, math_modes[mi], 0.f, 0.f, false}); - continue; - } - cudaDeviceSynchronize(); - - float ms = time_ms( - stream, - [&](cudaStream_t s) { - cublasGemmEx(handle, op_b, op_a, N, M, Kdim, &alpha, dB, kCudaPrec, ldb, dA, kCudaPrec, lda, - &beta, dC, kCudaPrec, N, kCompute, algo); - }, - warmup, iters); - -#ifdef PRECISION_FLOAT - std::vector hC(lenC); - check_cuda(cudaMemcpy(hC.data(), dC, (size_t)lenC * sizeof(float), cudaMemcpyDeviceToHost)); - float md = max_abs_diff_fp32(ref_host.data(), hC.data(), lenC); -#else - std::vector hCb(lenC); - check_cuda(cudaMemcpy(hCb.data(), dC, (size_t)lenC * sizeof(precision_t), cudaMemcpyDeviceToHost)); - std::vector hC(lenC); - for (int i = 0; i < lenC; ++i) hC[i] = __bfloat162float(hCb[i]); - float md = max_abs_diff_fp32(ref_host.data(), hC.data(), lenC); -#endif - rows.push_back({algo, math_modes[mi], ms, md, true}); - } - } - - cublasDestroy(handle); - cudaFree(dA); - cudaFree(dB); - cudaFree(dC); - cudaFree(dRef); - - float best_ms = 1e30f; - int best_i = -1; - for (size_t i = 0; i < rows.size(); ++i) { - if (!rows[i].ok) continue; - if (rows[i].ms < best_ms) { - best_ms = rows[i].ms; - best_i = (int)i; - } - } - - printf("\n%-6s %-22s %-8s %10s %12s %s\n", "algo", "math", "ok", "us/iter", "max|diff|", "note"); - printf( - "------ ---------------------- -------- ---------- ------------ ----\n"); - for (size_t i = 0; i < rows.size(); ++i) { - const Row& r = rows[i]; - const char* mn = (r.math == CUBLAS_DEFAULT_MATH) ? math_names[0] : math_names[1]; - if (!r.ok) { - printf("%-6d %-22s %-8s\n", (int)r.algo, mn, "no"); - continue; - } - const char* tag = ((int)i == best_i) ? "best" : ""; - printf("%-6d %-22s %-8s %10.4f %12.5g %s\n", (int)r.algo, mn, "yes", r.ms * 1000.0f, r.max_diff, tag); - } - if (best_i >= 0) { - printf("\nFastest OK: algo=%d math=%s (%.4f us/iter) max|diff|=%g vs DEFAULT/DEFAULT_MATH " - "reference\n", - (int)rows[best_i].algo, - rows[best_i].math == CUBLAS_DEFAULT_MATH ? math_names[0] : math_names[1], best_ms * 1000.0f, - rows[best_i].max_diff); - } - return 0; -} diff --git a/tests/tune_cublas_gemm.sh b/tests/tune_cublas_gemm.sh deleted file mode 100755 index ff419ba353..0000000000 --- a/tests/tune_cublas_gemm.sh +++ /dev/null @@ -1,36 +0,0 @@ -#!/usr/bin/env bash -# Build and run cublasGemmEx tuner (same layout as puf_mm_tn in kernels.cu). -# Default: fp32. Use --bf16 / --half for bf16. -set -euo pipefail -ROOT="$(cd "$(dirname "$0")/.." && pwd)" -cd "$ROOT" - -CUDA_HOME="${CUDA_HOME:-${CUDA_PATH:-$(dirname "$(dirname "$(command -v nvcc)")")}}" -NVCC="${NVCC:-$CUDA_HOME/bin/nvcc}" -ARCH="${NVCC_ARCH:-native}" - -PRECISION_FLAG="-DPRECISION_FLOAT" -USER_ARGS=() -for arg in "$@"; do - case "$arg" in - --bf16|--half) PRECISION_FLAG="" ;; - --float|--fp32) PRECISION_FLAG="-DPRECISION_FLOAT" ;; - *) USER_ARGS+=("$arg") ;; - esac -done - -OUT="${ROOT}/tests/tune_cublas_gemm" -if [[ -n "$PRECISION_FLAG" ]]; then - echo "nvcc $ARCH $OUT (fp32)" -else - echo "nvcc $ARCH $OUT (bf16)" -fi -"$NVCC" -O2 -std=c++17 "-arch=$ARCH" \ - $PRECISION_FLAG \ - "${ROOT}/tests/tune_cublas_gemm.cu" \ - -o "$OUT" \ - -L"${CUDA_HOME}/lib64" -L"${CUDA_HOME}/lib" \ - -lcublas - -echo "Running: $OUT ${USER_ARGS[*]}" -exec "$OUT" "${USER_ARGS[@]}" From 91b2454b6c9b8df522e80a2604d6dff80759abad Mon Sep 17 00:00:00 2001 From: jonah Date: Thu, 9 Apr 2026 08:13:15 -0700 Subject: [PATCH 5/9] revert --- pyproject.toml | 3 --- 1 file changed, 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f9754941ee..d8c18a0158 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,9 +27,6 @@ dependencies = [ "pybind11", ] -[[tool.uv.index]] -url = "https://download.pytorch.org/whl/cu128" - [project.scripts] puffer = "pufferlib.pufferl:main" From 694b7e734a82a11e234042f1cceffc79f3d2e88e Mon Sep 17 00:00:00 2001 From: jonah Date: Thu, 9 Apr 2026 08:23:46 -0700 Subject: [PATCH 6/9] wire in fast kernels --- src/ocean.cu | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/src/ocean.cu b/src/ocean.cu index 3d5fadf599..a122fd1e00 100644 --- a/src/ocean.cu +++ b/src/ocean.cu @@ -626,15 +626,13 @@ static PrecisionTensor nmmo3_encoder_forward(void* w, void* activations, Precisi n3_multihot_kernel<<>>( a->multihot.data, input.data, B, ew->obs_size); - gemm_conv_forward(&ew->conv1.w, &ew->conv1.b, a->multihot.data, a->conv1.out.data, - a->col1.data, a->mm1.data, B, N3_C1_IC, N3_MAP_H, N3_MAP_W, - N3_C1_OC, N3_C1_K, N3_C1_S, N3_C1_OH, N3_C1_OW, true, stream); + gemm_conv_forward_fast(&ew->conv1.w, &ew->conv1.b, a->multihot.data, a->conv1.out.data, + a->col1.data, a->mm1.data, B, kIm2ColModsC1, true, stream); if (a->conv1.saved_input.data) cudaMemcpyAsync(a->conv1.saved_input.data, a->multihot.data, (int64_t)B * N3_C1_IC * N3_MAP_H * N3_MAP_W * sizeof(precision_t), cudaMemcpyDeviceToDevice, stream); - gemm_conv_forward(&ew->conv2.w, &ew->conv2.b, a->conv1.out.data, a->conv2.out.data, - a->col2.data, a->mm2.data, B, N3_C2_IC, N3_C1_OH, N3_C1_OW, - N3_C2_OC, N3_C2_K, N3_C2_S, N3_C2_OH, N3_C2_OW, false, stream); + gemm_conv_forward_fast(&ew->conv2.w, &ew->conv2.b, a->conv1.out.data, a->conv2.out.data, + a->col2.data, a->mm2.data, B, kIm2ColModsC2, false, stream); if (a->conv2.saved_input.data) cudaMemcpyAsync(a->conv2.saved_input.data, a->conv1.out.data, (int64_t)B * N3_C2_IC * N3_C1_OH * N3_C1_OW * sizeof(precision_t), cudaMemcpyDeviceToDevice, stream); @@ -670,10 +668,9 @@ static void nmmo3_encoder_backward(void* w, void* activations, PrecisionTensor g n3_conv_bias_grad_nchw<<conv2.OC, 256, 0, stream>>>( a->conv2.bgrad.data, a->conv2.grad.data, B, ew->conv2.OC, ew->conv2.OH * ew->conv2.OW); - gemm_conv_backward(&ew->conv2.w, a->conv2.saved_input.data, a->conv2.grad.data, + gemm_conv_backward_fast(&ew->conv2.w, a->conv2.saved_input.data, a->conv2.grad.data, a->conv2.wgrad.data, a->conv1.grad.data, - a->col2.data, a->mm2.data, B, N3_C2_IC, N3_C1_OH, N3_C1_OW, - N3_C2_OC, N3_C2_K, N3_C2_S, N3_C2_OH, N3_C2_OW, stream); + a->col2.data, a->mm2.data, B, kIm2ColModsC2, stream); n3_relu_backward_kernel<<conv1.OC * ew->conv1.OH * ew->conv1.OW), BLOCK_SIZE, 0, stream>>>( a->conv1.grad.data, a->conv1.out.data, @@ -681,10 +678,9 @@ static void nmmo3_encoder_backward(void* w, void* activations, PrecisionTensor g n3_conv_bias_grad_nchw<<conv1.OC, 256, 0, stream>>>( a->conv1.bgrad.data, a->conv1.grad.data, B, ew->conv1.OC, ew->conv1.OH * ew->conv1.OW); - gemm_conv_backward(&ew->conv1.w, a->conv1.saved_input.data, a->conv1.grad.data, + gemm_conv_backward_fast(&ew->conv1.w, a->conv1.saved_input.data, a->conv1.grad.data, a->conv1.wgrad.data, NULL, - a->col1.data, a->mm1.data, B, N3_C1_IC, N3_MAP_H, N3_MAP_W, - N3_C1_OC, N3_C1_K, N3_C1_S, N3_C1_OH, N3_C1_OW, stream); + a->col1.data, a->mm1.data, B, kIm2ColModsC1, stream); // Embedding backward: scatter-add from concat gradient into float buffer, then cast int embed_n = N3_EMBED_VOCAB * N3_EMBED_DIM; From f2a07e6e1f2c0f577c05973e2fdcfc7897e3a323 Mon Sep 17 00:00:00 2001 From: jonah Date: Fri, 10 Apr 2026 07:21:31 -0700 Subject: [PATCH 7/9] multihot --- src/ocean.cu | 155 ++++++++++++++++++------------- tests/bench_gemm_conv_end2end.cu | 115 ++++++++++++++++++++++- tests/bench_gemm_conv_end2end.sh | 15 ++- 3 files changed, 215 insertions(+), 70 deletions(-) diff --git a/src/ocean.cu b/src/ocean.cu index a122fd1e00..9578f66997 100644 --- a/src/ocean.cu +++ b/src/ocean.cu @@ -24,8 +24,98 @@ static cudnnDataType_t n3_cudnn_dtype() { return (PRECISION_SIZE == 2) ? CUDNN_DATA_BFLOAT16 : CUDNN_DATA_FLOAT; } +struct FastDivMod { + uint32_t d_; + uint32_t M_; + uint32_t l_; + + __host__ FastDivMod(int d) { + d_ = d <= 0 ? 1u : (uint32_t)d; + uint32_t l = 0; + for (; l < 32; ++l) + if ((1u << l) >= d_) break; + l_ = l; + const uint64_t one = 1; + uint64_t m = ((one << 32) * ((one << l_) - d_)) / d_ + 1; + M_ = (uint32_t)m; + } + + __device__ __forceinline__ int div(int n) const { + uint32_t u = (uint32_t)n; // n must be >= 0 + uint32_t t = __umulhi(M_, u); + return (int)((t + u) >> l_); + } + + __device__ __forceinline__ int mod(int n) const { + return n - div(n) * (int)d_; + } + + __device__ __forceinline__ void divmod(int n, int& q, int& r) const { + q = div(n); + r = n - q * (int)d_; + } +}; + +struct Im2ColFastMods { + FastDivMod dm_col_w; + FastDivMod dm_oh_ow; + FastDivMod dm_ow; + FastDivMod dm_kk; + FastDivMod dm_k; + FastDivMod dm_oc; + FastDivMod dm_iw; + FastDivMod dm_ih; + FastDivMod dm_ic; + FastDivMod dm_s; + FastDivMod dm_n3_hw; + FastDivMod dm_n3_hwf; + FastDivMod dm_n3_w; + int total_no_batch; + int oh_ow; + int oc_spatial; + int col_cols; + int n3_hw; + int n3_hwf; + int n3_multihot_plane; + int IC, IH, IW, OC, K, S, OH, OW; + + __host__ Im2ColFastMods(int ic, int ih, int iw, int oc, int k, int s, int oh, int ow) + : dm_col_w(ic * k * k), dm_oh_ow(oh * ow), dm_ow(ow), dm_kk(k * k), dm_k(k), dm_oc(oc), + dm_iw(iw), dm_ih(ih), dm_ic(ic), dm_s(s), + dm_n3_hw(N3_MAP_H * N3_MAP_W), + dm_n3_hwf(N3_MAP_H * N3_MAP_W * N3_NFEAT), + dm_n3_w(N3_MAP_W), + total_no_batch((oh * ow) * (ic * k * k)), oh_ow(oh * ow), col_cols(ic * k * k), + oc_spatial(oc * oh * ow), n3_hw(N3_MAP_H * N3_MAP_W), n3_hwf(N3_MAP_H * N3_MAP_W * N3_NFEAT), + n3_multihot_plane(N3_MULTIHOT * N3_MAP_H * N3_MAP_W), + IC(ic), IH(ih), IW(iw), OC(oc), K(k), S(s), OH(oh), OW(ow) {} +}; + +static const Im2ColFastMods kIm2ColModsC1( + N3_C1_IC, N3_MAP_H, N3_MAP_W, N3_C1_OC, N3_C1_K, N3_C1_S, N3_C1_OH, N3_C1_OW); +static const Im2ColFastMods kIm2ColModsC2( + N3_C2_IC, N3_C1_OH, N3_C1_OW, N3_C2_OC, N3_C2_K, N3_C2_S, N3_C2_OH, N3_C2_OW); + // ---- NMMO3 kernels ---- +// One thread per (b, f, h, w): dm_n3_hwf splits idx -> (b, rem_hwf), dm_n3_hw -> (f, rem_sp), dm_n3_w -> (h, w). +__global__ void n3_multihot_kernel_fast( + precision_t* __restrict__ out, const precision_t* __restrict__ obs, int B, int obs_size, + const FastDivMod dm_n3_hwf, const FastDivMod dm_n3_hw, const FastDivMod dm_n3_w, int n3_hw, int n3_hwf, + int n3_multihot_plane) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= B * n3_hwf) return; + int b, rem_hwf; + dm_n3_hwf.divmod(idx, b, rem_hwf); + int f, rem_sp; + dm_n3_hw.divmod(rem_hwf, f, rem_sp); + int h, w; + dm_n3_w.divmod(rem_sp, h, w); + const precision_t* src = obs + (int64_t)b * obs_size + (int64_t)(h * N3_MAP_W + w) * N3_NFEAT; + precision_t* dst = out + (int64_t)b * n3_multihot_plane; + dst[(N3_OFFSETS[f] + (int)to_float(src[f])) * n3_hw + h * N3_MAP_W + w] = from_float(1.0f); +} + __global__ void n3_multihot_kernel( precision_t* __restrict__ out, const precision_t* __restrict__ obs, int B, int obs_size) { int idx = blockIdx.x * blockDim.x + threadIdx.x; @@ -232,67 +322,7 @@ __global__ void im2col_kernel( col[idx] = input[b * IC * IH * IW + ic * IH * IW + ih * IW + iw]; } -struct FastDivMod { - uint32_t d_; - uint32_t M_; - uint32_t l_; - __host__ FastDivMod(int d) { - d_ = d <= 0 ? 1u : (uint32_t)d; - uint32_t l = 0; - for (; l < 32; ++l) - if ((1u << l) >= d_) break; - l_ = l; - const uint64_t one = 1; - uint64_t m = ((one << 32) * ((one << l_) - d_)) / d_ + 1; - M_ = (uint32_t)m; - } - - __device__ __forceinline__ int div(int n) const { - uint32_t u = (uint32_t)n; // n must be >= 0 - uint32_t t = __umulhi(M_, u); - return (int)((t + u) >> l_); - } - - __device__ __forceinline__ int mod(int n) const { - return n - div(n) * (int)d_; - } - - __device__ __forceinline__ void divmod(int n, int& q, int& r) const { - q = div(n); - r = n - q * (int)d_; - } -}; - -struct Im2ColFastMods { - FastDivMod dm_col_w; - FastDivMod dm_oh_ow; - FastDivMod dm_ow; - FastDivMod dm_kk; - FastDivMod dm_k; - FastDivMod dm_oc; - FastDivMod dm_iw; - FastDivMod dm_ih; - FastDivMod dm_ic; - FastDivMod dm_s; - int total_no_batch; - int oh_ow; - int oc_spatial; - int col_cols; - int IC, IH, IW, OC, K, S, OH, OW; - - __host__ Im2ColFastMods(int ic, int ih, int iw, int oc, int k, int s, int oh, int ow) - : dm_col_w(ic * k * k), dm_oh_ow(oh * ow), dm_ow(ow), dm_kk(k * k), dm_k(k), dm_oc(oc), - dm_iw(iw), dm_ih(ih), dm_ic(ic), dm_s(s), - total_no_batch((oh * ow) * (ic * k * k)), oh_ow(oh * ow), col_cols(ic * k * k), - oc_spatial(oc * oh * ow), - IC(ic), IH(ih), IW(iw), OC(oc), K(k), S(s), OH(oh), OW(ow) {} -}; - -static const Im2ColFastMods kIm2ColModsC1( - N3_C1_IC, N3_MAP_H, N3_MAP_W, N3_C1_OC, N3_C1_K, N3_C1_S, N3_C1_OH, N3_C1_OW); -static const Im2ColFastMods kIm2ColModsC2( - N3_C2_IC, N3_C1_OH, N3_C1_OW, N3_C2_OC, N3_C2_K, N3_C2_S, N3_C2_OH, N3_C2_OW); __global__ void im2col_kernel_fast( const precision_t* __restrict__ input, precision_t* __restrict__ col, @@ -623,8 +653,9 @@ static PrecisionTensor nmmo3_encoder_forward(void* w, void* activations, Precisi if (a->saved_obs.data) puf_copy(&a->saved_obs, &input, stream); cudaMemsetAsync(a->multihot.data, 0, (int64_t)B * N3_MULTIHOT * N3_MAP_H * N3_MAP_W * sizeof(precision_t), stream); - n3_multihot_kernel<<>>( - a->multihot.data, input.data, B, ew->obs_size); + n3_multihot_kernel_fast<<>>(a->multihot.data, input.data, B, + ew->obs_size, kIm2ColModsC1.dm_n3_hwf, kIm2ColModsC1.dm_n3_hw, kIm2ColModsC1.dm_n3_w, kIm2ColModsC1.n3_hw, + kIm2ColModsC1.n3_hwf, kIm2ColModsC1.n3_multihot_plane); gemm_conv_forward_fast(&ew->conv1.w, &ew->conv1.b, a->multihot.data, a->conv1.out.data, a->col1.data, a->mm1.data, B, kIm2ColModsC1, true, stream); diff --git a/tests/bench_gemm_conv_end2end.cu b/tests/bench_gemm_conv_end2end.cu index d017c2f660..ad620885f9 100644 --- a/tests/bench_gemm_conv_end2end.cu +++ b/tests/bench_gemm_conv_end2end.cu @@ -1,5 +1,7 @@ // End-to-end: gemm conv (slow) vs gemm_fast vs cudnn — forward & backward timed separately, layers 1 & 2 (NMMO3). -// Build/run: tests/bench_gemm_conv_end2end.sh [--float|--bf16] +// Multihot microbench: n3_multihot_kernel (reference, not fast) vs n3_multihot_kernel_fast — correctness + timing, +// B in {1024..32768} powers of two (run_n3_multihot_bench_B). +// Build/run: tests/bench_gemm_conv_end2end.sh [--float|--bf16] [--conv-only|--multihot-only] #include @@ -305,16 +307,116 @@ static int run_layer(int layer, int warmup, int iters) { return 0; } +// Host copy of N3_OFFSETS (device __constant__ in ocean.cu) for valid multihot index ranges. +static const int kN3OffsetsHost[10] = {0, 4, 8, 25, 30, 33, 38, 43, 48, 55}; + +static int n3_feat_max_v(int f) { + int next = (f + 1 < 10) ? kN3OffsetsHost[f + 1] : N3_MULTIHOT; + return next - kN3OffsetsHost[f] - 1; +} + +static int run_n3_multihot_bench_B(int B, int warmup, int iters) { + const int obs_size = N3_MAP_SIZE + N3_PLAYER + N3_REWARD; + const int n3_hw = N3_MAP_H * N3_MAP_W; + const int multihot_elems = B * N3_MULTIHOT * n3_hw; + const Im2ColFastMods& m = kIm2ColModsC1; + + Allocator alloc{}; + PrecisionTensor d_obs{}, d_out_ref{}, d_out_fast{}; + d_obs = {.shape = {B, obs_size}}; + d_out_ref = {.shape = {(int64_t)multihot_elems}}; + d_out_fast = {.shape = {(int64_t)multihot_elems}}; + alloc_register(&alloc, &d_obs); + alloc_register(&alloc, &d_out_ref); + alloc_register(&alloc, &d_out_fast); + if (alloc_create(&alloc) != cudaSuccess) return 1; + + std::vector h_obs((size_t)B * (size_t)obs_size, 0.0f); + unsigned seed = 4242u; + for (int b = 0; b < B; ++b) { + for (int rem = 0; rem < n3_hw; ++rem) { + for (int f = 0; f < N3_NFEAT; ++f) { + int mv = n3_feat_max_v(f); + seed = seed * 1103515245u + 12345u; + int v = (int)((seed >> 16) % (unsigned)(mv + 1)); + h_obs[(size_t)b * (size_t)obs_size + (size_t)rem * (size_t)N3_NFEAT + (size_t)f] = (float)v; + } + } + } + copy_fp32_h2d(h_obs.data(), d_obs.data, B * obs_size); + + cudaStream_t stream = 0; + const float rel_eps = 1e-5f; + + auto launch_ref = [&](cudaStream_t s) { + n3_multihot_kernel<<>>(d_out_ref.data, d_obs.data, B, obs_size); + }; + auto launch_fast = [&](cudaStream_t s) { + n3_multihot_kernel_fast<<>>(d_out_fast.data, d_obs.data, B, obs_size, + m.dm_n3_hwf, m.dm_n3_hw, m.dm_n3_w, m.n3_hw, m.n3_hwf, m.n3_multihot_plane); + }; + + const size_t multihot_bytes = (size_t)multihot_elems * sizeof(precision_t); + cudaMemsetAsync(d_out_ref.data, 0, multihot_bytes, stream); + cudaMemsetAsync(d_out_fast.data, 0, multihot_bytes, stream); + cudaStreamSynchronize(stream); + launch_ref(stream); + cudaDeviceSynchronize(); + cudaMemsetAsync(d_out_fast.data, 0, multihot_bytes, stream); + cudaStreamSynchronize(stream); + launch_fast(stream); + cudaDeviceSynchronize(); + + std::vector h_ref, h_fast; + copy_precision_d2h(d_out_ref.data, multihot_elems, &h_ref); + copy_precision_d2h(d_out_fast.data, multihot_elems, &h_fast); + float mx, mn; + stats_diff(h_ref.data(), h_fast.data(), multihot_elems, &mx, &mn); + float rel = stats_rel_max(h_ref.data(), h_fast.data(), multihot_elems, rel_eps); + + printf("n3_multihot B=%d obs_size=%d multihot_elems=%d\n", B, obs_size, multihot_elems); + printf(" reference=n3_multihot_kernel fast=n3_multihot_kernel_fast (dm_n3_hwf+dm_n3_hw+dm_n3_w, n3_hwf=%d)\n", + m.n3_hwf); + printf(" correctness fast vs reference: max|diff| %.6g mean|diff| %.6g max rel err %.6g\n", mx, mn, rel); + printf(" timing (%d warmup / %d iters), kernel only (correctness above uses cudaMemsetAsync like encoder):\n", warmup, iters); + + float ms_ref = time_kernel_ms(stream, launch_ref, warmup, iters); + float ms_f = time_kernel_ms(stream, launch_fast, warmup, iters); + printf(" n3_multihot_kernel (ref): %8.4f us/iter\n", ms_ref * 1000.0f); + printf(" n3_multihot_kernel_fast: %8.4f us/iter (%.2fx vs ref)\n", ms_f * 1000.0f, ms_ref / ms_f); + printf("\n"); + + alloc_free(&alloc); + return 0; +} + +static int run_n3_multihot_bench(int warmup, int iters) { + for (int B = 1024; B <= 32768; B *= 2) { + if (run_n3_multihot_bench_B(B, warmup, iters)) return 1; + } + return 0; +} + int main(int argc, char** argv) { const int warmup = 50; const int iters = 200; + bool run_conv = true, run_multihot = true; for (int i = 1; i < argc; ++i) { if (strcmp(argv[i], "--float") == 0 || strcmp(argv[i], "--fp32") == 0 || strcmp(argv[i], "--bf16") == 0 || strcmp(argv[i], "--half") == 0) { continue; } + if (strcmp(argv[i], "--conv-only") == 0) { + run_multihot = false; + continue; + } + if (strcmp(argv[i], "--multihot-only") == 0) { + run_conv = false; + continue; + } if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) { - printf("Usage: %s (precision: compile with tests/bench_gemm_conv_end2end.sh --float|--bf16)\n", argv[0]); + printf("Usage: %s [--conv-only] [--multihot-only]\n", argv[0]); + printf(" Precision: compile via tests/bench_gemm_conv_end2end.sh --float|--bf16\n"); return 0; } fprintf(stderr, "Unknown arg: %s\n", argv[i]); @@ -327,7 +429,12 @@ int main(int argc, char** argv) { printf("bench_gemm_conv_end2end precision=bf16 warmup=%d iters=%d\n\n", warmup, iters); #endif - if (run_layer(1, warmup, iters)) return 1; - if (run_layer(2, warmup, iters)) return 1; + if (run_conv) { + if (run_layer(1, warmup, iters)) return 1; + if (run_layer(2, warmup, iters)) return 1; + } + if (run_multihot) { + if (run_n3_multihot_bench(warmup, iters)) return 1; + } return 0; } diff --git a/tests/bench_gemm_conv_end2end.sh b/tests/bench_gemm_conv_end2end.sh index 4c3ba31442..1aa51a6b83 100755 --- a/tests/bench_gemm_conv_end2end.sh +++ b/tests/bench_gemm_conv_end2end.sh @@ -1,9 +1,14 @@ #!/usr/bin/env bash -# Build and run end-to-end conv benchmark: gemm vs gemm_fast vs cudnn (fwd+bwd), layers 1 & 2. -# Args: --float | --fp32 (default) or --bf16 | --half +# Build and run end-to-end conv benchmark: gemm vs gemm_fast vs cudnn (fwd+bwd), layers 1 & 2, +# plus multihot: n3_multihot_kernel (reference) vs n3_multihot_kernel_fast (Im2ColFastMods). +# Args: +# --float | --fp32 (default) or --bf16 | --half +# --conv-only skip multihot microbench +# --multihot-only only multihot (skip conv layers / cudnn) # # ./tests/bench_gemm_conv_end2end.sh # ./tests/bench_gemm_conv_end2end.sh --bf16 +# ./tests/bench_gemm_conv_end2end.sh --multihot-only set -euo pipefail ROOT="$(cd "$(dirname "$0")/.." && pwd)" cd "$ROOT" @@ -13,12 +18,14 @@ NVCC="${NVCC:-$CUDA_HOME/bin/nvcc}" ARCH="${NVCC_ARCH:-native}" PRECISION_FLAG="-DPRECISION_FLOAT" +EXTRA_ARGS=() for arg in "$@"; do case "$arg" in --bf16|--half) PRECISION_FLAG="" ;; --float|--fp32) PRECISION_FLAG="-DPRECISION_FLOAT" ;; + --conv-only|--multihot-only) EXTRA_ARGS+=("$arg") ;; *) - echo "Unknown argument: $arg (use --float or --bf16)" >&2 + echo "Unknown argument: $arg (use --float, --bf16, --conv-only, or --multihot-only)" >&2 exit 1 ;; esac @@ -62,4 +69,4 @@ fi -L"${CUDA_HOME}/lib64" -L"${CUDA_HOME}/lib" \ -lcublas -lcudnn -lcurand -exec "$OUT" +exec "$OUT" "${EXTRA_ARGS[@]}" From 9a2382cae32fda1592e161e73ac5894e05b7465b Mon Sep 17 00:00:00 2001 From: jonah Date: Fri, 10 Apr 2026 08:00:01 -0700 Subject: [PATCH 8/9] embedding kernel --- src/ocean.cu | 28 +++++++- tests/bench_gemm_conv_end2end.cu | 106 +++++++++++++++++++++++++++++-- tests/bench_gemm_conv_end2end.sh | 11 ++-- 3 files changed, 132 insertions(+), 13 deletions(-) diff --git a/src/ocean.cu b/src/ocean.cu index 9578f66997..2387e09f29 100644 --- a/src/ocean.cu +++ b/src/ocean.cu @@ -95,6 +95,7 @@ static const Im2ColFastMods kIm2ColModsC1( N3_C1_IC, N3_MAP_H, N3_MAP_W, N3_C1_OC, N3_C1_K, N3_C1_S, N3_C1_OH, N3_C1_OW); static const Im2ColFastMods kIm2ColModsC2( N3_C2_IC, N3_C1_OH, N3_C1_OW, N3_C2_OC, N3_C2_K, N3_C2_S, N3_C2_OH, N3_C2_OW); +static const FastDivMod kDmN3Player(N3_PLAYER); // ---- NMMO3 kernels ---- @@ -128,6 +129,29 @@ __global__ void n3_multihot_kernel( dst[(N3_OFFSETS[f] + (int)to_float(src[f])) * N3_MAP_H * N3_MAP_W + h * N3_MAP_W + w] = from_float(1.0f); } +__global__ void n3_embedding_kernel_fast( + precision_t* __restrict__ out, const precision_t* __restrict__ obs, + const precision_t* __restrict__ embed_w, int B, int obs_size, const FastDivMod dm_n3_player) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= B * N3_PLAYER) return; + int b, f; + dm_n3_player.divmod(idx, b, f); + int val = (int)to_float(obs[b * obs_size + N3_MAP_SIZE + f]); + const precision_t* src = embed_w + val * N3_EMBED_DIM; + precision_t* dst = out + b * N3_PLAYER_EMBED + f * N3_EMBED_DIM; +#ifdef PRECISION_FLOAT + const float4* src4 = reinterpret_cast(src); + float4* dst4 = reinterpret_cast(dst); +#pragma unroll + for (int i = 0; i < N3_EMBED_DIM / 4; i++) dst4[i] = src4[i]; +#else + const uint4* src4 = reinterpret_cast(src); + uint4* dst4 = reinterpret_cast(dst); +#pragma unroll + for (int i = 0; i < N3_EMBED_DIM / 8; i++) dst4[i] = src4[i]; +#endif +} + __global__ void n3_embedding_kernel( precision_t* __restrict__ out, const precision_t* __restrict__ obs, const precision_t* __restrict__ embed_w, int B, int obs_size) { @@ -668,8 +692,8 @@ static PrecisionTensor nmmo3_encoder_forward(void* w, void* activations, Precisi cudaMemcpyAsync(a->conv2.saved_input.data, a->conv1.out.data, (int64_t)B * N3_C2_IC * N3_C1_OH * N3_C1_OW * sizeof(precision_t), cudaMemcpyDeviceToDevice, stream); - n3_embedding_kernel<<>>( - a->embed_out.data, input.data, ew->embed_w.data, B, ew->obs_size); + n3_embedding_kernel_fast<<>>( + a->embed_out.data, input.data, ew->embed_w.data, B, ew->obs_size, kDmN3Player); n3_concat_kernel<<>>( a->concat.data, a->conv2.out.data, a->embed_out.data, input.data, B, ew->obs_size); diff --git a/tests/bench_gemm_conv_end2end.cu b/tests/bench_gemm_conv_end2end.cu index ad620885f9..3d3de1e6c1 100644 --- a/tests/bench_gemm_conv_end2end.cu +++ b/tests/bench_gemm_conv_end2end.cu @@ -1,7 +1,8 @@ // End-to-end: gemm conv (slow) vs gemm_fast vs cudnn — forward & backward timed separately, layers 1 & 2 (NMMO3). // Multihot microbench: n3_multihot_kernel (reference, not fast) vs n3_multihot_kernel_fast — correctness + timing, // B in {1024..32768} powers of two (run_n3_multihot_bench_B). -// Build/run: tests/bench_gemm_conv_end2end.sh [--float|--bf16] [--conv-only|--multihot-only] +// Embedding microbench: n3_embedding_kernel vs n3_embedding_kernel_fast — own CLI flag --embedding-only (same B grid). +// Build/run: tests/bench_gemm_conv_end2end.sh [--float|--bf16] [--multihot-only|--embedding-only] #include @@ -390,6 +391,82 @@ static int run_n3_multihot_bench_B(int B, int warmup, int iters) { return 0; } +static int run_n3_embedding_bench_B(int B, int warmup, int iters) { + const int obs_size = N3_MAP_SIZE + N3_PLAYER + N3_REWARD; + const int out_elems = B * N3_PLAYER_EMBED; + const int embed_elems = N3_EMBED_VOCAB * N3_EMBED_DIM; + + Allocator alloc{}; + PrecisionTensor d_obs{}, d_embed{}, d_out_ref{}, d_out_fast{}; + d_obs = {.shape = {B, obs_size}}; + d_embed = {.shape = {N3_EMBED_VOCAB, N3_EMBED_DIM}}; + d_out_ref = {.shape = {B, N3_PLAYER_EMBED}}; + d_out_fast = {.shape = {B, N3_PLAYER_EMBED}}; + alloc_register(&alloc, &d_obs); + alloc_register(&alloc, &d_embed); + alloc_register(&alloc, &d_out_ref); + alloc_register(&alloc, &d_out_fast); + if (alloc_create(&alloc) != cudaSuccess) return 1; + + std::vector h_obs((size_t)B * (size_t)obs_size, 0.0f); + std::vector h_embed((size_t)embed_elems); + fill_rand_host(h_embed.data(), embed_elems, 9191u); + unsigned seed = 5150u; + for (int b = 0; b < B; ++b) { + for (int f = 0; f < N3_PLAYER; ++f) { + seed = seed * 1103515245u + 12345u; + int v = (int)((seed >> 16) % (unsigned)N3_EMBED_VOCAB); + h_obs[(size_t)b * (size_t)obs_size + (size_t)N3_MAP_SIZE + (size_t)f] = (float)v; + } + } + copy_fp32_h2d(h_obs.data(), d_obs.data, B * obs_size); + copy_fp32_h2d(h_embed.data(), d_embed.data, embed_elems); + + cudaStream_t stream = 0; + const float rel_eps = 1e-5f; + + auto launch_ref = [&](cudaStream_t s) { + n3_embedding_kernel<<>>( + d_out_ref.data, d_obs.data, d_embed.data, B, obs_size); + }; + auto launch_fast = [&](cudaStream_t s) { + n3_embedding_kernel_fast<<>>( + d_out_fast.data, d_obs.data, d_embed.data, B, obs_size, kDmN3Player); + }; + + const size_t out_bytes = (size_t)out_elems * sizeof(precision_t); + cudaMemsetAsync(d_out_ref.data, 0, out_bytes, stream); + cudaMemsetAsync(d_out_fast.data, 0, out_bytes, stream); + cudaStreamSynchronize(stream); + launch_ref(stream); + cudaDeviceSynchronize(); + cudaMemsetAsync(d_out_fast.data, 0, out_bytes, stream); + cudaStreamSynchronize(stream); + launch_fast(stream); + cudaDeviceSynchronize(); + + std::vector h_ref, h_fast; + copy_precision_d2h(d_out_ref.data, out_elems, &h_ref); + copy_precision_d2h(d_out_fast.data, out_elems, &h_fast); + float mx, mn; + stats_diff(h_ref.data(), h_fast.data(), out_elems, &mx, &mn); + float rel = stats_rel_max(h_ref.data(), h_fast.data(), out_elems, rel_eps); + + printf("n3_embedding B=%d obs_size=%d out_elems=%d embed_elems=%d\n", B, obs_size, out_elems, embed_elems); + printf(" reference=n3_embedding_kernel fast=n3_embedding_kernel_fast (kDmN3Player + vec copy)\n"); + printf(" correctness fast vs reference: max|diff| %.6g mean|diff| %.6g max rel err %.6g\n", mx, mn, rel); + printf(" timing (%d warmup / %d iters), kernel only:\n", warmup, iters); + + float ms_ref = time_kernel_ms(stream, launch_ref, warmup, iters); + float ms_f = time_kernel_ms(stream, launch_fast, warmup, iters); + printf(" n3_embedding_kernel (ref): %8.4f us/iter\n", ms_ref * 1000.0f); + printf(" n3_embedding_kernel_fast: %8.4f us/iter (%.2fx vs ref)\n", ms_f * 1000.0f, ms_ref / ms_f); + printf("\n"); + + alloc_free(&alloc); + return 0; +} + static int run_n3_multihot_bench(int warmup, int iters) { for (int B = 1024; B <= 32768; B *= 2) { if (run_n3_multihot_bench_B(B, warmup, iters)) return 1; @@ -397,26 +474,40 @@ static int run_n3_multihot_bench(int warmup, int iters) { return 0; } +static int run_n3_embedding_bench(int warmup, int iters) { + for (int B = 1024; B <= 32768; B *= 2) { + if (run_n3_embedding_bench_B(B, warmup, iters)) return 1; + } + return 0; +} + int main(int argc, char** argv) { const int warmup = 50; const int iters = 200; - bool run_conv = true, run_multihot = true; + bool run_conv = true, run_multihot = true, run_embedding = true; for (int i = 1; i < argc; ++i) { if (strcmp(argv[i], "--float") == 0 || strcmp(argv[i], "--fp32") == 0 || strcmp(argv[i], "--bf16") == 0 || strcmp(argv[i], "--half") == 0) { continue; } - if (strcmp(argv[i], "--conv-only") == 0) { - run_multihot = false; + if (strcmp(argv[i], "--multihot-only") == 0) { + run_conv = false; + run_multihot = true; + run_embedding = false; continue; } - if (strcmp(argv[i], "--multihot-only") == 0) { + if (strcmp(argv[i], "--embedding-only") == 0) { run_conv = false; + run_multihot = false; + run_embedding = true; continue; } if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) { - printf("Usage: %s [--conv-only] [--multihot-only]\n", argv[0]); + printf("Usage: %s [--multihot-only] [--embedding-only]\n", argv[0]); printf(" Precision: compile via tests/bench_gemm_conv_end2end.sh --float|--bf16\n"); + printf(" Default: conv layers + n3_multihot ref vs fast + n3_embedding ref vs fast (B=1024..32768).\n"); + printf(" --multihot-only multihot ref vs fast only\n"); + printf(" --embedding-only embedding ref vs fast only\n"); return 0; } fprintf(stderr, "Unknown arg: %s\n", argv[i]); @@ -436,5 +527,8 @@ int main(int argc, char** argv) { if (run_multihot) { if (run_n3_multihot_bench(warmup, iters)) return 1; } + if (run_embedding) { + if (run_n3_embedding_bench(warmup, iters)) return 1; + } return 0; } diff --git a/tests/bench_gemm_conv_end2end.sh b/tests/bench_gemm_conv_end2end.sh index 1aa51a6b83..431f13ca22 100755 --- a/tests/bench_gemm_conv_end2end.sh +++ b/tests/bench_gemm_conv_end2end.sh @@ -1,14 +1,15 @@ #!/usr/bin/env bash # Build and run end-to-end conv benchmark: gemm vs gemm_fast vs cudnn (fwd+bwd), layers 1 & 2, -# plus multihot: n3_multihot_kernel (reference) vs n3_multihot_kernel_fast (Im2ColFastMods). +# plus NMMO3 microbenches (each optional via flags; default runs conv + multihot + embedding). # Args: # --float | --fp32 (default) or --bf16 | --half -# --conv-only skip multihot microbench -# --multihot-only only multihot (skip conv layers / cudnn) +# --multihot-only n3_multihot ref vs fast only (B=1024..32768) +# --embedding-only n3_embedding ref vs fast only (same B grid) # # ./tests/bench_gemm_conv_end2end.sh # ./tests/bench_gemm_conv_end2end.sh --bf16 # ./tests/bench_gemm_conv_end2end.sh --multihot-only +# ./tests/bench_gemm_conv_end2end.sh --embedding-only set -euo pipefail ROOT="$(cd "$(dirname "$0")/.." && pwd)" cd "$ROOT" @@ -23,9 +24,9 @@ for arg in "$@"; do case "$arg" in --bf16|--half) PRECISION_FLAG="" ;; --float|--fp32) PRECISION_FLAG="-DPRECISION_FLOAT" ;; - --conv-only|--multihot-only) EXTRA_ARGS+=("$arg") ;; + --multihot-only|--embedding-only) EXTRA_ARGS+=("$arg") ;; *) - echo "Unknown argument: $arg (use --float, --bf16, --conv-only, or --multihot-only)" >&2 + echo "Unknown argument: $arg (use --float, --bf16, --multihot-only, or --embedding-only)" >&2 exit 1 ;; esac From 4be0af095229fa51bec33db776fc92b98dd3d24f Mon Sep 17 00:00:00 2001 From: jonah Date: Sun, 12 Apr 2026 07:51:04 -0700 Subject: [PATCH 9/9] n3_conv_bias_grad_nchw --- src/ocean.cu | 90 +++++++++++++++++++++++++++++--- tests/bench_gemm_conv_end2end.cu | 89 +++++++++++++++++++++++++++++-- tests/bench_gemm_conv_end2end.sh | 10 ++-- 3 files changed, 173 insertions(+), 16 deletions(-) diff --git a/src/ocean.cu b/src/ocean.cu index 2387e09f29..8ce6eb57da 100644 --- a/src/ocean.cu +++ b/src/ocean.cu @@ -2,6 +2,7 @@ // Included by pufferlib.cu — requires precision_t, PrecisionTensor, Allocator, puf_mm, etc. #include "cudnn_conv2d.cu" +#include "kernels.cu" // ---- NMMO3 constants ---- @@ -96,6 +97,8 @@ static const Im2ColFastMods kIm2ColModsC1( static const Im2ColFastMods kIm2ColModsC2( N3_C2_IC, N3_C1_OH, N3_C1_OW, N3_C2_OC, N3_C2_K, N3_C2_S, N3_C2_OH, N3_C2_OW); static const FastDivMod kDmN3Player(N3_PLAYER); +static const FastDivMod kDmConvBiasSpatialC1(N3_C1_OH * N3_C1_OW); +static const FastDivMod kDmConvBiasSpatialC2(N3_C2_OH * N3_C2_OW); // ---- NMMO3 kernels ---- @@ -221,7 +224,82 @@ __global__ void bias_grad_kernel( } } -// NCHW bias grad: sum over (B, OH, OW) for each OC channel +// NCHW bias grad: sum over (B, OH, OW) for each OC channel (dm_spatial.d_ must equal spatial). +// NMMO3 spatial 12 / 2: stripe over b, vectorize along contiguous s; else flat-i + FastDivMod. +__global__ void n3_conv_bias_grad_nchw_fast( + precision_t* __restrict__ bgrad, const precision_t* __restrict__ grad, + int B, int OC, const FastDivMod dm_spatial) { + int oc = blockIdx.x; + if (oc >= OC) return; + const int spatial = (int)dm_spatial.d_; + const int sp_c1 = N3_C1_OH * N3_C1_OW; + const int sp_c2 = N3_C2_OH * N3_C2_OW; + float sum = 0.0f; + + if (spatial == sp_c1) { +#ifdef PRECISION_FLOAT + for (int b = threadIdx.x; b < B; b += blockDim.x) { + const float* row = grad + ((int64_t)b * OC + oc) * spatial; + float4 a0 = *reinterpret_cast(row); + float4 a1 = *reinterpret_cast(row + 4); + float4 a2 = *reinterpret_cast(row + 8); + sum += a0.x + a0.y + a0.z + a0.w + a1.x + a1.y + a1.z + a1.w + a2.x + a2.y + a2.z + a2.w; + } +#else + for (int b = threadIdx.x; b < B; b += blockDim.x) { + const __nv_bfloat16* row = grad + ((int64_t)b * OC + oc) * spatial; + const uint64_t* p = reinterpret_cast(row); + #pragma unroll + for (int j = 0; j < 3; ++j) { + union { + uint64_t u; + __nv_bfloat16 h[4]; + } w; + w.u = p[j]; + sum += to_float(w.h[0]) + to_float(w.h[1]) + to_float(w.h[2]) + to_float(w.h[3]); + } + } +#endif + } else if (spatial == sp_c2) { +#ifdef PRECISION_FLOAT + for (int b = threadIdx.x; b < B; b += blockDim.x) { + const float* row = grad + ((int64_t)b * OC + oc) * spatial; + float2 v = *reinterpret_cast(row); + sum += v.x + v.y; + } +#else + for (int b = threadIdx.x; b < B; b += blockDim.x) { + const __nv_bfloat16* row = grad + ((int64_t)b * OC + oc) * spatial; + union { + uint32_t u; + __nv_bfloat16 h[2]; + } w; + w.u = *reinterpret_cast(row); + sum += to_float(w.h[0]) + to_float(w.h[1]); + } +#endif + } else { + int total = B * spatial; + for (int i = threadIdx.x; i < total; i += blockDim.x) { + int bb, s; + dm_spatial.divmod(i, bb, s); + sum += to_float(grad[(int64_t)bb * OC * spatial + oc * spatial + s]); + } + } + for (int offset = 16; offset > 0; offset >>= 1) + sum += __shfl_down_sync(0xffffffff, sum, offset); + __shared__ float sdata[32]; + int lane = threadIdx.x % 32, warp = threadIdx.x / 32; + if (lane == 0) sdata[warp] = sum; + __syncthreads(); + if (warp == 0) { + sum = (lane < (blockDim.x + 31) / 32) ? sdata[lane] : 0.0f; + for (int offset = 16; offset > 0; offset >>= 1) + sum += __shfl_down_sync(0xffffffff, sum, offset); + if (lane == 0) bgrad[oc] = from_float(sum); + } +} + __global__ void n3_conv_bias_grad_nchw( precision_t* __restrict__ bgrad, const precision_t* __restrict__ grad, int B, int OC, int spatial) { @@ -720,9 +798,8 @@ static void nmmo3_encoder_backward(void* w, void* activations, PrecisionTensor g n3_concat_backward_conv_kernel<<>>( a->conv2.grad.data, grad_concat.data, B); - n3_conv_bias_grad_nchw<<conv2.OC, 256, 0, stream>>>( - a->conv2.bgrad.data, a->conv2.grad.data, - B, ew->conv2.OC, ew->conv2.OH * ew->conv2.OW); + n3_conv_bias_grad_nchw_fast<<conv2.OC, 256, 0, stream>>>( + a->conv2.bgrad.data, a->conv2.grad.data, B, ew->conv2.OC, kDmConvBiasSpatialC2); gemm_conv_backward_fast(&ew->conv2.w, a->conv2.saved_input.data, a->conv2.grad.data, a->conv2.wgrad.data, a->conv1.grad.data, a->col2.data, a->mm2.data, B, kIm2ColModsC2, stream); @@ -730,9 +807,8 @@ static void nmmo3_encoder_backward(void* w, void* activations, PrecisionTensor g n3_relu_backward_kernel<<conv1.OC * ew->conv1.OH * ew->conv1.OW), BLOCK_SIZE, 0, stream>>>( a->conv1.grad.data, a->conv1.out.data, B * ew->conv1.OC * ew->conv1.OH * ew->conv1.OW); - n3_conv_bias_grad_nchw<<conv1.OC, 256, 0, stream>>>( - a->conv1.bgrad.data, a->conv1.grad.data, - B, ew->conv1.OC, ew->conv1.OH * ew->conv1.OW); + n3_conv_bias_grad_nchw_fast<<conv1.OC, 256, 0, stream>>>( + a->conv1.bgrad.data, a->conv1.grad.data, B, ew->conv1.OC, kDmConvBiasSpatialC1); gemm_conv_backward_fast(&ew->conv1.w, a->conv1.saved_input.data, a->conv1.grad.data, a->conv1.wgrad.data, NULL, a->col1.data, a->mm1.data, B, kIm2ColModsC1, stream); diff --git a/tests/bench_gemm_conv_end2end.cu b/tests/bench_gemm_conv_end2end.cu index 3d3de1e6c1..564759cd5d 100644 --- a/tests/bench_gemm_conv_end2end.cu +++ b/tests/bench_gemm_conv_end2end.cu @@ -2,7 +2,8 @@ // Multihot microbench: n3_multihot_kernel (reference, not fast) vs n3_multihot_kernel_fast — correctness + timing, // B in {1024..32768} powers of two (run_n3_multihot_bench_B). // Embedding microbench: n3_embedding_kernel vs n3_embedding_kernel_fast — own CLI flag --embedding-only (same B grid). -// Build/run: tests/bench_gemm_conv_end2end.sh [--float|--bf16] [--multihot-only|--embedding-only] +// Conv bias grad: n3_conv_bias_grad_nchw vs n3_conv_bias_grad_nchw_fast (FastDivMod) — --conv-bias-grad-only. +// Build/run: tests/bench_gemm_conv_end2end.sh [--float|--bf16] [--multihot-only|--embedding-only|--conv-bias-grad-only] #include @@ -481,10 +482,75 @@ static int run_n3_embedding_bench(int warmup, int iters) { return 0; } +static int run_n3_conv_bias_grad_bench_B( + int B, int OC, int spatial, const FastDivMod dm_spatial, const char* tag, int warmup, int iters) { + const int grad_elems = B * OC * spatial; + cudaStream_t stream = 0; + const float rel_eps = 1e-5f; + + Allocator alloc{}; + PrecisionTensor d_grad{}, d_bref{}, d_bfast{}; + d_grad = {.shape = {(int64_t)grad_elems}}; + d_bref = {.shape = {OC}}; + d_bfast = {.shape = {OC}}; + alloc_register(&alloc, &d_grad); + alloc_register(&alloc, &d_bref); + alloc_register(&alloc, &d_bfast); + if (alloc_create(&alloc) != cudaSuccess) return 1; + + std::vector h_grad((size_t)grad_elems); + fill_rand_host(h_grad.data(), grad_elems, 77077u + (unsigned)B + (unsigned)spatial); + copy_fp32_h2d(h_grad.data(), d_grad.data, grad_elems); + + cudaMemsetAsync(d_bref.data, 0, (size_t)OC * sizeof(precision_t), stream); + n3_conv_bias_grad_nchw<<>>(d_bref.data, d_grad.data, B, OC, spatial); + cudaDeviceSynchronize(); + + cudaMemsetAsync(d_bfast.data, 0, (size_t)OC * sizeof(precision_t), stream); + n3_conv_bias_grad_nchw_fast<<>>(d_bfast.data, d_grad.data, B, OC, dm_spatial); + cudaDeviceSynchronize(); + + std::vector href, hfast; + copy_precision_d2h(d_bref.data, OC, &href); + copy_precision_d2h(d_bfast.data, OC, &hfast); + float mx, mn; + stats_diff(href.data(), hfast.data(), OC, &mx, &mn); + float rel = stats_rel_max(href.data(), hfast.data(), OC, rel_eps); + + printf("n3_conv_bias_grad %s B=%d OC=%d spatial=%d grad_elems=%d\n", tag, B, OC, spatial, grad_elems); + printf(" reference=n3_conv_bias_grad_nchw fast=n3_conv_bias_grad_nchw_fast (FastDivMod)\n"); + printf(" correctness fast vs reference: max|diff| %.6g mean|diff| %.6g max rel err %.6g\n", mx, mn, rel); + printf(" timing (%d warmup / %d iters), kernel only:\n", warmup, iters); + + auto launch_ref = [&](cudaStream_t s) { + n3_conv_bias_grad_nchw<<>>(d_bref.data, d_grad.data, B, OC, spatial); + }; + auto launch_fast = [&](cudaStream_t s) { + n3_conv_bias_grad_nchw_fast<<>>(d_bfast.data, d_grad.data, B, OC, dm_spatial); + }; + + float ms_ref = time_kernel_ms(stream, launch_ref, warmup, iters); + float ms_f = time_kernel_ms(stream, launch_fast, warmup, iters); + printf(" n3_conv_bias_grad_nchw (ref): %8.4f us/iter\n", ms_ref * 1000.0f); + printf(" n3_conv_bias_grad_nchw_fast: %8.4f us/iter (%.2fx vs ref)\n", ms_f * 1000.0f, ms_ref / ms_f); + printf("\n"); + + alloc_free(&alloc); + return 0; +} + +static int run_n3_conv_bias_grad_bench(int warmup, int iters) { + for (int B = 1024; B <= 32768; B *= 2) { + if (run_n3_conv_bias_grad_bench_B(B, N3_C2_OC, N3_C2_OH * N3_C2_OW, kDmConvBiasSpatialC2, "conv2", warmup, iters)) return 1; + if (run_n3_conv_bias_grad_bench_B(B, N3_C1_OC, N3_C1_OH * N3_C1_OW, kDmConvBiasSpatialC1, "conv1", warmup, iters)) return 1; + } + return 0; +} + int main(int argc, char** argv) { const int warmup = 50; const int iters = 200; - bool run_conv = true, run_multihot = true, run_embedding = true; + bool run_conv = true, run_multihot = true, run_embedding = true, run_conv_bias_grad = false; for (int i = 1; i < argc; ++i) { if (strcmp(argv[i], "--float") == 0 || strcmp(argv[i], "--fp32") == 0 || strcmp(argv[i], "--bf16") == 0 || strcmp(argv[i], "--half") == 0) { @@ -494,20 +560,30 @@ int main(int argc, char** argv) { run_conv = false; run_multihot = true; run_embedding = false; + run_conv_bias_grad = false; continue; } if (strcmp(argv[i], "--embedding-only") == 0) { run_conv = false; run_multihot = false; run_embedding = true; + run_conv_bias_grad = false; + continue; + } + if (strcmp(argv[i], "--conv-bias-grad-only") == 0) { + run_conv = false; + run_multihot = false; + run_embedding = false; + run_conv_bias_grad = true; continue; } if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) { - printf("Usage: %s [--multihot-only] [--embedding-only]\n", argv[0]); + printf("Usage: %s [--multihot-only] [--embedding-only] [--conv-bias-grad-only]\n", argv[0]); printf(" Precision: compile via tests/bench_gemm_conv_end2end.sh --float|--bf16\n"); printf(" Default: conv layers + n3_multihot ref vs fast + n3_embedding ref vs fast (B=1024..32768).\n"); - printf(" --multihot-only multihot ref vs fast only\n"); - printf(" --embedding-only embedding ref vs fast only\n"); + printf(" --multihot-only multihot ref vs fast only\n"); + printf(" --embedding-only embedding ref vs fast only\n"); + printf(" --conv-bias-grad-only n3_conv_bias_grad_nchw vs _fast (conv1+conv2 shapes)\n"); return 0; } fprintf(stderr, "Unknown arg: %s\n", argv[i]); @@ -530,5 +606,8 @@ int main(int argc, char** argv) { if (run_embedding) { if (run_n3_embedding_bench(warmup, iters)) return 1; } + if (run_conv_bias_grad) { + if (run_n3_conv_bias_grad_bench(warmup, iters)) return 1; + } return 0; } diff --git a/tests/bench_gemm_conv_end2end.sh b/tests/bench_gemm_conv_end2end.sh index 431f13ca22..1eec994315 100755 --- a/tests/bench_gemm_conv_end2end.sh +++ b/tests/bench_gemm_conv_end2end.sh @@ -3,13 +3,15 @@ # plus NMMO3 microbenches (each optional via flags; default runs conv + multihot + embedding). # Args: # --float | --fp32 (default) or --bf16 | --half -# --multihot-only n3_multihot ref vs fast only (B=1024..32768) -# --embedding-only n3_embedding ref vs fast only (same B grid) +# --multihot-only n3_multihot ref vs fast only (B=1024..32768) +# --embedding-only n3_embedding ref vs fast only (same B grid) +# --conv-bias-grad-only n3_conv_bias_grad_nchw vs _fast, conv1+conv2 NMMO3 shapes # # ./tests/bench_gemm_conv_end2end.sh # ./tests/bench_gemm_conv_end2end.sh --bf16 # ./tests/bench_gemm_conv_end2end.sh --multihot-only # ./tests/bench_gemm_conv_end2end.sh --embedding-only +# ./tests/bench_gemm_conv_end2end.sh --conv-bias-grad-only set -euo pipefail ROOT="$(cd "$(dirname "$0")/.." && pwd)" cd "$ROOT" @@ -24,9 +26,9 @@ for arg in "$@"; do case "$arg" in --bf16|--half) PRECISION_FLAG="" ;; --float|--fp32) PRECISION_FLAG="-DPRECISION_FLOAT" ;; - --multihot-only|--embedding-only) EXTRA_ARGS+=("$arg") ;; + --multihot-only|--embedding-only|--conv-bias-grad-only) EXTRA_ARGS+=("$arg") ;; *) - echo "Unknown argument: $arg (use --float, --bf16, --multihot-only, or --embedding-only)" >&2 + echo "Unknown argument: $arg (use --float, --bf16, --multihot-only, --embedding-only, or --conv-bias-grad-only)" >&2 exit 1 ;; esac