diff --git a/ds4.c b/ds4.c index ee069b76a..f30021bca 100644 --- a/ds4.c +++ b/ds4.c @@ -6911,10 +6911,10 @@ static void layer_grouped_out_one( const ds4_model * model, const ds4_layer_weights * layer, const float * heads) { - const uint32_t n_groups = 8; + const uint32_t n_groups = DS4_N_OUT_GROUP; const uint32_t group_heads = DS4_N_HEAD / n_groups; const uint32_t group_dim = DS4_N_HEAD_DIM * group_heads; - const uint32_t rank = 1024; + const uint32_t rank = DS4_N_LORA_O; float *low = xcalloc((size_t)n_groups * rank, sizeof(low[0])); @@ -6930,10 +6930,10 @@ static void layer_grouped_out_one_decode_scratch( const ds4_layer_weights * layer, const float * heads, ds4_cpu_decode_scratch * scratch) { - const uint32_t n_groups = 8; + const uint32_t n_groups = DS4_N_OUT_GROUP; const uint32_t group_heads = DS4_N_HEAD / n_groups; const uint32_t group_dim = DS4_N_HEAD_DIM * group_heads; - const uint32_t rank = 1024; + const uint32_t rank = DS4_N_LORA_O; memset(scratch->attn_low, 0, (size_t)n_groups * rank * sizeof(scratch->attn_low[0])); matvec_q8_0_grouped_rows_decode_scratch(scratch->attn_low, model, layer->attn_output_a, @@ -6947,10 +6947,10 @@ static void layer_grouped_out_batch( const ds4_layer_weights * layer, const float * heads, uint32_t n_tok) { - const uint32_t n_groups = 8; + const uint32_t n_groups = DS4_N_OUT_GROUP; const uint32_t group_heads = DS4_N_HEAD / n_groups; const uint32_t group_dim = DS4_N_HEAD_DIM * group_heads; - const uint32_t rank = 1024; + const uint32_t rank = DS4_N_LORA_O; float *low = xcalloc((size_t)n_tok * n_groups * rank, sizeof(low[0]));