Grouped Query Attention, and tokenomics
Exploring, validating, and delving deeper into the origins of AI inference-time common-knowledge.
My blogs are usually extremely freaking long. I’m going to consciously try to be more succinct here. Let me know how I do!
Attention is a word every AI enjoyer knows, and has an opinion on. It’s a great sequence mixing mechanism, and extraordinarily great at capturing long-range dependencies.
But all the homies do a double take on the mechanism since it’s horribly inefficient on today’s GPUs. We’ve split an attention-leveraging model’s response process into two phases:
Prefill: The time to process all of the tokens, and generate the first response token.
Decode: The process of generating the next token.
We’ll also introduce the idea of the KV-cache, a great speedup mechanism, but also the root of some critical problems at decode time, alongside some irking storage nags.
Let’s convince ourselves of some common truths & results the industry heavily leverages.
The standard multi-head attention (MHA) mechanism is horribly memory bound in the decode process due to the addition of KV-caching, but compute bound in prefill.
Why does the KV cache help, then?
KV-caching for MHA takes up more storage than we need. Grouped Query Attention (GQA) helps with this.
Multi-Head Attention (MHA), and kv-caching
Motivating Prefill
Attention, in general, can be thought of the following sequence of operations (decode-related changes omitted):
# c x [B, n, h, T] -> [B, n, T, T]
attn_logits = (queries @ keys_transposed) / math.sqrt(self.latent_attention_dim)
mask = torch.full((T_q, T_k), float('-inf')).triu(1)
# [B, n, T, T] -> [B, n, T, T] (no change in shape)
attn_probs = F.softmax(attn_logits + mask, dim=-1)
# [B, n, T, T] x [B, n, T, h] -> [B, n, T, h]
attn_scores = attn_probs @ valuesIt consists of three operations:
QKT+ Mask, computing the matrix product of the queries and transposed keys. Sometimes you may hear people refer to this computation as logit computation
Softmax, which transforms the matrix logits into a probability distribution A
AV, computing the final attention scores
The majority of our floating point operations FLOPs from from the QKT and AV products, so we’ll neglect the FLOPs from the softmax.
Now, let’s look at the FLOPs for each of these computations. Let’s define some terms. Let the following variables come to life!
n: The number of attention heads
T: The sequence length, i.e the number of prior tokens on which we are attending
B: The batch size we are processing with
h: The attention head dimension, or the latent space dimension in which we are doing our computations
Both QKT and AV do somewhat similar matrix multiplications. To build up to the expression, let’s assume we do a general matrix multiplication (GEMM) with matrix shapes A (M, K) x B (K, N) = C (M, N). Typically, we just abbreviate this to GEMM(M, N, K)
If we assume that we need to do an add and multiply per partial product, we do 2 floating-point operations. We have K partial products per output result, and M * N output results. So, our total FLOPs in this GEMM is 2 x M x N x K = 2MNK.
Similarly, we load A, with same (M, K), and B, with shape (K, N), and write C with shape (M, N). So, our total memory transfers are (M x K + K x N + M x N) x data_width. Data width just refers to how many bytes per matrix multiplication entry — typically defined by FP32, FP16, etc.
Now, we can articulate our QKT product and AV product as GEMMs. QKT is effectively a GEMM with parameters GEMM(T, T, h). AV is equivalent to GEMM(T, h T).
As the FLOPs equation is 2MNK, both products end up having the relation 2T2h, per head, per batch. So our final flops is 4BnT2h.
In practice, we leverage fused GPU kernels for computing attention, so we avoid having to write the full TxT attention matrix. By doing so, our memory cost is limited to loading the inputs - queries, keys, values - and then writing the final attention output.
That’s a total of: (queries) BnTh + (keys) BnTh + (values) BnTh + (attention output) BnTh = 4BnTh * data_width.
So, our arithmetic intensity is (4BnT2h / 4BnTh DW) = T/DW.
In practice, attention is computed in BF16 at pretrain time, and FP16 at post-train / inference time. So, DW = 2.
That means our intensity scales with the sequence length processed! Usually, this means that for sequence lengths greater than 1024, we are safely compute bound. In practice, hardware has much lower ridge points (the arithmetic intensity at which we transition to compute bound) — generally under 100. That means that sequence lengths above 200, i.e 256, 512, etc. can be thought of as compute bound.
KV-Caching and Generation
The current arithmetic intensity is great for pre-training! We don’t need to do much else, and in practice we don’t do many optimizations for attention at pretrain, sans fusing the attention computation to avoid writing the matrix to HBM.
But generation gives us more room to optimize. Unlike pre-train, we don’t do next token prediction once, and then sample new data. Instead, we keep on doing next token prediction until we hit the end, typically signified with reaching the EOS token. That leaves a lot of overlap within the attention mechanism. I.e, here’s the general generation flow:
(B, S0, ) --> gen. token(transformer) --> (B, S0 + 1, ) --> ... --> gen. token --> (B, S0 + k, )Specifically, the overlap comes from the prior computation of queries, keys, and values. In generation, we assume that the weights are static in between token generations. That means that if you provide a batch of token id in the shape of (B, T, ) twice, each layer would generate the same output. You can hope to get a relatively similar probability distribution. Why this matters for us, is that at each token timestep, we’d want the clarity of mind that the keys, queries, and values for the previous sequence length tokens doesn’t change.
What matters for us for this discussion is the attention layer. For the output to be the same, the layer would need to generate similar keys, queries, and values. But do we really need to recompute the keys, queries, and values for each of the previous tokens each and every time?
No! Let’s ask the question for each of the 3 attention entities:
Do we need to recompute queries: No! In fact, we don’t even need the queries for tokens once we’re done computing the next token. Attention itself can be thought of answering the following question at each timestep: What are the most relevant keys for a given token’s query head? And we don’t need to answer this for previous tokens since we already have the next token in mind!
So, we can get away with computing only the final token’s queries. No need to keep the previous tokens’ queries around.
Do we need to recompute keys: Yes! Each time, we need the keys to find out what is the most relevant information per query.
Do we need to recompute values: Yes! Once we find out the most relevant keys, we do want the associated values to compute our final output.
So, we need to preserve our keys and values, but can discard the queries. This leads to the idea of maintaining a key-value (KV) cache, which would store the results after computing them, removing computation we just identified as redundant. Doing this will reduce our runtime, but lead to the decode process becoming memory bound — loading all of our previous keys and values will now be pretty slow. But still faster than re-doing the computation each and every time!
Estimating Decode FLOPs for MHA
Let’s redo the same kind of calculation we did in prefill for our decode problem.
Now, our matrices have changed a bit. Instead of queries having a shape of [B, n, T, h] we now only have [B, n, 1, h] since we compute the last timestep only. So, that means our QKT FLOPs will drop due to less elements in the query matrix, alongside bytes.
As a result of a smaller M GEMM dimension, the A matrix now has a shape of [B, n, 1, T]. Then, our output coming from the AV computation will be [B, n, 1, h].
I’ll go through the calculations a little faster now that we have some experience. If this is confusing, try to replicate the calculation step-by-step like I did above!
FLOPs = 4BnTh
Bytes = (queries) Bnh + (keys) BnTh + (values) BnTh + (output) Bnh = 2Bnh + 2BnTh.
That leads to an arithmetic intensity of 4BnTh/(2Bnh+2BnTh) * DW. At sequence lengths where T » nh, we can make the approximation that the expression is approximately 4BnTh/(2BnTh * DW) = 2 / DW.
So, our arithmetic intensity is constant, and low! If DW = 2, then the arithmetic intensity approaches 1, which is easily memory bound, as most ridge points are at least >= 30.
So, why do we blame the memory bound-ness on KV cache? Because the denominator term — 2BnTh — comes from loading our keys and values. So now, we can claim that decode is indeed a memory bound workload.
Grouped Query Attention — a stopgap, not a solution
The field has responded to the incredibly large KV-cache problem, and the best stopgap available is grouped query attention (GQA). I wouldn’t really call this a solution — if you want to see my rationale for that, just look at the enormous numbers we’re getting for storage in the next section.
Let’s revisit the bytes computation from the last section:
Bytes = (queries) Bnh + (keys) BnTh + (values) BnTh + (output) Bnh = 2Bnh + 2BnTh. What if, instead of loading n heads, we load nkv heads, and just duplicate them n / nkv times to create an equivalent tensor?
It turns out that we can tune this nkv parameter so that we only get a minor loss drop. But this n / nkv reuse factor (let’s call it ‘r’) now means that we can cut the bytes calculation down even further. FLOPs don’t actually change in this case because we still do compute over the duplicated kv-heads.
Bytes = (queries) Bnh + (keys) Bn_kvTh + (values) Bn_kvTh + (output) Bnh = 2Bnh + 2BnTh.
Once we re-apply the same approximations as we did the last time, we arrive at the following:
4BnTh/(2B(n_kv)Th * DW) = 2r / DW
What this means is, using the nkv as a parameter, we can also tune the arithmetic intensity of decode-attention in such a way that we can nudge the (grouped query) attention mechanism back towards the compute bound.
The storage problem
Okay, loading the KV cache takes time. And to be fair, this is a huge amount of data. How large? Let’s compute it!
We already know that 2BnTh * DW gives us the total bytes of KV cache we load. But this is per layer! Models usually have 50-60 layers. If we try to do a generation with a sequence length of 16384 for example (and this is on the lower end BTW for most useful applications, which often leverage reasoning approaches), we obtain the following.
Let’s calculate this for Gemma3-27B, which uses 16 heads, and a head dim of 128, assuming a batch size of 256:
KV Cache Mem = 2BnTh * DW * num layers = 2 * 256 * 16 * 16384 * 128 * 2 * 62 = 2,130,303,778,816BThat’s 2.13 terabytes of storage purely for KV cache!!! Let that sink in.
The literal equation for KV as a function of sequence length is: 130023424 * sequence_length. Plugging in sequence_length = 1M, we get a total storage of 136 TB of kv-cache.
There is a chance that KV-cache storage for trillion parameter models (we’re only at 27B right now!) can reach into the petabytes. Doing inference at scale quickly becomes a challenge.
Conclusion
And we’ve basically made it to state of the art in attention architectural design! Yes, additional optimizations exist. But most optimizations don’t actually touch more parts of the core attention operation, potentially due to them walking the tradeoff of lossing loss even more ungracefully than GQA itself does.


