Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ cc_test(
":kv_transcoding",
":mat",
":matmul",
":ops",
":test_util",
":threading_context",
":weights",
Expand Down
21 changes: 9 additions & 12 deletions evals/gemma_batch_bench.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,19 +129,16 @@ std::vector<std::string> GenerateInputs() {
TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
s_env->SetMaxGeneratedTokens(12);
const std::vector<std::string> inputs = GenerateInputs();
const AttentionImpl modes[] = {AttentionImpl::kOld, AttentionImpl::kFlash};
for (const AttentionImpl mode : modes) {
// Run multiple times so that auto-tuning is closer to complete.
fprintf(stderr, "Testing mode %s\n", GetAttentionImplName(mode).c_str());
for (size_t rep = 0; rep < 4; ++rep) {
std::vector<std::string> responses = BatchGemmaReply(inputs, mode);
for (size_t i = 0;
i < HWY_MIN(hwy::Unpredictable1() * 3, responses.size()); ++i) {
fprintf(stderr, "Rep %zu batch answer %zu '%s'\n\n", rep, i,
responses[i].c_str());
}
PROFILER_PRINT_RESULTS();
// Run multiple times so that auto-tuning is closer to complete.
for (size_t rep = 0; rep < 4; ++rep) {
std::vector<std::string> responses =
BatchGemmaReply(inputs, AttentionImpl::kFlash);
for (size_t i = 0; i < HWY_MIN(hwy::Unpredictable1() * 3, responses.size());
++i) {
fprintf(stderr, "Rep %zu batch answer %zu '%s'\n\n", rep, i,
responses[i].c_str());
}
PROFILER_PRINT_RESULTS();
}
}

Expand Down
2 changes: 1 addition & 1 deletion evals/wheat_from_chaff_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ GemmaEnv* GemmaTest::s_env = nullptr;
// Tests whether Gemma can find the right answer in varying levels of
// background information, ranging from the bare facts to outright distraction.
TEST_F(GemmaTest, WheatFromChaff) {
const AttentionImpl modes[] = {AttentionImpl::kOld, AttentionImpl::kFlash};
const AttentionImpl modes[] = {AttentionImpl::kFlash};
fprintf(stderr, "Warmup, mode %s\n", GetAttentionImplName(modes[0]).c_str());
auto prompt = BuildPrompt({"quark_1.txt", "holiday_story.txt"}, kQuestions);
auto response = GemmaReply(prompt, modes[0]);
Expand Down
36 changes: 4 additions & 32 deletions gemma/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,6 @@ struct AttentionActivations {
? layer_config.heads * 3 * layer_config.qkv_dim
: layer_config.heads * layer_config.qkv_dim,
allocator)),
q_T(MatFactory("q_T", layer_config.qkv_dim,
config.vocab_size == 0
? batch_size * layer_config.heads * 3
: batch_size * layer_config.heads,
allocator)),
vit_Q(MatFactory("Q2", batch_size, layer_config.qkv_dim, allocator)),
vit_K_T(MatFactory(
"K2_T", hwy::RoundUpTo(seq_len, kMaxBF16PerVector),
Expand All @@ -88,13 +83,6 @@ struct AttentionActivations {
allocator, MatPadding::kPacked)),
pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size,
config.model_dim, allocator)),
// att is only valid for AttentionImpl::kOld.
att(MatFactory(
"att", batch_size,
layer_config.heads *
(runtime_config.attention_impl == AttentionImpl::kOld ? seq_len
: 1),
allocator)),
att_out(MatFactory("att_out", batch_size,
layer_config.heads * layer_config.qkv_dim,
allocator)),
Expand Down Expand Up @@ -133,20 +121,17 @@ struct AttentionActivations {
// fill them in each MatMul call.
q.AllocateAndAttachRowPtrs(row_ptrs);
q_bf.AllocateAndAttachRowPtrs(row_ptrs);
q_T.AllocateAndAttachRowPtrs(row_ptrs);
att_sums.AllocateAndAttachRowPtrs(row_ptrs);
}

void SetBatchSize(size_t batch_size) {
q.OverrideRows(batch_size);
q_bf.OverrideRows(batch_size);
// q_T rows are always qkv_dim!

vit_Q.OverrideRows(batch_size);
// vit_K_T and vit_V_T stay seq_len!

pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size);
att_out.OverrideRows(batch_size);
att_out_reps.OverrideRows(batch_size * rep_factor);
// There is no override for [split_]flash_params, because we reserved an
Expand All @@ -170,14 +155,12 @@ struct AttentionActivations {
std::vector<Tile148Params> split_flash_params;
MatStorageT<float> q; // query
MatStorageT<BF16> q_bf;
MatStorageT<BF16> q_T; // Transposed to maximize attention speed.

MatStorageT<float> vit_Q;
MatStorageT<KV_t> vit_K_T;
MatStorageT<KV_t> vit_V_T;

MatStorageT<float> pre_att_rms_out;
MatStorageT<float> att; // attention vector
MatStorageT<float> att_out; // attention output
MatStorageT<float> att_out_reps; // attention output for each thread.
MatStorageT<float> softmax_max; // see OnlineSoftmaxState
Expand Down Expand Up @@ -218,12 +201,10 @@ struct AttentionActivationsPtrs {
activations.split_flash_params) {
q = activations.q;
q_bf = activations.q_bf;
q_T = activations.q_T;
vit_Q = activations.vit_Q;
vit_K_T = activations.vit_K_T;
vit_V_T = activations.vit_V_T;
pre_att_rms_out = activations.pre_att_rms_out;
att = activations.att;
att_out = activations.att_out;
att_out_reps = activations.att_out_reps;
softmax_max = activations.softmax_max;
Expand All @@ -236,13 +217,11 @@ struct AttentionActivationsPtrs {
void SetBatchSize(size_t batch_size) {
q.OverrideRows(batch_size);
q_bf.OverrideRows(batch_size);
// q_T rows are always qkv_dim!

vit_Q.OverrideRows(batch_size);
// vit_K_T and vit_V_T stay seq_len!

pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size);
att_out.OverrideRows(batch_size);
softmax_max.OverrideRows(batch_size);
softmax_d.OverrideRows(batch_size);
Expand All @@ -268,29 +247,22 @@ struct AttentionActivationsPtrs {
MatPtrT<float> q;
// Query matrix of size batch_size x (q_heads * qkv_dim).
MatPtrT<BF16> q_bf;
// Transposed query matrix for faster Q*K^T.
MatPtrT<BF16> q_T;

MatPtrT<float> vit_Q;
MatPtrT<KV_t> vit_K_T;
MatPtrT<KV_t> vit_V_T;

// Output of RMSNorm before attention, size batch_size x model_dim.
MatPtrT<float> pre_att_rms_out;
// Attention scores computed from Q*K^T, size batch_size x (q_heads *
// seq_len).
MatPtrT<float> att;
// Attention output computed from att * V, size batch_size x (q_heads *
// qkv_dim).
MatPtrT<float> att_out;
MatPtrT<float> att_out_reps;
// The maximum logit value encountered when computing att_out from att,
// size batch_size x q_heads . See OnlineSoftmaxState for details.
// WARNING: Only filled in for AttentionImpl::kOld.
// The maximum logit value encountered when computing att_out, shape
// batch_size x q_heads . See OnlineSoftmaxState for details.
MatPtrT<float> softmax_max;
// The sum of scaled exponentials when computing att_out from att,
// size batch_size x q_heads . See OnlineSoftmaxState for details.
// WARNING: Only filled in for AttentionImpl::kOld.
// The sum of scaled exponentials when computing att_out, shape
// batch_size x q_heads . See OnlineSoftmaxState for details.
MatPtrT<float> softmax_d;
// Accumulation of attention outputs over heads, size batch_size x
// model_dim.
Expand Down
Loading
Loading