Skip to content

Commit

Permalink
Add soft-capping to Gemma2
Browse files Browse the repository at this point in the history
This is a cherry-pick of ggerganov/llama.cpp#8197
  • Loading branch information
abetlen authored and jart committed Jul 1, 2024
1 parent 263d39b commit 140eed5
Showing 1 changed file with 30 additions and 3 deletions.
33 changes: 30 additions & 3 deletions llama.cpp/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@ enum llm_kv {
LLM_KV_EXPERT_USED_COUNT,
LLM_KV_POOLING_TYPE,
LLM_KV_LOGIT_SCALE,
LLM_KV_ATTN_LOGIT_SOFTCAPPING,
LLM_KV_FINAL_LOGIT_SOFTCAPPING,

LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV,
Expand Down Expand Up @@ -322,6 +324,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" },
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
{ LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
{ LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },

{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
Expand Down Expand Up @@ -1525,6 +1529,9 @@ struct llama_hparams {
float f_norm_eps;
float f_norm_rms_eps;

float f_attn_logit_softcapping = 50.0f;
float f_final_logit_softcapping = 30.0f;

float rope_attn_factor = 1.0f;
float rope_freq_base_train;
float rope_freq_scale_train;
Expand All @@ -1540,8 +1547,9 @@ struct llama_hparams {
float f_max_alibi_bias = 0.0f;
float f_logit_scale = 0.0f;

bool causal_attn = true;
bool use_alibi = false;
bool causal_attn = true;
bool use_alibi = false;
bool attn_soft_cap = false;

enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
Expand Down Expand Up @@ -3995,6 +4003,9 @@ static void llm_load_hparams(
case LLM_ARCH_GEMMA2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
hparams.attn_soft_cap = true;

switch (hparams.n_layer) {
case 18: model.type = e_model::MODEL_9B; break;
Expand Down Expand Up @@ -6511,6 +6522,12 @@ static struct ggml_tensor * llm_build_kqv(
kq = ggml_scale(ctx, kq, 30);
}

if (hparams.attn_soft_cap) {
kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping);
kq = ggml_tanh(ctx, kq);
kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping);
}

kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);

Expand Down Expand Up @@ -9920,7 +9937,7 @@ struct llm_build_context {
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il);

Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));
Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head)));
cb(Qcur, "Qcur_scaled", il);

Kcur = ggml_rope_ext(
Expand Down Expand Up @@ -9988,6 +10005,11 @@ struct llm_build_context {
cur = ggml_mul_mat(ctx0, model.output, cur);
cb(cur, "result_output", -1);

// final logit soft-capping
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
cur = ggml_tanh(ctx0, cur);
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);

ggml_build_forward_expand(gf, cur);

return gf;
Expand Down Expand Up @@ -15398,6 +15420,11 @@ struct llama_context * llama_new_context_with_model(
params.flash_attn = false;
}

if (params.flash_attn && model->hparams.attn_soft_cap) {
LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__);
params.flash_attn = false;
}

llama_context * ctx = new llama_context(*model);

const auto & hparams = model->hparams;
Expand Down

0 comments on commit 140eed5

Please sign in to comment.