diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 9fb6e4dd..016f8a08 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -591,7 +591,7 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4( if (att_cap > 0.0f) { VF4 cap = hn::Set(df4, att_cap); VF4 one_over_cap = hn::Set(df4, one_over_att_cap); - new_max = hn::Mul(cap, hn::Tanh(df4, hn::Mul(new_max, one_over_cap))); + new_max = hn::Mul(cap, hn::FastTanh(df4, hn::Mul(new_max, one_over_cap))); } VF4 local_max = new_max; VF4 old_max_vf = hn::Set(df4, kNegInf); @@ -733,7 +733,7 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8( if (att_cap > 0.0f) { VF8 cap = hn::Set(df8, att_cap); VF8 one_over_cap = hn::Set(df8, one_over_att_cap); - new_max = hn::Mul(cap, hn::Tanh(df8, hn::Mul(new_max, one_over_cap))); + new_max = hn::Mul(cap, hn::FastTanh(df8, hn::Mul(new_max, one_over_cap))); } VF8 local_max = new_max; VF8 old_max_vf = hn::Set(df8, kNegInf); @@ -1309,27 +1309,27 @@ static HWY_INLINE void ApplySoftCap(DF df, float att_cap, float one_over_cap, if (att_cap > 0.0f) { VF cap = hn::Set(df, att_cap); VF one_over_cap_vec = hn::Set(df, one_over_cap); - x0 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x0, one_over_cap_vec))); + x0 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x0, one_over_cap_vec))); if constexpr (kVTileSize >= 2) { - x1 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x1, one_over_cap_vec))); + x1 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x1, one_over_cap_vec))); } if constexpr (kVTileSize >= 3) { - x2 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x2, one_over_cap_vec))); + x2 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x2, one_over_cap_vec))); } if constexpr (kVTileSize >= 4) { - x3 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x3, one_over_cap_vec))); + x3 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x3, one_over_cap_vec))); } if constexpr (kVTileSize >= 5) { - x4 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x4, one_over_cap_vec))); + x4 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x4, one_over_cap_vec))); } if constexpr (kVTileSize >= 6) { - x5 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x5, one_over_cap_vec))); + x5 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x5, one_over_cap_vec))); } if constexpr (kVTileSize >= 7) { - x6 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x6, one_over_cap_vec))); + x6 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x6, one_over_cap_vec))); } if constexpr (kVTileSize >= 8) { - x7 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x7, one_over_cap_vec))); + x7 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x7, one_over_cap_vec))); } } }