diff --git a/BUILD.bazel b/BUILD.bazel index deb376bf..7edb6fa8 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -144,6 +144,7 @@ cc_test( ":kv_transcoding", ":mat", ":matmul", + ":ops", ":test_util", ":threading_context", ":weights", diff --git a/evals/gemma_batch_bench.cc b/evals/gemma_batch_bench.cc index a780aa41..6a4f893a 100644 --- a/evals/gemma_batch_bench.cc +++ b/evals/gemma_batch_bench.cc @@ -129,19 +129,16 @@ std::vector GenerateInputs() { TEST_F(GemmaBatchBench, RandomQuestionsBatched) { s_env->SetMaxGeneratedTokens(12); const std::vector inputs = GenerateInputs(); - const AttentionImpl modes[] = {AttentionImpl::kOld, AttentionImpl::kFlash}; - for (const AttentionImpl mode : modes) { - // Run multiple times so that auto-tuning is closer to complete. - fprintf(stderr, "Testing mode %s\n", GetAttentionImplName(mode).c_str()); - for (size_t rep = 0; rep < 4; ++rep) { - std::vector responses = BatchGemmaReply(inputs, mode); - for (size_t i = 0; - i < HWY_MIN(hwy::Unpredictable1() * 3, responses.size()); ++i) { - fprintf(stderr, "Rep %zu batch answer %zu '%s'\n\n", rep, i, - responses[i].c_str()); - } - PROFILER_PRINT_RESULTS(); + // Run multiple times so that auto-tuning is closer to complete. + for (size_t rep = 0; rep < 4; ++rep) { + std::vector responses = + BatchGemmaReply(inputs, AttentionImpl::kFlash); + for (size_t i = 0; i < HWY_MIN(hwy::Unpredictable1() * 3, responses.size()); + ++i) { + fprintf(stderr, "Rep %zu batch answer %zu '%s'\n\n", rep, i, + responses[i].c_str()); } + PROFILER_PRINT_RESULTS(); } } diff --git a/evals/wheat_from_chaff_test.cc b/evals/wheat_from_chaff_test.cc index f78ec17d..8981beb4 100644 --- a/evals/wheat_from_chaff_test.cc +++ b/evals/wheat_from_chaff_test.cc @@ -146,7 +146,7 @@ GemmaEnv* GemmaTest::s_env = nullptr; // Tests whether Gemma can find the right answer in varying levels of // background information, ranging from the bare facts to outright distraction. TEST_F(GemmaTest, WheatFromChaff) { - const AttentionImpl modes[] = {AttentionImpl::kOld, AttentionImpl::kFlash}; + const AttentionImpl modes[] = {AttentionImpl::kFlash}; fprintf(stderr, "Warmup, mode %s\n", GetAttentionImplName(modes[0]).c_str()); auto prompt = BuildPrompt({"quark_1.txt", "holiday_story.txt"}, kQuestions); auto response = GemmaReply(prompt, modes[0]); diff --git a/gemma/activations.h b/gemma/activations.h index c14b24e2..f00f81e2 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -70,11 +70,6 @@ struct AttentionActivations { ? layer_config.heads * 3 * layer_config.qkv_dim : layer_config.heads * layer_config.qkv_dim, allocator)), - q_T(MatFactory("q_T", layer_config.qkv_dim, - config.vocab_size == 0 - ? batch_size * layer_config.heads * 3 - : batch_size * layer_config.heads, - allocator)), vit_Q(MatFactory("Q2", batch_size, layer_config.qkv_dim, allocator)), vit_K_T(MatFactory( "K2_T", hwy::RoundUpTo(seq_len, kMaxBF16PerVector), @@ -88,13 +83,6 @@ struct AttentionActivations { allocator, MatPadding::kPacked)), pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size, config.model_dim, allocator)), - // att is only valid for AttentionImpl::kOld. - att(MatFactory( - "att", batch_size, - layer_config.heads * - (runtime_config.attention_impl == AttentionImpl::kOld ? seq_len - : 1), - allocator)), att_out(MatFactory("att_out", batch_size, layer_config.heads * layer_config.qkv_dim, allocator)), @@ -133,20 +121,17 @@ struct AttentionActivations { // fill them in each MatMul call. q.AllocateAndAttachRowPtrs(row_ptrs); q_bf.AllocateAndAttachRowPtrs(row_ptrs); - q_T.AllocateAndAttachRowPtrs(row_ptrs); att_sums.AllocateAndAttachRowPtrs(row_ptrs); } void SetBatchSize(size_t batch_size) { q.OverrideRows(batch_size); q_bf.OverrideRows(batch_size); - // q_T rows are always qkv_dim! vit_Q.OverrideRows(batch_size); // vit_K_T and vit_V_T stay seq_len! pre_att_rms_out.OverrideRows(batch_size); - att.OverrideRows(batch_size); att_out.OverrideRows(batch_size); att_out_reps.OverrideRows(batch_size * rep_factor); // There is no override for [split_]flash_params, because we reserved an @@ -170,14 +155,12 @@ struct AttentionActivations { std::vector split_flash_params; MatStorageT q; // query MatStorageT q_bf; - MatStorageT q_T; // Transposed to maximize attention speed. MatStorageT vit_Q; MatStorageT vit_K_T; MatStorageT vit_V_T; MatStorageT pre_att_rms_out; - MatStorageT att; // attention vector MatStorageT att_out; // attention output MatStorageT att_out_reps; // attention output for each thread. MatStorageT softmax_max; // see OnlineSoftmaxState @@ -218,12 +201,10 @@ struct AttentionActivationsPtrs { activations.split_flash_params) { q = activations.q; q_bf = activations.q_bf; - q_T = activations.q_T; vit_Q = activations.vit_Q; vit_K_T = activations.vit_K_T; vit_V_T = activations.vit_V_T; pre_att_rms_out = activations.pre_att_rms_out; - att = activations.att; att_out = activations.att_out; att_out_reps = activations.att_out_reps; softmax_max = activations.softmax_max; @@ -236,13 +217,11 @@ struct AttentionActivationsPtrs { void SetBatchSize(size_t batch_size) { q.OverrideRows(batch_size); q_bf.OverrideRows(batch_size); - // q_T rows are always qkv_dim! vit_Q.OverrideRows(batch_size); // vit_K_T and vit_V_T stay seq_len! pre_att_rms_out.OverrideRows(batch_size); - att.OverrideRows(batch_size); att_out.OverrideRows(batch_size); softmax_max.OverrideRows(batch_size); softmax_d.OverrideRows(batch_size); @@ -268,8 +247,6 @@ struct AttentionActivationsPtrs { MatPtrT q; // Query matrix of size batch_size x (q_heads * qkv_dim). MatPtrT q_bf; - // Transposed query matrix for faster Q*K^T. - MatPtrT q_T; MatPtrT vit_Q; MatPtrT vit_K_T; @@ -277,20 +254,15 @@ struct AttentionActivationsPtrs { // Output of RMSNorm before attention, size batch_size x model_dim. MatPtrT pre_att_rms_out; - // Attention scores computed from Q*K^T, size batch_size x (q_heads * - // seq_len). - MatPtrT att; // Attention output computed from att * V, size batch_size x (q_heads * // qkv_dim). MatPtrT att_out; MatPtrT att_out_reps; - // The maximum logit value encountered when computing att_out from att, - // size batch_size x q_heads . See OnlineSoftmaxState for details. - // WARNING: Only filled in for AttentionImpl::kOld. + // The maximum logit value encountered when computing att_out, shape + // batch_size x q_heads . See OnlineSoftmaxState for details. MatPtrT softmax_max; - // The sum of scaled exponentials when computing att_out from att, - // size batch_size x q_heads . See OnlineSoftmaxState for details. - // WARNING: Only filled in for AttentionImpl::kOld. + // The sum of scaled exponentials when computing att_out, shape + // batch_size x q_heads . See OnlineSoftmaxState for details. MatPtrT softmax_d; // Accumulation of attention outputs over heads, size batch_size x // model_dim. diff --git a/gemma/attention.cc b/gemma/attention.cc index 5aab57ec..ddcca31e 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -19,17 +19,20 @@ #include #include "compression/types.h" // GEMMA_DISABLED_TARGETS -#include "util/zones.h" -#include "hwy/base.h" #ifndef HWY_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS #include "gemma/activations.h" #include "gemma/configs.h" // kMaxQKVDim +#include "gemma/kv_cache.h" +#include "gemma/query.h" #include "gemma/weights.h" +#include "ops/matmul.h" #include "util/threading.h" #include "util/threading_context.h" +#include "util/zones.h" +#include "hwy/base.h" #include "hwy/profiler.h" // Compiles this file for multiple architectures via "foreach_target.h", to @@ -42,6 +45,7 @@ #include "hwy/highway.h" // After highway.h #include "compression/compress-inl.h" +#include "gemma/attention.h" // includes highway.h #include "gemma/flash_attention.h" #include "gemma/gemma-inl.h" #include "ops/ops-inl.h" @@ -50,21 +54,6 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -// Returns the number of floats per vector (aka NF). -size_t FloatsPerVector() { - using DF = hn::ScalableTag; - const DF df; - return hn::Lanes(df); -} - -// The k-cache and v-cache are setup without knowing NF. So if it hasn't been -// done already, reshape it to take NF into account. -void MaybeReshapeCache(const size_t default_cols, MatPtrT& cache) { - if (default_cols == cache.Cols()) { - cache.ReshapePackedRowsToCols(2 * FloatsPerVector()); - } -} - // Transposes a single row of the kv cache into the k-cache and v-cache. void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k, KV_t* HWY_RESTRICT v, size_t qkv_dim) { @@ -120,32 +109,6 @@ void TransposeOOBKVCacheRow(KV_t* HWY_RESTRICT k, KV_t* HWY_RESTRICT v, } } -// Computes Q.K scores, which are "logits" (or scores) stored to att. -// `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim]. -static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, - const hwy::Divisor& div_seq_len, - const float* HWY_RESTRICT q, - const MatPtrT& k, float* HWY_RESTRICT att, - ThreadingContext& ctx, const size_t worker) { - GCPP_ZONE(ctx, worker, Zones::kGenAttentionQDotK); - const hn::ScalableTag dbf; - const size_t qkv_dim = k.Cols(); - HWY_ALIGN BF16 q_bf[kMaxQKVDim]; - - CompressPerThread tls; - const hn::ScalableTag df; - CompressTraits::Compress(df, q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim), - 0); - - // --seq_len must be large enough to avoid wraparound. - HWY_DASSERT(last_pos < static_cast(div_seq_len.GetDivisor())); - for (size_t pos = start_pos; pos <= last_pos; ++pos) { - const float score = - Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos), qkv_dim); - att[pos] = score; - } -} - void PositionalEncodingQK(float* qk, const size_t layer_idx, const AttentionActivationsPtrs& activations, ThreadingContext& ctx, const size_t worker, @@ -168,141 +131,6 @@ void PositionalEncodingQK(float* qk, const size_t layer_idx, } } -// Accumulates the sum of v (from `kv_cache`) * probability (`att`) into -// `att_out`. Equivalent in gemma/modules.py: -// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj) -// `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim]. -static HWY_INLINE void WeightedSumV( - const size_t start_pos, const size_t last_pos, - const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att, - const MatPtrT& v, float* HWY_RESTRICT att_out, ThreadingContext& ctx, - const size_t worker) { - // --seq_len must be large enough to avoid wraparound. - HWY_DASSERT(last_pos < static_cast(div_seq_len.GetDivisor())); - // TODO: replace with MatMul(att, v) after it supports non-transposed B. - MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), ctx, - worker); - for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { - MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols()); - } -} - -// Calculates the attention outputs for a single q, which may be updated -// in place for RMSNorm. -void SingleDotSoftmaxWeightedSum( - const size_t q_pos, const size_t kv_start_pos, const size_t kv_last_pos, - float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, - const MatPtr& query_norm_scale, const size_t layer_idx, - const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, - float* HWY_RESTRICT att_out, const SMOptions& sm_options, - ThreadingContext& ctx, const size_t worker) { - const float att_cap = activations.config.att_cap; - const float query_scale = activations.query_scale; - // --seq_len must be large enough to avoid wraparound. - HWY_DASSERT(kv_last_pos < activations.SeqLen()); - const LayerConfig& layer_config = activations.config.layer_configs[layer_idx]; - - // Apply rope and scaling to Q. - if (query_norm_scale.HasPtr()) { - CallUpcasted(&query_norm_scale, [&](const auto* weights_t) { - RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q, - layer_config.qkv_dim, ctx, worker); - }); - } - - PositionalEncodingQK(q, layer_idx, activations, ctx, worker, q_pos, - query_scale); - - QDotK(kv_start_pos, kv_last_pos, activations.div_seq_len, q, k, att, ctx, - worker); - - // SoftMax with optional SoftCap yields "probabilities" in att. - const Logits logits(att, kv_last_pos + 1); - MaybeLogitsSoftCap(att_cap, logits, ctx, worker); - Softmax(logits, ctx, worker, /*temperature=*/1.0f, sm_options); - - WeightedSumV(kv_start_pos, kv_last_pos, activations.div_seq_len, att, v, - att_out, ctx, worker); -} - -// The attention window usually starts at 0 unless `pos` is larger than -// the attention window size, then it is `pos` - window_size + 1. -size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx) { - const size_t att_window_size = config.attention_window_sizes[layer_idx]; - return pos - HWY_MIN(att_window_size - 1, pos); -} - -void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, - const MatPtr& query_norm_scale, - AttentionActivationsPtrs& activations, - QBatch& qbatch, ThreadingContext& ctx) { - GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive); - - const hwy::Divisor div_qbatch(qbatch.Size()); - const LayerConfig& layer_config = activations.config.layer_configs[layer_idx]; - const size_t qkv_dim = layer_config.qkv_dim; - - // A "head group" in the context of GQA refers to a collection of query - // heads that share the same key and value heads. - const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads; - - const size_t cache_layer_size = layer_config.CacheLayerSize(); - const size_t seq_len = activations.SeqLen(); - // All layers should have the same number of heads. - HWY_DASSERT(activations.div_heads.GetDivisor() == layer_config.heads); - - // For each head/token/query, compute Q.K, softmax, and weighted V. - const auto func = [&](const size_t task, size_t worker) HWY_ATTR { - const size_t tq_idx = activations.div_heads.Divide(task); - const size_t head = activations.div_heads.Remainder(task); - GCPP_ZONE(ctx, worker, Zones::kGenAttentionDotSoftmaxWeightedSumPar); - - const size_t qi = div_qbatch.Remainder(tq_idx); - const size_t token_idx = div_qbatch.Divide(tq_idx); - auto& kv_cache = qbatch.KV(qi).kv_cache; - - // Find the token position in the query and calculate - // the range of cache positions to attend to. - const size_t pos = qbatch.Pos(qi) + token_idx; - const size_t start_pos = StartPos(pos, activations.config, layer_idx); - size_t last_pos = pos; - const size_t prefix_end = qbatch.PrefixEnd(qi); - if (prefix_end > 0 && prefix_end - 1 > last_pos) { - // last_pos in QDotK and WeightedSumV is inclusive. - last_pos = prefix_end - 1; - } - - float* HWY_RESTRICT q = activations.q.Row(tq_idx) + head * qkv_dim; - float* HWY_RESTRICT att = activations.att.Row(tq_idx) + head * seq_len; - float* HWY_RESTRICT att_out = - activations.att_out.Row(tq_idx) + head * qkv_dim; - SMOptions sm_options{.max_out = activations.softmax_max.Row(tq_idx) + head, - .d_out = activations.softmax_d.Row(tq_idx) + head}; - - // Make strided read-only views into the kv cache for - // this query and head. - const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2; - const size_t kv_head_offset = layer_idx * cache_layer_size + head_offset; - MatPtrT k("k_view", Extents2D(seq_len, qkv_dim)); - k.SetPtr(kv_cache.Row(0) + kv_head_offset, kv_cache.Stride()); - MatPtrT v("v_view", Extents2D(seq_len, qkv_dim)); - v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride()); - - constexpr size_t offset = 0; // placeholder, do not remove - SingleDotSoftmaxWeightedSum(pos + offset, start_pos, last_pos, q, k, v, - query_norm_scale, layer_idx, activations, att, - att_out, sm_options, ctx, worker); - }; - - { - PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin"); - // Full parallelism is helpful, kAcrossClusters is insufficient. - HierarchicalParallelFor( - num_tokens * div_qbatch.GetDivisor() * layer_config.heads, ctx, - Callers::kAttDotSoftmaxWeightedSum, func); - } -} - // Different functions use different naming conventions for the number of // tokens. Functions that are query-independent, such as RMSNorm*, call the // count `num_interleaved`. Functions that are query-dependent, such as @@ -447,17 +275,11 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx, (void)layer_config; // only used in HWY_DASSERT ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env); - if (attention_impl == AttentionImpl::kOld) { - DotSoftmaxWeightedSum(num_tokens, layer_idx, layer.query_norm_scale, - activations, qbatch, env.ctx); - } else { - // * 2 does not help on Turin. - FlashAttention(num_tokens, - /*target_parallelism=*/env.ctx.pools.MaxWorkers() * - AttentionActivations::kThreadReplicationFactor, - layer_idx, layer.query_norm_scale, activations, qbatch, - env.ctx, attention_impl); - } + FlashAttention(num_tokens, + /*target_parallelism=*/env.ctx.pools.MaxWorkers() * + AttentionActivations::kThreadReplicationFactor, + layer_idx, layer.query_norm_scale, activations, qbatch, + env.ctx, attention_impl); SumHeads(layer, activations, env); } diff --git a/gemma/attention.h b/gemma/attention.h index bb8a743d..f8e1e69e 100644 --- a/gemma/attention.h +++ b/gemma/attention.h @@ -21,42 +21,49 @@ #include #include "gemma/activations.h" -#include "gemma/query.h" -#include "gemma/weights.h" +#include "gemma/configs.h" // AttentionImpl +#include "gemma/kv_cache.h" // KV_t +#include "gemma/query.h" // QBatch +#include "gemma/weights.h" // LayerWeightsPtrs #include "ops/matmul.h" -#include "hwy/highway.h" +#include "hwy/highway.h" // HWY_VISIT_TARGETS +#include "hwy/per_target.h" // VectorBytes namespace gcpp { +// Returns the number of floats per vector (aka NF). +inline size_t FloatsPerVector() { return hwy::VectorBytes() / sizeof(float); } + +// The attention window usually starts at 0 unless `pos` is larger than +// the attention window size, then it is `pos` - window_size + 1. +inline size_t StartPos(size_t pos, const ModelConfig& config, + size_t layer_idx) { + const size_t att_window_size = config.attention_window_sizes[layer_idx]; + return pos - HWY_MIN(att_window_size - 1, pos); +} + +// The k-cache and v-cache are setup without knowing NF. So if it hasn't been +// done already, reshape it to take NF into account. Must be called before +// FlashAttention. +inline void MaybeReshapeCache(const size_t default_cols, MatPtrT& cache) { + if (default_cols == cache.Cols()) { + cache.ReshapePackedRowsToCols(2 * FloatsPerVector()); + } +} + // Passed to HWY_VISIT_TARGETS; declares for one target. #define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \ namespace NAMESPACE { \ - size_t FloatsPerVector(); \ - \ - void MaybeReshapeCache(size_t default_cols, MatPtrT& cache); \ - \ void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k, \ KV_t* HWY_RESTRICT v, size_t qkv_dim); \ + void TransposeOOBKVCacheRow(KV_t* HWY_RESTRICT k, KV_t* HWY_RESTRICT v, \ + size_t qkv_dim); \ \ void PositionalEncodingQK(float* qk, size_t layer_idx, \ const AttentionActivationsPtrs& activations, \ ThreadingContext& ctx, size_t worker, size_t pos, \ float mul); \ \ - size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx); \ - \ - void SingleDotSoftmaxWeightedSum( \ - const size_t pos, const size_t start_pos, const size_t last_pos, \ - float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, \ - const MatPtr& query_norm_scale, size_t layer_idx, \ - const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, \ - float* HWY_RESTRICT att_out, ThreadingContext& ctx, size_t worker); \ - \ - void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \ - const MatPtr& query_norm_scale, \ - AttentionActivationsPtrs& activations, \ - QBatch& qbatch, ThreadingContext& ctx); \ - \ void GemmaAttention(size_t num_tokens, const size_t layer_idx, \ const LayerWeightsPtrs& layer, \ AttentionActivationsPtrs& activations, QBatch& qbatch, \ diff --git a/gemma/attention_test.cc b/gemma/attention_test.cc index f7193ac4..54b3690d 100644 --- a/gemma/attention_test.cc +++ b/gemma/attention_test.cc @@ -568,8 +568,6 @@ void RunAttentionTest(AttentionImpl attention_impl) { /*q_head=*/0, kGoldenQ); } -void TestGemmaAttentionOld() { RunAttentionTest(AttentionImpl::kOld); } - void TestGemmaAttentionFlash() { RunAttentionTest(AttentionImpl::kFlash); } } // namespace HWY_NAMESPACE @@ -580,7 +578,6 @@ HWY_AFTER_NAMESPACE(); namespace gcpp { HWY_BEFORE_TEST(AttentionTest); -HWY_EXPORT_AND_TEST_P(AttentionTest, TestGemmaAttentionOld); HWY_EXPORT_AND_TEST_P(AttentionTest, TestGemmaAttentionFlash); HWY_AFTER_TEST(); diff --git a/gemma/configs.cc b/gemma/configs.cc index 0a585da5..271ca8fc 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -714,7 +714,6 @@ Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) { } constexpr std::pair kAttentionImplNameToEnum[] = { - {"old", AttentionImpl::kOld}, {"flash", AttentionImpl::kFlash}, {"flash_transposed_qs", AttentionImpl::kFlashTransposedQs}, {"flash_transposed_qs_bf16", AttentionImpl::kFlashTransposedQsBF16}, @@ -732,9 +731,9 @@ AttentionImpl GetAttentionImpl(const std::string& impl_name) { for (const auto& [name, attention_impl] : kAttentionImplNameToEnum) { if (name == impl_name) return attention_impl; } - HWY_WARN("Unknown attention implementation: %s. Using kOld.\n", + HWY_WARN("Unknown attention implementation: %s. Using kFlash.\n", impl_name.c_str()); - return AttentionImpl::kOld; + return AttentionImpl::kFlash; } } // namespace gcpp diff --git a/gemma/configs.h b/gemma/configs.h index c14455ea..77317cc9 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -35,8 +35,6 @@ namespace gcpp { constexpr size_t kMaxBF16PerVector = HWY_ARCH_MAX_BYTES / sizeof(BF16); -HWY_INLINE_VAR constexpr int kAttentionUseOld = 2; - HWY_INLINE_VAR constexpr size_t kMaxQKVDim = 1024; #ifndef GEMMA_FUSED_FFN @@ -95,8 +93,7 @@ enum class KVEncoding { }; enum class AttentionImpl { - kOld, // Previous Attention implementation - kFlash, // Flash Attention (default) + kFlash = 0, // Flash Attention (default) kFlashTransposedQs, kFlashTransposedQsBF16, kFlashTransposedQsInt16, @@ -106,32 +103,6 @@ enum class AttentionImpl { std::string GetAttentionImplName(AttentionImpl impl); AttentionImpl GetAttentionImpl(const std::string& impl); -/* - * Returns a bitmask of flags to pass to attention functions based on the - * attention implementation selected. - * - * If `hwy_native_dot_bf16` is true, the function will use the old attention - * implementation, ignoring `impl`. - * - * `hwy_native_dot_bf16` needs to be passed in, because the HWY_NATIVE_DOT_BF16 - * macro is not available outside of highway instrumented translation units and - * cannot be made accessible from .h files. - */ -static inline int AttentionImplToFlags(AttentionImpl impl, - int hwy_native_dot_bf16) { - if (hwy_native_dot_bf16) return kAttentionUseOld; - - switch (impl) { - case AttentionImpl::kOld: - return kAttentionUseOld; - case AttentionImpl::kFlash: - case AttentionImpl::kFlashTransposedQs: - case AttentionImpl::kFlashTransposedQsBF16: - default: - return 0; - } -} - // Post attention and ffw normalization type. enum class PostNormType { None, diff --git a/gemma/configs_test.cc b/gemma/configs_test.cc index e6f02579..5d649758 100644 --- a/gemma/configs_test.cc +++ b/gemma/configs_test.cc @@ -49,8 +49,8 @@ TEST(ConfigsTest, TestAttentionImpl) { ASSERT_EQ(GetAttentionImpl(name), impl); } ASSERT_EQ(GetAttentionImplName(AttentionImpl::kSentinel), "unknown"); - ASSERT_EQ(GetAttentionImpl("unknown"), AttentionImpl::kOld); - ASSERT_EQ(GetAttentionImpl("invalid"), AttentionImpl::kOld); + ASSERT_EQ(GetAttentionImpl("unknown"), AttentionImpl::kFlash); + ASSERT_EQ(GetAttentionImpl("invalid"), AttentionImpl::kFlash); } } // namespace gcpp diff --git a/gemma/flash_attention_test.cc b/gemma/flash_attention_test.cc index a4bdf298..04c00051 100644 --- a/gemma/flash_attention_test.cc +++ b/gemma/flash_attention_test.cc @@ -27,6 +27,7 @@ #include "gemma/kv_transcoding.h" #include "gemma/weights.h" #include "ops/matmul.h" +#include "ops/ops.h" #include "util/test_util.h" #include "hwy/nanobenchmark.h" #ifndef HWY_DISABLED_TARGETS @@ -57,12 +58,170 @@ #include "gemma/configs.h" #include "gemma/flash_attention.h" #include "gemma/tiled_attention.h" +#include "ops/ops-inl.h" #include "hwy/tests/test_util-inl.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { +// Old attention implementation for comparison, formerly in gemma/attention.cc. + +// Computes Q.K scores, which are "logits" (or scores) stored to att. +// `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim]. +static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, + const hwy::Divisor& div_seq_len, + const float* HWY_RESTRICT q, + const MatPtrT& k, float* HWY_RESTRICT att, + ThreadingContext& ctx, const size_t worker) { + GCPP_ZONE(ctx, worker, Zones::kGenAttentionQDotK); + const hn::ScalableTag dbf; + const size_t qkv_dim = k.Cols(); + HWY_ALIGN BF16 q_bf[kMaxQKVDim]; + + CompressPerThread tls; + const hn::ScalableTag df; + CompressTraits::Compress(df, q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim), + 0); + + // --seq_len must be large enough to avoid wraparound. + HWY_DASSERT(last_pos < static_cast(div_seq_len.GetDivisor())); + for (size_t pos = start_pos; pos <= last_pos; ++pos) { + const float score = + Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos), qkv_dim); + att[pos] = score; + } +} + +// Accumulates the sum of v (from `kv_cache`) * probability (`att`) into +// `att_out`. Equivalent in gemma/modules.py: +// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj) +// `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim]. +static HWY_INLINE void WeightedSumV( + const size_t start_pos, const size_t last_pos, + const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att, + const MatPtrT& v, float* HWY_RESTRICT att_out, ThreadingContext& ctx, + const size_t worker) { + // --seq_len must be large enough to avoid wraparound. + HWY_DASSERT(last_pos < static_cast(div_seq_len.GetDivisor())); + // TODO: replace with MatMul(att, v) after it supports non-transposed B. + MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), ctx, + worker); + for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { + MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols()); + } +} + +// Calculates the attention outputs for a single q, which may be updated +// in place for RMSNorm. +void SingleDotSoftmaxWeightedSum( + const size_t q_pos, const size_t kv_start_pos, const size_t kv_last_pos, + float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, + const MatPtr& query_norm_scale, const size_t layer_idx, + const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, + float* HWY_RESTRICT att_out, const SMOptions& sm_options, + ThreadingContext& ctx, const size_t worker) { + const float att_cap = activations.config.att_cap; + const float query_scale = activations.query_scale; + // --seq_len must be large enough to avoid wraparound. + HWY_DASSERT(kv_last_pos < activations.SeqLen()); + const LayerConfig& layer_config = activations.config.layer_configs[layer_idx]; + + // Apply rope and scaling to Q. + if (query_norm_scale.HasPtr()) { + CallUpcasted(&query_norm_scale, [&](const auto* weights_t) { + RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q, + layer_config.qkv_dim, ctx, worker); + }); + } + + PositionalEncodingQK(q, layer_idx, activations, ctx, worker, q_pos, + query_scale); + + QDotK(kv_start_pos, kv_last_pos, activations.div_seq_len, q, k, att, ctx, + worker); + + // SoftMax with optional SoftCap yields "probabilities" in att. + const Logits logits(att, kv_last_pos + 1); + MaybeLogitsSoftCap(att_cap, logits, ctx, worker); + Softmax(logits, ctx, worker, /*temperature=*/1.0f, sm_options); + + WeightedSumV(kv_start_pos, kv_last_pos, activations.div_seq_len, att, v, + att_out, ctx, worker); +} + +void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, + const MatPtr& query_norm_scale, + AttentionActivationsPtrs& activations, + MatPtrT att_storage, QBatch& qbatch, + ThreadingContext& ctx) { + GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive); + + const hwy::Divisor div_qbatch(qbatch.Size()); + const LayerConfig& layer_config = activations.config.layer_configs[layer_idx]; + const size_t qkv_dim = layer_config.qkv_dim; + + // A "head group" in the context of GQA refers to a collection of query + // heads that share the same key and value heads. + const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads; + + const size_t cache_layer_size = layer_config.CacheLayerSize(); + const size_t seq_len = activations.SeqLen(); + // All layers should have the same number of heads. + HWY_DASSERT(activations.div_heads.GetDivisor() == layer_config.heads); + + // For each head/token/query, compute Q.K, softmax, and weighted V. + const auto func = [&](const size_t task, size_t worker) HWY_ATTR { + const size_t tq_idx = activations.div_heads.Divide(task); + const size_t head = activations.div_heads.Remainder(task); + GCPP_ZONE(ctx, worker, Zones::kGenAttentionDotSoftmaxWeightedSumPar); + + const size_t qi = div_qbatch.Remainder(tq_idx); + const size_t token_idx = div_qbatch.Divide(tq_idx); + auto& kv_cache = qbatch.KV(qi).kv_cache; + + // Find the token position in the query and calculate + // the range of cache positions to attend to. + const size_t pos = qbatch.Pos(qi) + token_idx; + const size_t start_pos = StartPos(pos, activations.config, layer_idx); + size_t last_pos = pos; + const size_t prefix_end = qbatch.PrefixEnd(qi); + if (prefix_end > 0 && prefix_end - 1 > last_pos) { + // last_pos in QDotK and WeightedSumV is inclusive. + last_pos = prefix_end - 1; + } + + float* HWY_RESTRICT q = activations.q.Row(tq_idx) + head * qkv_dim; + float* HWY_RESTRICT att = att_storage.Row(tq_idx) + head * seq_len; + float* HWY_RESTRICT att_out = + activations.att_out.Row(tq_idx) + head * qkv_dim; + SMOptions sm_options{.max_out = activations.softmax_max.Row(tq_idx) + head, + .d_out = activations.softmax_d.Row(tq_idx) + head}; + + // Make strided read-only views into the kv cache for + // this query and head. + const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2; + const size_t kv_head_offset = layer_idx * cache_layer_size + head_offset; + MatPtrT k("k_view", Extents2D(seq_len, qkv_dim)); + k.SetPtr(kv_cache.Row(0) + kv_head_offset, kv_cache.Stride()); + MatPtrT v("v_view", Extents2D(seq_len, qkv_dim)); + v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride()); + + constexpr size_t offset = 0; // placeholder, do not remove + SingleDotSoftmaxWeightedSum(pos + offset, start_pos, last_pos, q, k, v, + query_norm_scale, layer_idx, activations, att, + att_out, sm_options, ctx, worker); + }; + + { + PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin"); + // Full parallelism is helpful, kAcrossClusters is insufficient. + HierarchicalParallelFor( + num_tokens * div_qbatch.GetDivisor() * layer_config.heads, ctx, + Callers::kAttDotSoftmaxWeightedSum, func); + } +} + using FloatPtr = hwy::AlignedFreeUniquePtr; template @@ -136,8 +295,7 @@ void TestFlashAttention(size_t target_parallelism, AttentionImpl attention_impl) { ThreadingArgs threading_args; ThreadingContext ctx(threading_args); - constexpr size_t kOuter = 1024; - constexpr size_t kInner = 256; + constexpr size_t kSeqLen = 1024; ModelConfig config(Model::GEMMA2_2B, Type::kF32, PromptWrapping::GEMMA_PT); config.att_cap = 1024.0f; TensorInfoRegistry tensor_info_registry(config); @@ -154,26 +312,29 @@ void TestFlashAttention(size_t target_parallelism, Activations activations(runtime_config, config, runtime_config.prefill_tbatch_size, kv_cache.SeqLen(), env.ctx, env.row_ptrs); - std::vector tokens(kOuter); + std::vector tokens(kSeqLen); std::iota(tokens.begin(), tokens.end(), 1); PromptTokens prompt(tokens); AllQueries all_queries(hwy::Span(&prompt, 1), hwy::Span(&kv_cache, 1)); - QBatch qbatch(/*start=*/0, /*max_size=*/kOuter, all_queries); - const size_t batch_size = kOuter; + QBatch qbatch(/*start=*/0, /*max_size=*/kSeqLen, all_queries); + const size_t batch_size = kSeqLen; std::vector> row_ptrs; AttentionActivations attention_storage( - config, layer_config, batch_size, kOuter, runtime_config, + config, layer_config, batch_size, kSeqLen, runtime_config, ctx.pools.MaxWorkers(), ctx.allocator, row_ptrs); - AttentionActivationsPtrs attention(config, kOuter, attention_storage); + AttentionActivationsPtrs att_activations(config, kSeqLen, attention_storage); + MatStorageT att("att", + Extents2D(batch_size, layer_config.heads * kSeqLen), + ctx.allocator, MatPadding::kOdd); const size_t qkv_dim = layer_config.qkv_dim; - ASSERT_EQ(qkv_dim, kInner); + ASSERT_EQ(qkv_dim, 256); const hwy::Divisor div_qbatch(qbatch.Size()); // A "head group" in the context of GQA refers to a collection of query // heads that share the same key and value heads. const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads; const size_t seq_len = - static_cast(attention.div_seq_len.GetDivisor()); + static_cast(att_activations.div_seq_len.GetDivisor()); MaybeReshapeCache(qbatch.KV(0).cache->KOrVDefaultCols(), qbatch.KV(0).k_cache); MaybeReshapeCache(qbatch.KV(0).cache->KOrVDefaultCols(), @@ -205,12 +366,12 @@ void TestFlashAttention(size_t target_parallelism, TransposeKVCacheRow(k_src, k_dest, v_dest, qkv_dim); } } - SetMat(1, attention.q); - DotSoftmaxWeightedSum(tokens.size(), 0, layers.query_norm_scale, attention, - qbatch, ctx); + SetMat(1, att_activations.q); + DotSoftmaxWeightedSum(tokens.size(), 0, layers.query_norm_scale, + att_activations, att, qbatch, ctx); // Copy the output to saved_att to allow for comparison. - auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator); - SetMat(1, attention.q); + auto saved_att = MakeCopyOfMat(att_activations.att_out, ctx.allocator); + SetMat(1, att_activations.q); const size_t total_tasks = tokens.size() * div_qbatch.GetDivisor() * layer_config.heads; const size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, tokens.size(), @@ -219,8 +380,8 @@ void TestFlashAttention(size_t target_parallelism, target_parallelism, kNF, kVTileSize, GetAttentionImplName(attention_impl).c_str()); FlashAttention(tokens.size(), target_parallelism, 0, layers.query_norm_scale, - attention, qbatch, ctx, attention_impl); - AssertClose(attention.att_out, *saved_att); + att_activations, qbatch, ctx, attention_impl); + AssertClose(att_activations.att_out, *saved_att); ctx.profiler.PrintResults(); } diff --git a/gemma/gemma.cc b/gemma/gemma.cc index ae94cc55..17eca7f3 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -81,20 +81,18 @@ namespace HWY_NAMESPACE { void Attention(LayerAttentionType type, const size_t num_tokens, const size_t layer_idx, const LayerWeightsPtrs& layer, Activations& activations, QBatch& qbatch, MatMulEnv& env) { + const int kFlags = 0; if (activations.attention_impl == AttentionImpl::kFlashTransposedQs || activations.attention_impl == AttentionImpl::kFlashTransposedQsBF16 || activations.attention_impl == AttentionImpl::kFlashTransposedQsInt16) { - TiledAttention( - activations.attention_impl, num_tokens, layer_idx, layer, - activations.attention, qbatch, env, - AttentionImplToFlags(activations.attention_impl, HWY_NATIVE_DOT_BF16)); + TiledAttention(activations.attention_impl, num_tokens, layer_idx, layer, + activations.attention, qbatch, env, kFlags); return; } if (type == LayerAttentionType::kGemma) { - // TODO: remove flag to enable FlashAttention. GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch, - env, activations.attention_impl, /*flags=*/0); + env, activations.attention_impl, kFlags); } } diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h index 67472822..2b122760 100644 --- a/gemma/kv_cache.h +++ b/gemma/kv_cache.h @@ -19,7 +19,6 @@ #include #include -#include #include #include "gemma/configs.h" // ModelConfig diff --git a/gemma/tiled_attention_test.cc b/gemma/tiled_attention_test.cc index 79282c2d..76af6655 100644 --- a/gemma/tiled_attention_test.cc +++ b/gemma/tiled_attention_test.cc @@ -1,7 +1,5 @@ #include -#include -#include #include #include #include @@ -95,7 +93,679 @@ struct AttentionTestEnv { } bool transposed = - attention_impl == AttentionImpl::kFlashTransposedQsBF16 + attention_impl == AttentionImpl::kFlashTransposedQsBF16; + gcpp::KVEncoding encoding; + const Type type = kv_caches.back().compact_kv_cache_ptr.GetType(); + if (type == Type::kInt8) { + encoding = transposed ? gcpp::KVEncoding::kInt8TwoTranspositions + : gcpp::KVEncoding::kInt8; + } else if (type == Type::kBF16) { + encoding = transposed ? gcpp::KVEncoding::kBF16TwoTranspositions + : gcpp::KVEncoding::kBF16; + } else { + encoding = transposed ? gcpp::KVEncoding::kF32TwoTranspositions + : gcpp::KVEncoding::kF32; + } + std::optional bytes_opt = + gcpp::GetTileSizeBytes(encoding, qkv_dim); + HWY_ASSERT(bytes_opt.has_value()); + size_t bytes = bytes_opt.value(); + hwy::Span encoded( + reinterpret_cast( + kv_caches.back().compact_kv_cache_ptr.RowBytes(i)), + bytes); + bool encode_success = + gcpp::EncodeTile(encoding, decoded, qkv_dim, encoded); + HWY_ASSERT(encode_success); + } + } else { + FillMatPtrT(kv_caches.back().kv_cache); + } + all_queries.Append({ + .prompt = PromptTokens({1, 2, 3}), + .mutable_pos = static_cast(last_pos), + .initial_pos = 0, + .prefix_end = 0, + .kv_cache = kv_caches.back().ToPtr(), + }); + } + + activations = std::make_unique(runtime_config, model_config, + qbatch_size * num_tokens, + kv_seq_len, ctx, env.row_ptrs); + + qbatch = + std::make_unique(/*start_pos=*/0, qbatch_size, all_queries); + } + + void SetupWeights() { + int model_dim = layer_config.model_dim; + int qkv_dim = layer_config.qkv_dim; + int num_heads = layer_config.heads; + int num_kv_heads = layer_config.kv_heads; + + qkv1_w_storage = + MatStorageT("qkv1", Extents2D(model_dim, qkv_dim * num_heads), + ctx.allocator, MatPadding::kPacked); + qkv2_w_storage = MatStorageT( + "qkv2", Extents2D(model_dim, num_kv_heads * 2 * qkv_dim), ctx.allocator, + MatPadding::kPacked); + wo_w_storage = MatStorageT("wo", Extents2D(model_dim, model_dim), + ctx.allocator, MatPadding::kPacked); + + FillMatPtrT(wo_w_storage); + layer->att_weights = wo_w_storage; + FillMatPtrT(qkv1_w_storage); + FillMatPtrT(qkv2_w_storage); + layer->qkv_einsum_w1 = qkv1_w_storage; + layer->qkv_einsum_w2 = qkv2_w_storage; + + query_norm_scale = MatStorageT("query_norm", qkv_dim, ctx.allocator); + FillMatPtrT(query_norm_scale); + layer->query_norm_scale = query_norm_scale; + + key_norm_scale = MatStorageT("key_norm", qkv_dim, ctx.allocator); + FillMatPtrT(key_norm_scale); + layer->key_norm_scale = key_norm_scale; + } + + AttentionTestEnv(const AttentionTestEnv&) = delete; + AttentionTestEnv& operator=(const AttentionTestEnv&) = delete; + AttentionTestEnv(AttentionTestEnv&&) = delete; + AttentionTestEnv& operator=(AttentionTestEnv&&) = delete; + + ThreadingArgs threading_args; + ThreadingContext ctx; + MatMulEnv env; + LayerConfig layer_config; + ModelConfig model_config; + std::unique_ptr tensor_info_registry; + std::unique_ptr layer; + RuntimeConfig runtime_config; + InferenceArgs inference_args; + AllQueries all_queries; + std::vector kv_caches; + std::unique_ptr activations; + std::unique_ptr qbatch; + + // Weights storage for later tests + MatStorageT qkv1_w_storage; + MatStorageT qkv2_w_storage; + MatStorageT wo_w_storage; + MatStorageT query_norm_scale; + MatStorageT key_norm_scale; +}; + +void TestTransposeStridedQueries() { + ThreadingArgs threading_args; + ThreadingContext ctx(threading_args); + size_t qkv_dim = 64; + size_t num_queries = 24; + AlignedPtr input_queries = + ctx.allocator.Alloc(qkv_dim * num_queries); + AlignedPtr output_queries = + ctx.allocator.Alloc(qkv_dim * num_queries); + for (size_t i = 0; i < num_queries; ++i) { + for (size_t j = 0; j < qkv_dim; ++j) { + input_queries[i * qkv_dim + j] = i * qkv_dim + j; + } + } + std::vector queries; + for (size_t i = 0; i < num_queries; ++i) { + queries.push_back(input_queries.get() + i * qkv_dim); + } + hwy::Span queries_span(queries.data(), queries.size()); + + TransposeStridedQueries( + queries_span, qkv_dim, + hwy::Span(output_queries.get(), qkv_dim * num_queries)); + for (size_t i = 0; i < num_queries; ++i) { + for (size_t j = 0; j < qkv_dim; ++j) { + EXPECT_EQ(output_queries[j * num_queries + i], + input_queries[i * qkv_dim + j]) + << "i=" << i << " j=" << j; + } + } +} + +void TestLocalAttentionForAllHeadsTokensAndBatch() { + size_t qkv_dim = 64; + size_t kv_seq_len = 64; + size_t num_kv_heads = 2; + size_t num_heads = 2; + size_t num_tokens = 2; + size_t last_pos = 62; // so token 0 will have 63 and token 1 will have 64 + // tokens to attend to. + float att_cap = 10.0f; + size_t layer_idx = 0; + size_t layers_total = 1; + size_t qbatch_size = 2; + AttentionImpl attention_impl = AttentionImpl::kFlashTransposedQs; + AttentionTestEnv test_env(qkv_dim, kv_seq_len, kv_seq_len, num_kv_heads, + num_heads, num_tokens, last_pos, att_cap, layer_idx, + layers_total, qbatch_size, attention_impl); + FillMatPtrT(test_env.activations->attention.q); + LocalAttentionForAllHeadsTokensAndBatch( + attention_impl, num_tokens, layer_idx, *test_env.layer, + test_env.activations->attention, *test_env.qbatch, test_env.ctx); + + // print states; + std::vector exp_denominator_sums_gold = {63, 63, 64, 64, + 63, 63, 64, 64}; + std::vector max_logits_gold = {10, 10, 10, 10, 10, 10, 10, 10}; + std::vector att_out_gold = { + 30.2575, 30.2675, 30.2775, 30.2875, 30.2975, 30.3075, 30.3175, 30.3275, + 30.3375, 30.3475, 30.3575, 30.3675, 30.3775, 30.3875, 30.3975, 30.4075, + 30.4175, 30.4275, 30.4375, 30.4475, 30.4575, 30.4675, 30.4775, 30.4875, + 30.4975, 30.5075, 30.5175, 30.5275, 30.5375, 30.5475, 30.5575, 30.5675, + 30.5775, 30.5875, 30.5975, 30.6075, 30.6175, 30.6275, 30.6375, 30.6475, + 30.6575, 30.6675, 30.6775, 30.6875, 30.6975, 30.7075, 30.7175, 30.7275, + 30.7375, 30.7475, 30.7575, 30.7675, 30.7775, 30.7875, 30.7975, 30.8075, + 30.8175, 30.8275, 30.8375, 30.8475, 30.8575, 30.8675, 30.8775, 30.8875, + 30.2775, 30.2875, 30.2975, 30.3075, 30.3175, 30.3275, 30.3375, 30.3475, + 30.3575, 30.3675, 30.3775, 30.3875, 30.3975, 30.4075, 30.4175, 30.4275, + 30.4375, 30.4475, 30.4575, 30.4675, 30.4775, 30.4875, 30.4975, 30.5075, + 30.5175, 30.5275, 30.5375, 30.5475, 30.5575, 30.5675, 30.5775, 30.5875, + 30.5975, 30.6075, 30.6175, 30.6275, 30.6375, 30.6475, 30.6575, 30.6675, + 30.6775, 30.6875, 30.6975, 30.7075, 30.7175, 30.7275, 30.7375, 30.7475, + 30.7575, 30.7675, 30.7775, 30.7875, 30.7975, 30.8075, 30.8175, 30.8275, + 30.8375, 30.8475, 30.8575, 30.8675, 30.8775, 30.8875, 30.8975, 30.9075, + 30.415, 30.425, 30.435, 30.445, 30.455, 30.465, 30.475, 30.485, + 30.495, 30.505, 30.515, 30.525, 30.535, 30.545, 30.555, 30.565, + 30.575, 30.585, 30.595, 30.605, 30.615, 30.625, 30.635, 30.645, + 30.655, 30.665, 30.675, 30.685, 30.695, 30.705, 30.715, 30.725, + 30.735, 30.745, 30.755, 30.765, 30.775, 30.785, 30.795, 30.805, + 30.815, 30.825, 30.835, 30.845, 30.855, 30.865, 30.875, 30.885, + 30.895, 30.905, 30.915, 30.925, 30.935, 30.945, 30.955, 30.965, + 30.975, 30.985, 30.995, 31.005, 31.015, 31.025, 31.035, 31.045, + 30.435, 30.445, 30.455, 30.465, 30.475, 30.485, 30.495, 30.505, + 30.515, 30.525, 30.535, 30.545, 30.555, 30.565, 30.575, 30.585, + 30.595, 30.605, 30.615, 30.625, 30.635, 30.645, 30.655, 30.665, + 30.675, 30.685, 30.695, 30.705, 30.715, 30.725, 30.735, 30.745, + 30.755, 30.765, 30.775, 30.785, 30.795, 30.805, 30.815, 30.825, + 30.835, 30.845, 30.855, 30.865, 30.875, 30.885, 30.895, 30.905, + 30.915, 30.925, 30.935, 30.945, 30.955, 30.965, 30.975, 30.985, + 30.995, 31.005, 31.015, 31.025, 31.035, 31.045, 31.055, 31.065, + 30.2575, 30.2675, 30.2775, 30.2875, 30.2975, 30.3075, 30.3175, 30.3275, + 30.3375, 30.3475, 30.3575, 30.3675, 30.3775, 30.3875, 30.3975, 30.4075, + 30.4175, 30.4275, 30.4375, 30.4475, 30.4575, 30.4675, 30.4775, 30.4875, + 30.4975, 30.5075, 30.5175, 30.5275, 30.5375, 30.5475, 30.5575, 30.5675, + 30.5775, 30.5875, 30.5975, 30.6075, 30.6175, 30.6275, 30.6375, 30.6475, + 30.6575, 30.6675, 30.6775, 30.6875, 30.6975, 30.7075, 30.7175, 30.7275, + 30.7375, 30.7475, 30.7575, 30.7675, 30.7775, 30.7875, 30.7975, 30.8075, + 30.8175, 30.8275, 30.8375, 30.8475, 30.8575, 30.8675, 30.8775, 30.8875, + 30.2775, 30.2875, 30.2975, 30.3075, 30.3175, 30.3275, 30.3375, 30.3475, + 30.3575, 30.3675, 30.3775, 30.3875, 30.3975, 30.4075, 30.4175, 30.4275, + 30.4375, 30.4475, 30.4575, 30.4675, 30.4775, 30.4875, 30.4975, 30.5075, + 30.5175, 30.5275, 30.5375, 30.5475, 30.5575, 30.5675, 30.5775, 30.5875, + 30.5975, 30.6075, 30.6175, 30.6275, 30.6375, 30.6475, 30.6575, 30.6675, + 30.6775, 30.6875, 30.6975, 30.7075, 30.7175, 30.7275, 30.7375, 30.7475, + 30.7575, 30.7675, 30.7775, 30.7875, 30.7975, 30.8075, 30.8175, 30.8275, + 30.8375, 30.8475, 30.8575, 30.8675, 30.8775, 30.8875, 30.8975, 30.9075, + 30.415, 30.425, 30.435, 30.445, 30.455, 30.465, 30.475, 30.485, + 30.495, 30.505, 30.515, 30.525, 30.535, 30.545, 30.555, 30.565, + 30.575, 30.585, 30.595, 30.605, 30.615, 30.625, 30.635, 30.645, + 30.655, 30.665, 30.675, 30.685, 30.695, 30.705, 30.715, 30.725, + 30.735, 30.745, 30.755, 30.765, 30.775, 30.785, 30.795, 30.805, + 30.815, 30.825, 30.835, 30.845, 30.855, 30.865, 30.875, 30.885, + 30.895, 30.905, 30.915, 30.925, 30.935, 30.945, 30.955, 30.965, + 30.975, 30.985, 30.995, 31.005, 31.015, 31.025, 31.035, 31.045, + 30.435, 30.445, 30.455, 30.465, 30.475, 30.485, 30.495, 30.505, + 30.515, 30.525, 30.535, 30.545, 30.555, 30.565, 30.575, 30.585, + 30.595, 30.605, 30.615, 30.625, 30.635, 30.645, 30.655, 30.665, + 30.675, 30.685, 30.695, 30.705, 30.715, 30.725, 30.735, 30.745, + 30.755, 30.765, 30.775, 30.785, 30.795, 30.805, 30.815, 30.825, + 30.835, 30.845, 30.855, 30.865, 30.875, 30.885, 30.895, 30.905, + 30.915, 30.925, 30.935, 30.945, 30.955, 30.965, 30.975, 30.985, + 30.995, 31.005, 31.015, 31.025, 31.035, 31.045, 31.055, 31.065, + }; + const size_t group_size = num_heads / num_kv_heads; + for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { + for (size_t q_batch_idx = 0; q_batch_idx < qbatch_size; ++q_batch_idx) { + size_t b = token_idx * qbatch_size + q_batch_idx; + EXPECT_THAT( + absl::MakeSpan(test_env.activations->attention.softmax_d.Row(b), + num_heads), + Pointwise(FloatNear(1e-3f), absl::MakeSpan(exp_denominator_sums_gold) + .subspan(b * num_heads, num_heads))); + EXPECT_THAT( + absl::MakeSpan(test_env.activations->attention.softmax_max.Row(b), + num_heads), + Pointwise(FloatNear(1e-3f), absl::MakeSpan(max_logits_gold) + .subspan(b * num_heads, num_heads))); + for (size_t kv_h = 0; kv_h < num_kv_heads; ++kv_h) { + for (size_t g = 0; g < group_size; ++g) { + const size_t q_h = kv_h * group_size + g; + size_t expected_q_idx = b * num_heads + q_h; + EXPECT_THAT( + absl::MakeSpan(test_env.activations->attention.att_out.Row(b) + + q_h * qkv_dim, + qkv_dim), + Pointwise(FloatNear(1e-3f), + absl::MakeSpan(att_out_gold) + .subspan(expected_q_idx * qkv_dim, qkv_dim))); + } + } + } + } +} + +const std::vector AttentionMultipleTokensAttentionGoldens = { + 34.7414, 34.7717, 34.8022, 34.8327, 34.8631, 34.8936, 34.9241, 34.9545, + 34.985, 35.0156, 35.046, 35.0765, 35.1068, 35.1373, 35.1678, 35.1982, + 35.2286, 35.2592, 35.2895, 35.32, 35.3506, 35.381, 35.4115, 35.4421, + 35.4725, 35.503, 35.5334, 35.5638, 35.5943, 35.6247, 35.6552, 35.6857, + 35.7161, 35.7466, 35.7772, 35.8076, 35.8381, 35.8685, 35.8989, 35.9294, + 35.9598, 35.9902, 36.0208, 36.0512, 36.0816, 36.1122, 36.1426, 36.1731, + 36.2037, 36.2341, 36.2646, 36.295, 36.3254, 36.356, 36.3863, 36.4168, + 36.4474, 36.4778, 36.5082, 36.5388, 36.5692, 36.5997, 36.6301, 36.6605, + 34.6687, 34.6987, 34.7288, 34.759, 34.7891, 34.8192, 34.8495, 34.8795, + 34.9097, 34.9399, 34.97, 35.0002, 35.0302, 35.0604, 35.0906, 35.1206, + 35.1507, 35.181, 35.211, 35.2412, 35.2714, 35.3015, 35.3317, 35.3619, + 35.3921, 35.4222, 35.4523, 35.4824, 35.5126, 35.5427, 35.5728, 35.603, + 35.6331, 35.6633, 35.6935, 35.7236, 35.7538, 35.7838, 35.814, 35.8442, + 35.8742, 35.9043, 35.9346, 35.9646, 35.9948, 36.025, 36.0551, 36.0853, + 36.1155, 36.1456, 36.1759, 36.2059, 36.236, 36.2662, 36.2963, 36.3264, + 36.3566, 36.3867, 36.4169, 36.4471, 36.4772, 36.5074, 36.5374, 36.5676, + 37.0338, 37.0634, 37.0929, 37.1222, 37.1519, 37.1813, 37.2107, 37.2403, + 37.2698, 37.2992, 37.3288, 37.3584, 37.3877, 37.4174, 37.447, 37.4764, + 37.5056, 37.5352, 37.5646, 37.5938, 37.6234, 37.6528, 37.6821, 37.7117, + 37.7412, 37.7705, 37.8001, 37.8295, 37.8589, 37.8885, 37.918, 37.9473, + 37.977, 38.0065, 38.0358, 38.0655, 38.095, 38.1244, 38.1541, 38.1836, + 38.213, 38.2422, 38.2718, 38.3012, 38.3305, 38.36, 38.3895, 38.4187, + 38.4484, 38.4778, 38.5071, 38.5367, 38.5662, 38.5955, 38.6251, 38.6546, + 38.6839, 38.7136, 38.7431, 38.7725, 38.8021, 38.8316, 38.861, 38.8907, + 36.9872, 37.0167, 37.046, 37.0752, 37.1047, 37.1341, 37.1633, 37.1928, + 37.2222, 37.2514, 37.2809, 37.3103, 37.3396, 37.3691, 37.3985, 37.4278, + 37.4569, 37.4863, 37.5156, 37.5447, 37.5742, 37.6035, 37.6326, 37.6621, + 37.6914, 37.7206, 37.7501, 37.7794, 37.8086, 37.8381, 37.8674, 37.8966, + 37.9262, 37.9555, 37.9848, 38.0143, 38.0437, 38.0729, 38.1025, 38.1319, + 38.1612, 38.1903, 38.2197, 38.249, 38.2781, 38.3075, 38.3368, 38.366, + 38.3955, 38.4248, 38.4539, 38.4834, 38.5127, 38.5419, 38.5714, 38.6008, + 38.63, 38.6595, 38.6889, 38.7181, 38.7477, 38.777, 38.8063, 38.8358, + 39.0984, 39.1479, 39.1976, 39.2475, 39.297, 39.3468, 39.3967, 39.4463, + 39.4961, 39.546, 39.5957, 39.6455, 39.695, 39.7447, 39.7946, 39.8441, + 39.8939, 39.9438, 39.9934, 40.0431, 40.0931, 40.1427, 40.1925, 40.2425, + 40.2921, 40.342, 40.3915, 40.4412, 40.4911, 40.5407, 40.5904, 40.6403, + 40.6899, 40.7397, 40.7897, 40.8393, 40.8892, 40.9387, 40.9884, 41.0382, + 41.0878, 41.1375, 41.1874, 41.237, 41.2868, 41.3367, 41.3863, 41.4361, + 41.4861, 41.5358, 41.5856, 41.6351, 41.6849, 41.7347, 41.7843, 41.834, + 41.884, 41.9336, 41.9834, 42.0333, 42.083, 42.1328, 42.1823, 42.232, + 38.9699, 39.0188, 39.068, 39.1173, 39.1663, 39.2155, 39.2648, 39.3138, + 39.3631, 39.4124, 39.4615, 39.5108, 39.5597, 39.6089, 39.6581, 39.7071, + 39.7563, 39.8056, 39.8546, 39.9039, 39.9532, 40.0023, 40.0515, 40.1009, + 40.15, 40.1993, 40.2483, 40.2974, 40.3467, 40.3957, 40.4449, 40.4942, + 40.5433, 40.5925, 40.6419, 40.691, 40.7402, 40.7892, 40.8383, 40.8876, + 40.9366, 40.9857, 41.035, 41.0841, 41.1333, 41.1826, 41.2317, 41.2809, + 41.3303, 41.3794, 41.4287, 41.4777, 41.5268, 41.5761, 41.6251, 41.6743, + 41.7237, 41.7727, 41.8219, 41.8713, 41.9204, 41.9697, 42.0186, 42.0677, + 43.4945, 43.5425, 43.5902, 43.6376, 43.6856, 43.7334, 43.7808, 43.8289, + 43.8766, 43.9241, 43.9722, 44.02, 44.0675, 44.1157, 44.1635, 44.2111, + 44.2583, 44.3062, 44.3538, 44.4011, 44.449, 44.4966, 44.544, 44.5919, + 44.6396, 44.6869, 44.735, 44.7826, 44.8301, 44.8781, 44.9258, 44.9733, + 45.0213, 45.0691, 45.1166, 45.1647, 45.2125, 45.26, 45.3081, 45.356, + 45.4035, 45.4508, 45.4987, 45.5462, 45.5936, 45.6415, 45.6891, 45.7364, + 45.7844, 45.832, 45.8794, 45.9274, 45.9751, 46.0225, 46.0705, 46.1183, + 46.1657, 46.2138, 46.2615, 46.309, 46.3571, 46.4049, 46.4525, 46.5006, + 43.4125, 43.4603, 43.5077, 43.5549, 43.6027, 43.6502, 43.6974, 43.7453, + 43.7928, 43.84, 43.8879, 43.9355, 43.9828, 44.0307, 44.0783, 44.1256, + 44.1726, 44.2203, 44.2676, 44.3147, 44.3624, 44.4098, 44.4569, 44.5046, + 44.552, 44.5992, 44.6469, 44.6944, 44.7416, 44.7894, 44.8369, 44.8841, + 44.9319, 44.9795, 45.0267, 45.0746, 45.1222, 45.1694, 45.2173, 45.265, + 45.3123, 45.3593, 45.407, 45.4543, 45.5014, 45.5491, 45.5965, 45.6436, + 45.6913, 45.7387, 45.7859, 45.8336, 45.8811, 45.9283, 45.9761, 46.0236, + 46.0708, 46.1186, 46.1661, 46.2134, 46.2613, 46.3088, 46.3561, 46.404, + 34.7729, 34.8035, 34.8341, 34.8648, 34.8953, 34.9259, 34.9567, 34.9872, + 35.0179, 35.0486, 35.0792, 35.1098, 35.1404, 35.171, 35.2016, 35.2322, + 35.2628, 35.2935, 35.324, 35.3547, 35.3854, 35.416, 35.4466, 35.4774, + 35.508, 35.5387, 35.5692, 35.5998, 35.6305, 35.661, 35.6916, 35.7224, + 35.7529, 35.7836, 35.8143, 35.8449, 35.8755, 35.9061, 35.9367, 35.9674, + 35.9979, 36.0285, 36.0592, 36.0898, 36.1204, 36.1511, 36.1817, 36.2123, + 36.2431, 36.2737, 36.3044, 36.3349, 36.3655, 36.3962, 36.4267, 36.4574, + 36.4881, 36.5186, 36.5493, 36.58, 36.6106, 36.6413, 36.6718, 36.7024, + 34.6995, 34.7297, 34.76, 34.7904, 34.8206, 34.8509, 34.8813, 34.9115, + 34.9418, 34.9722, 35.0025, 35.0328, 35.063, 35.0933, 35.1237, 35.1539, + 35.1842, 35.2146, 35.2448, 35.2751, 35.3055, 35.3357, 35.3661, 35.3965, + 35.4268, 35.4571, 35.4873, 35.5176, 35.548, 35.5782, 35.6085, 35.6389, + 35.6691, 35.6994, 35.7298, 35.7601, 35.7904, 35.8206, 35.8509, 35.8813, + 35.9115, 35.9418, 35.9721, 36.0024, 36.0327, 36.0631, 36.0933, 36.1237, + 36.1541, 36.1843, 36.2147, 36.2449, 36.2752, 36.3056, 36.3358, 36.3661, + 36.3965, 36.4267, 36.457, 36.4874, 36.5177, 36.548, 36.5782, 36.6085, + 37.0829, 37.1127, 37.1423, 37.1717, 37.2015, 37.2312, 37.2607, 37.2905, + 37.3201, 37.3496, 37.3795, 37.4091, 37.4386, 37.4685, 37.4982, 37.5277, + 37.5571, 37.5868, 37.6164, 37.6458, 37.6755, 37.7051, 37.7346, 37.7643, + 37.7939, 37.8234, 37.8531, 37.8827, 37.9122, 37.942, 37.9716, 38.0011, + 38.0309, 38.0606, 38.0901, 38.1199, 38.1496, 38.1791, 38.209, 38.2387, + 38.2682, 38.2976, 38.3273, 38.3569, 38.3863, 38.416, 38.4456, 38.475, + 38.5048, 38.5344, 38.5638, 38.5936, 38.6232, 38.6527, 38.6825, 38.7121, + 38.7416, 38.7714, 38.8011, 38.8306, 38.8604, 38.8901, 38.9196, 38.9494, + 37.0359, 37.0655, 37.095, 37.1243, 37.154, 37.1835, 37.2129, 37.2425, + 37.2721, 37.3014, 37.3311, 37.3607, 37.39, 37.4198, 37.4493, 37.4787, + 37.508, 37.5376, 37.567, 37.5963, 37.6259, 37.6553, 37.6846, 37.7142, + 37.7437, 37.773, 37.8027, 37.8322, 37.8615, 37.8911, 37.9207, 37.95, + 37.9797, 38.0092, 38.0386, 38.0683, 38.0978, 38.1272, 38.1569, 38.1865, + 38.2159, 38.2451, 38.2747, 38.3042, 38.3334, 38.363, 38.3925, 38.4218, + 38.4514, 38.4809, 38.5102, 38.5398, 38.5693, 38.5986, 38.6283, 38.6578, + 38.6872, 38.7168, 38.7464, 38.7757, 38.8054, 38.835, 38.8644, 38.8941, + 39.1594, 39.2093, 39.2593, 39.3095, 39.3594, 39.4094, 39.4597, 39.5096, + 39.5597, 39.61, 39.6599, 39.7101, 39.7599, 39.8099, 39.8601, 39.91, + 39.96, 40.0102, 40.0601, 40.1102, 40.1605, 40.2104, 40.2605, 40.3108, + 40.3608, 40.411, 40.4608, 40.5108, 40.561, 40.6109, 40.661, 40.7112, + 40.7611, 40.8112, 40.8615, 40.9115, 40.9616, 41.0114, 41.0614, 41.1116, + 41.1615, 41.2115, 41.2617, 41.3116, 41.3617, 41.412, 41.4619, 41.512, + 41.5624, 41.6123, 41.6625, 41.7123, 41.7623, 41.8126, 41.8624, 41.9125, + 41.9627, 42.0127, 42.0628, 42.113, 42.163, 42.2131, 42.263, 42.313, + 39.0297, 39.079, 39.1284, 39.1781, 39.2274, 39.2769, 39.3265, 39.3759, + 39.4254, 39.4751, 39.5245, 39.5741, 39.6233, 39.6727, 39.7224, 39.7716, + 39.8211, 39.8708, 39.9201, 39.9696, 40.0193, 40.0686, 40.1182, 40.1679, + 40.2173, 40.2669, 40.3162, 40.3656, 40.4153, 40.4646, 40.514, 40.5637, + 40.6131, 40.6626, 40.7123, 40.7617, 40.8112, 40.8605, 40.9099, 40.9595, + 41.0088, 41.0583, 41.1079, 41.1573, 41.2068, 41.2565, 41.3058, 41.3554, + 41.4051, 41.4545, 41.5041, 41.5534, 41.6028, 41.6524, 41.7017, 41.7512, + 41.8009, 41.8502, 41.8998, 41.9495, 41.9988, 42.0484, 42.0977, 42.1471, + 43.5891, 43.6374, 43.6854, 43.7331, 43.7814, 43.8294, 43.8772, 43.9255, + 43.9736, 44.0214, 44.0698, 44.1179, 44.1657, 44.2141, 44.2623, 44.3101, + 44.3577, 44.4058, 44.4537, 44.5013, 44.5495, 44.5974, 44.6451, 44.6933, + 44.7413, 44.7889, 44.8372, 44.8852, 44.9329, 44.9812, 45.0293, 45.077, + 45.1254, 45.1734, 45.2212, 45.2696, 45.3177, 45.3655, 45.414, 45.4621, + 45.5099, 45.5575, 45.6057, 45.6535, 45.7011, 45.7493, 45.7973, 45.8449, + 45.8931, 45.9411, 45.9888, 46.037, 46.085, 46.1327, 46.1811, 46.2291, + 46.2768, 46.3252, 46.3733, 46.421, 46.4694, 46.5175, 46.5653, 46.6138, + 43.5064, 43.5544, 43.6022, 43.6497, 43.6978, 43.7456, 43.7931, 43.8412, + 43.889, 43.9366, 43.9847, 44.0326, 44.0802, 44.1284, 44.1763, 44.2239, + 44.2712, 44.3191, 44.3668, 44.4141, 44.4621, 44.5098, 44.5572, 44.6052, + 44.6529, 44.7004, 44.7484, 44.7962, 44.8436, 44.8918, 44.9395, 44.987, + 45.0352, 45.083, 45.1305, 45.1787, 45.2266, 45.2742, 45.3223, 45.3703, + 45.4179, 45.4652, 45.5131, 45.5608, 45.6081, 45.6561, 45.7038, 45.7512, + 45.7992, 45.8469, 45.8944, 45.9424, 45.9902, 46.0376, 46.0857, 46.1335, + 46.181, 46.2292, 46.277, 46.3245, 46.3727, 46.4206, 46.4682, 46.5164, +}; + +constexpr HWY_INLINE_VAR int kTiledFlags = 0; + +void TestAttentionMultipleTokens() { + int qkv_dim = 64; + int kv_seq_len = 64; + int num_kv_heads = 2; + int num_heads = 4; + int num_tokens = 2; + int last_pos = 62; // so in the tbatch token 0 will have 63 and token 1 + // will have 64 tokens to attend to. + float att_cap = 10.0f; + int layer_idx = 0; + int layers_total = 1; + int qbatch_size = 2; + AttentionImpl attention_impl = AttentionImpl::kFlashTransposedQs; + AttentionTestEnv test_env(qkv_dim, kv_seq_len, kv_seq_len, num_kv_heads, + num_heads, num_tokens, last_pos, att_cap, layer_idx, + layers_total, qbatch_size, attention_impl); + test_env.SetupWeights(); + FillMatPtrT(test_env.activations->attention.pre_att_rms_out); + FillMatPtrT(test_env.activations->attention.q); + FillMatPtrT(test_env.activations->attention.att_out); + FillMatPtrT(test_env.activations->attention.softmax_max); + FillMatPtrT(test_env.activations->attention.softmax_d); + + TiledAttention(attention_impl, num_tokens, layer_idx, *test_env.layer, + test_env.activations->attention, *test_env.qbatch, + test_env.env, kTiledFlags); + + std::cerr << "q after TiledAttention\n"; + PrintMatPtr(test_env.activations->attention.q); + for (size_t i = 0; i < test_env.activations->attention.att_out.Rows(); ++i) { + EXPECT_TRUE(hwy::CompareArraySimilar( + AttentionMultipleTokensAttentionGoldens.data() + + i * test_env.activations->attention.att_out.Cols(), + test_env.activations->attention.att_out.Row(i), + test_env.activations->attention.att_out.Cols(), 1e-3, + hwy::TargetName(HWY_TARGET), __FILE__, __LINE__)) + << "att_out mismatch for query: " << i; + } +} + +void TestAttentionMultipleTokensAttentionWindowSizeEdgeCase() { + int qkv_dim = 64; + int kv_seq_len = 34; + int num_kv_heads = 2; + int num_heads = 4; + int num_tokens = 2; + int last_pos = 31; // so in the tbatch token 0 will have 63 and token 1 + // will have 64 tokens to attend to. + float att_cap = 10.0f; + int layer_idx = 0; + int layers_total = 1; + int qbatch_size = 2; + int attention_window_size = 32; + AttentionImpl attention_impl = AttentionImpl::kFlashTransposedQs; + AttentionTestEnv test_env(qkv_dim, kv_seq_len, attention_window_size, + num_kv_heads, num_heads, num_tokens, last_pos, + att_cap, layer_idx, layers_total, qbatch_size, + attention_impl); + test_env.SetupWeights(); + FillMatPtrT(test_env.activations->attention.pre_att_rms_out); + FillMatPtrT(test_env.activations->attention.q); + FillMatPtrT(test_env.activations->attention.att_out); + FillMatPtrT(test_env.activations->attention.softmax_max); + FillMatPtrT(test_env.activations->attention.softmax_d); + + TiledAttention(attention_impl, num_tokens, layer_idx, *test_env.layer, + test_env.activations->attention, *test_env.qbatch, + test_env.env, kTiledFlags); + + std::vector att_out_golden_test_local = { + 39.3051, 39.3556, 39.4062, 39.4571, 39.5075, 39.5582, 39.6091, 39.6596, + 39.7103, 39.7612, 39.8118, 39.8626, 39.913, 39.9636, 40.0144, 40.0649, + 40.1155, 40.1664, 40.2169, 40.2676, 40.3185, 40.369, 40.4198, 40.4707, + 40.5213, 40.572, 40.6225, 40.6731, 40.724, 40.7744, 40.8251, 40.876, + 40.9265, 40.9772, 41.0281, 41.0787, 41.1295, 41.1799, 41.2305, 41.2813, + 41.3318, 41.3824, 41.4333, 41.4838, 41.5345, 41.5854, 41.6359, 41.6867, + 41.7376, 41.7882, 41.839, 41.8894, 41.94, 41.9908, 42.0413, 42.092, + 42.1429, 42.1934, 42.2441, 42.295, 42.3456, 42.3964, 42.4468, 42.4974, + 39.1614, 39.2113, 39.2613, 39.3114, 39.3613, 39.4113, 39.4616, 39.5115, + 39.5616, 39.6118, 39.6618, 39.7119, 39.7617, 39.8117, 39.8618, 39.9117, + 39.9617, 40.0119, 40.0618, 40.1118, 40.1621, 40.212, 40.2621, 40.3124, + 40.3623, 40.4125, 40.4623, 40.5123, 40.5625, 40.6123, 40.6624, 40.7126, + 40.7625, 40.8126, 40.8629, 40.9128, 40.9629, 41.0127, 41.0627, 41.1129, + 41.1627, 41.2127, 41.2629, 41.3128, 41.3629, 41.4131, 41.463, 41.5131, + 41.5634, 41.6134, 41.6635, 41.7133, 41.7634, 41.8135, 41.8634, 41.9134, + 41.9637, 42.0135, 42.0636, 42.1139, 42.1638, 42.214, 42.2637, 42.3137, + 43.8459, 43.895, 43.9437, 43.9921, 44.0411, 44.0898, 44.1383, 44.1874, + 44.2361, 44.2846, 44.3337, 44.3825, 44.4311, 44.4802, 44.529, 44.5776, + 44.6258, 44.6747, 44.7233, 44.7716, 44.8205, 44.8692, 44.9175, 44.9665, + 45.0151, 45.0635, 45.1125, 45.1612, 45.2096, 45.2586, 45.3074, 45.3558, + 45.4049, 45.4537, 45.5021, 45.5513, 45.6001, 45.6486, 45.6977, 45.7466, + 45.7951, 45.8434, 45.8923, 45.9409, 45.9891, 46.0381, 46.0867, 46.135, + 46.184, 46.2327, 46.281, 46.33, 46.3787, 46.4271, 46.4762, 46.5249, + 46.5733, 46.6224, 46.6712, 46.7197, 46.7688, 46.8176, 46.8661, 46.9153, + 43.7538, 43.8026, 43.851, 43.8992, 43.948, 43.9964, 44.0446, 44.0934, + 44.142, 44.1902, 44.239, 44.2876, 44.3358, 44.3847, 44.4333, 44.4816, + 44.5296, 44.5782, 44.6266, 44.6746, 44.7232, 44.7716, 44.8197, 44.8684, + 44.9168, 44.9649, 45.0136, 45.0621, 45.1102, 45.159, 45.2075, 45.2557, + 45.3045, 45.353, 45.4012, 45.4501, 45.4986, 45.5469, 45.5958, 45.6444, + 45.6927, 45.7406, 45.7893, 45.8376, 45.8856, 45.9343, 45.9827, 46.0307, + 46.0794, 46.1278, 46.1759, 46.2247, 46.2731, 46.3213, 46.3701, 46.4185, + 46.4667, 46.5155, 46.564, 46.6123, 46.6611, 46.7097, 46.7579, 46.8068, + 48.7531, 48.8438, 48.9348, 49.0262, 49.1169, 49.208, 49.2995, 49.3903, + 49.4815, 49.573, 49.6639, 49.7552, 49.8458, 49.9368, 50.0281, 50.1188, + 50.2099, 50.3013, 50.3921, 50.4832, 50.5747, 50.6656, 50.7568, 50.8484, + 50.9393, 51.0306, 51.1213, 51.2123, 51.3037, 51.3944, 51.4855, 51.577, + 51.6678, 51.759, 51.8505, 51.9414, 52.0327, 52.1233, 52.2143, 52.3056, + 52.3963, 52.4874, 52.5788, 52.6696, 52.7607, 52.8522, 52.9431, 53.0343, + 53.1259, 53.2168, 53.3081, 53.3988, 53.4898, 53.5812, 53.6719, 53.763, + 53.8545, 53.9453, 54.0365, 54.128, 54.2189, 54.3102, 54.4008, 54.4918, + 48.4943, 48.5838, 48.6737, 48.7639, 48.8535, 48.9435, 49.0338, 49.1235, + 49.2135, 49.3039, 49.3937, 49.4838, 49.5732, 49.6631, 49.7533, 49.8428, + 49.9328, 50.023, 50.1127, 50.2027, 50.293, 50.3827, 50.4728, 50.5632, + 50.653, 50.7432, 50.8327, 50.9226, 51.0128, 51.1024, 51.1924, 51.2827, + 51.3724, 51.4624, 51.5528, 51.6425, 51.7327, 51.8221, 51.912, 52.0022, + 52.0917, 52.1817, 52.2719, 52.3616, 52.4516, 52.5419, 52.6316, 52.7217, + 52.8121, 52.9019, 52.9921, 53.0816, 53.1715, 53.2617, 53.3513, 53.4413, + 53.5316, 53.6212, 53.7113, 53.8017, 53.8914, 53.9815, 54.071, 54.1609, + 57.7208, 57.8084, 57.8954, 57.9818, 58.0694, 58.1564, 58.2429, 58.3306, + 58.4177, 58.5043, 58.5921, 58.6793, 58.7659, 58.8537, 58.941, 59.0277, + 59.1137, 59.2011, 59.2878, 59.374, 59.4614, 59.5482, 59.6345, 59.722, + 59.8089, 59.8952, 59.9827, 60.0697, 60.1561, 60.2437, 60.3308, 60.4172, + 60.505, 60.5921, 60.6786, 60.7664, 60.8536, 60.9402, 61.0281, 61.1153, + 61.202, 61.2881, 61.3755, 61.4622, 61.5483, 61.6358, 61.7226, 61.8088, + 61.8963, 61.9832, 62.0695, 62.1571, 62.244, 62.3304, 62.4181, 62.5051, + 62.5916, 62.6793, 62.7664, 62.853, 62.9407, 63.0279, 63.1146, 63.2024, + 57.5554, 57.6426, 57.729, 57.815, 57.9021, 57.9887, 58.0747, 58.162, + 58.2486, 58.3347, 58.422, 58.5087, 58.5949, 58.6823, 58.7691, 58.8553, + 58.9409, 59.0278, 59.114, 59.1997, 59.2867, 59.373, 59.4588, 59.5458, + 59.6323, 59.7181, 59.8052, 59.8917, 59.9776, 60.0648, 60.1514, 60.2374, + 60.3246, 60.4113, 60.4974, 60.5847, 60.6714, 60.7576, 60.8449, 60.9317, + 61.018, 61.1036, 61.1905, 61.2767, 61.3624, 61.4494, 61.5357, 61.6215, + 61.7085, 61.7949, 61.8808, 61.9679, 62.0544, 62.1403, 62.2275, 62.3141, + 62.4001, 62.4873, 62.574, 62.66, 62.7474, 62.8341, 62.9202, 63.0076, + 39.3678, 39.4186, 39.4696, 39.5207, 39.5715, 39.6225, 39.6737, 39.7246, + 39.7756, 39.8268, 39.8777, 39.9288, 39.9796, 40.0305, 40.0816, 40.1324, + 40.1834, 40.2346, 40.2854, 40.3364, 40.3876, 40.4385, 40.4896, 40.5408, + 40.5917, 40.6428, 40.6936, 40.7446, 40.7957, 40.8466, 40.8975, 40.9487, + 40.9996, 41.0506, 41.1019, 41.1528, 41.2038, 41.2546, 41.3055, 41.3567, + 41.4075, 41.4584, 41.5096, 41.5605, 41.6115, 41.6627, 41.7136, 41.7646, + 41.8159, 41.8668, 41.9179, 41.9687, 42.0196, 42.0708, 42.1216, 42.1726, + 42.2238, 42.2746, 42.3256, 42.3769, 42.4278, 42.4789, 42.5296, 42.5806, + 39.2228, 39.2729, 39.3232, 39.3737, 39.4239, 39.4743, 39.5248, 39.575, + 39.6254, 39.676, 39.7263, 39.7767, 39.8268, 39.8771, 39.9276, 39.9778, + 40.0281, 40.0786, 40.1288, 40.1792, 40.2298, 40.28, 40.3304, 40.381, + 40.4313, 40.4818, 40.5319, 40.5822, 40.6327, 40.6829, 40.7333, 40.7838, + 40.834, 40.8844, 40.935, 40.9853, 41.0357, 41.0858, 41.1361, 41.1866, + 41.2368, 41.2871, 41.3376, 41.3878, 41.4382, 41.4888, 41.539, 41.5894, + 41.64, 41.6903, 41.7408, 41.7909, 41.8412, 41.8917, 41.9419, 41.9922, + 42.0428, 42.093, 42.1434, 42.194, 42.2442, 42.2947, 42.3448, 42.3951, + 43.9435, 43.9928, 44.0418, 44.0905, 44.1399, 44.1889, 44.2376, 44.287, + 44.3361, 44.3849, 44.4343, 44.4834, 44.5322, 44.5817, 44.6308, 44.6797, + 44.7283, 44.7774, 44.8263, 44.8749, 44.9241, 44.9731, 45.0217, 45.071, + 45.12, 45.1686, 45.2179, 45.2669, 45.3156, 45.365, 45.414, 45.4628, + 45.5122, 45.5613, 45.61, 45.6595, 45.7086, 45.7574, 45.8068, 45.856, + 45.9048, 45.9534, 46.0026, 46.0515, 46.1001, 46.1493, 46.1982, 46.2469, + 46.2961, 46.3451, 46.3938, 46.4431, 46.4921, 46.5408, 46.5901, 46.6392, + 46.6879, 46.7373, 46.7864, 46.8352, 46.8846, 46.9337, 46.9825, 47.032, + 43.8506, 43.8996, 43.9484, 43.9968, 44.0459, 44.0947, 44.1432, 44.1923, + 44.2411, 44.2896, 44.3388, 44.3876, 44.4362, 44.4854, 44.5343, 44.5829, + 44.6312, 44.6801, 44.7287, 44.7771, 44.826, 44.8747, 44.9231, 44.9721, + 45.0208, 45.0692, 45.1182, 45.167, 45.2154, 45.2645, 45.3133, 45.3617, + 45.4109, 45.4597, 45.5082, 45.5574, 45.6062, 45.6548, 45.704, 45.7529, + 45.8015, 45.8498, 45.8987, 45.9473, 45.9957, 46.0446, 46.0933, 46.1416, + 46.1906, 46.2394, 46.2878, 46.3368, 46.3856, 46.434, 46.4831, 46.5319, + 46.5803, 46.6295, 46.6783, 46.7268, 46.776, 46.8248, 46.8734, 46.9226, + 48.8777, 48.969, 49.0607, 49.1527, 49.2441, 49.3358, 49.4279, 49.5194, + 49.6112, 49.7034, 49.7949, 49.8868, 49.9781, 50.0697, 50.1617, 50.2531, + 50.3448, 50.4368, 50.5283, 50.62, 50.7122, 50.8037, 50.8956, 50.9878, + 51.0794, 51.1713, 51.2626, 51.3543, 51.4463, 51.5377, 51.6294, 51.7215, + 51.813, 51.9048, 51.997, 52.0885, 52.1805, 52.2717, 52.3633, 52.4553, + 52.5467, 52.6384, 52.7305, 52.8219, 52.9137, 53.0058, 53.0973, 53.1892, + 53.2814, 53.373, 53.4649, 53.5562, 53.6479, 53.7399, 53.8313, 53.923, + 54.0152, 54.1066, 54.1984, 54.2906, 54.3821, 54.4741, 54.5653, 54.6569, + 48.6164, 48.7066, 48.7971, 48.888, 48.9782, 49.0688, 49.1597, 49.25, + 49.3407, 49.4317, 49.5221, 49.6129, 49.703, 49.7934, 49.8843, 49.9745, + 50.065, 50.1559, 50.2462, 50.3368, 50.4278, 50.5181, 50.6089, 50.6999, + 50.7903, 50.8811, 50.9713, 51.0618, 51.1527, 51.2429, 51.3335, 51.4244, + 51.5147, 51.6054, 51.6964, 51.7868, 51.8776, 51.9677, 52.0581, 52.149, + 52.2392, 52.3297, 52.4206, 52.5109, 52.6015, 52.6925, 52.7828, 52.8736, + 52.9646, 53.055, 53.1458, 53.236, 53.3265, 53.4174, 53.5076, 53.5982, + 53.6891, 53.7794, 53.8701, 53.9611, 54.0515, 54.1423, 54.2324, 54.3228, + 57.914, 58.0021, 58.0897, 58.1767, 58.265, 58.3526, 58.4397, 58.528, + 58.6157, 58.7028, 58.7912, 58.879, 58.9662, 59.0547, 59.1426, 59.2299, + 59.3165, 59.4045, 59.4918, 59.5786, 59.6666, 59.754, 59.8408, 59.9289, + 60.0165, 60.1033, 60.1915, 60.2791, 60.3661, 60.4544, 60.542, 60.629, + 60.7174, 60.8051, 60.8922, 60.9806, 61.0684, 61.1556, 61.2441, 61.332, + 61.4193, 61.5059, 61.5939, 61.6812, 61.768, 61.856, 61.9434, 62.0302, + 62.1183, 62.2059, 62.2927, 62.3809, 62.4685, 62.5555, 62.6437, 62.7314, + 62.8184, 62.9068, 62.9945, 63.0816, 63.17, 63.2578, 63.345, 63.4335, + 57.7471, 57.8348, 57.9219, 58.0084, 58.0962, 58.1834, 58.27, 58.3578, + 58.4451, 58.5317, 58.6197, 58.707, 58.7937, 58.8817, 58.9691, 59.0559, + 59.1421, 59.2296, 59.3165, 59.4028, 59.4903, 59.5773, 59.6636, 59.7512, + 59.8383, 59.9247, 60.0124, 60.0995, 60.186, 60.2738, 60.361, 60.4476, + 60.5354, 60.6227, 60.7093, 60.7973, 60.8846, 60.9713, 61.0593, 61.1467, + 61.2335, 61.3197, 61.4072, 61.4941, 61.5804, 61.6679, 61.7549, 61.8412, + 61.9289, 62.0159, 62.1023, 62.19, 62.2772, 62.3636, 62.4514, 62.5386, + 62.6252, 62.7131, 62.8003, 62.887, 62.9749, 63.0622, 63.1489, 63.237}; + for (size_t i = 0; i < test_env.activations->attention.att_out.Rows(); ++i) { + EXPECT_TRUE(hwy::CompareArraySimilar( + att_out_golden_test_local.data() + + i * test_env.activations->attention.att_out.Cols(), + test_env.activations->attention.att_out.Row(i), + test_env.activations->attention.att_out.Cols(), 1e-3, + hwy::TargetName(HWY_TARGET), __FILE__, __LINE__)) + << "att_out mismatch for query: " << i; + } +} + +void TestAttentionMultipleTokensBF16() { + int qkv_dim = 64; + int kv_seq_len = 64; + int num_kv_heads = 2; + int num_heads = 4; + int num_tokens = 2; + int last_pos = 62; // so in the tbatch token 0 will have 63 and token 1 + // will have 64 tokens to attend to. + float att_cap = 10.0f; + int layer_idx = 0; + int layers_total = 1; + int qbatch_size = 2; + AttentionImpl attention_impl = AttentionImpl::kFlashTransposedQsBF16; + AttentionTestEnv test_env(qkv_dim, kv_seq_len, kv_seq_len, num_kv_heads, + num_heads, num_tokens, last_pos, att_cap, layer_idx, + layers_total, qbatch_size, attention_impl); + test_env.SetupWeights(); + FillMatPtrT(test_env.activations->attention.pre_att_rms_out); + FillMatPtrT(test_env.activations->attention.q); + FillMatPtrT(test_env.activations->attention.att_out); + FillMatPtrT(test_env.activations->attention.softmax_max); + FillMatPtrT(test_env.activations->attention.softmax_d); + + TiledAttention(attention_impl, num_tokens, layer_idx, *test_env.layer, + test_env.activations->attention, *test_env.qbatch, + test_env.env, kTiledFlags); + for (size_t i = 0; i < test_env.activations->attention.att_out.Rows(); ++i) { + EXPECT_TRUE(hwy::CompareArraySimilar( + AttentionMultipleTokensAttentionGoldens.data() + + i * test_env.activations->attention.att_out.Cols(), + test_env.activations->attention.att_out.Row(i), + test_env.activations->attention.att_out.Cols(), 1e-1, + hwy::TargetName(HWY_TARGET), __FILE__, __LINE__)) + << "att_out mismatch for query: " << i; + } +} + +void TestAttentionMultipleTokensInt8() { + int qkv_dim = 64; + int kv_seq_len = 64; + int num_kv_heads = 2; + int num_heads = 4; + int num_tokens = 2; + int last_pos = 62; // so in the tbatch token 0 will have 63 and token 1 + // will have 64 tokens to attend to. + float att_cap = 10.0f; + int layer_idx = 0; + int layers_total = 1; + int qbatch_size = 2; + AttentionImpl attention_impl = AttentionImpl::kFlashTransposedQsBF16; + AttentionTestEnv test_env(qkv_dim, kv_seq_len, kv_seq_len, num_kv_heads, + num_heads, num_tokens, last_pos, att_cap, layer_idx, + layers_total, qbatch_size, attention_impl, + Type::kInt8); + test_env.SetupWeights(); + FillMatPtrT(test_env.activations->attention.pre_att_rms_out); + FillMatPtrT(test_env.activations->attention.q); + FillMatPtrT(test_env.activations->attention.att_out); + FillMatPtrT(test_env.activations->attention.softmax_max); + FillMatPtrT(test_env.activations->attention.softmax_d); + + TiledAttention(attention_impl, num_tokens, layer_idx, *test_env.layer, + test_env.activations->attention, *test_env.qbatch, + test_env.env, kTiledFlags); + for (size_t i = 0; i < test_env.activations->attention.att_out.Rows(); ++i) { + EXPECT_TRUE(hwy::CompareArraySimilar( + AttentionMultipleTokensAttentionGoldens.data() + + i * test_env.activations->attention.att_out.Cols(), + test_env.activations->attention.att_out.Row(i), + test_env.activations->attention.att_out.Cols(), 1e-1, + hwy::TargetName(HWY_TARGET), __FILE__, __LINE__)) + << "att_out mismatch for query: " << i; + } +} // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE diff --git a/util/mat.h b/util/mat.h index a4555794..3b74faad 100644 --- a/util/mat.h +++ b/util/mat.h @@ -86,7 +86,9 @@ class MatPtr : public IFields { // Only for use by ctor, `AllocateFor` and 'loading' memory-mapped tensors. void SetPtr(void* ptr, size_t stride) { - HWY_ASSERT(stride >= Cols()); + if (stride < Cols()) { + HWY_ABORT("%s: stride %zu < cols %zu\n", Name(), stride, Cols()); + } ptr_ = ptr; stride_ = static_cast(stride);