KV Cache in LLMs

Authors
  • Amit Shekhar
    Name
    Amit Shekhar
    Published on
KV Cache in LLMs

I am Amit Shekhar, Founder @ Outcome School, I have taught and mentored many developers, and their efforts landed them high-paying tech jobs, helped many tech companies in solving their unique problems, and created many open-source libraries being used by top companies. I am passionate about sharing knowledge through open-source, blogs, and videos.

I teach AI and Machine Learning, and Android at Outcome School.

Join Outcome School and get high paying tech job:

In this blog, we will learn about KV Cache - where K stands for Key and V stands for Value - and why it is used in Large Language Models (LLMs) to speed up text generation.

We will start with how LLMs generate text one token at a time, understand the role of Key, Value, and Query inside the model, see the problem of repeated computation through an example, and then walk through how KV Cache solves this problem by storing and reusing past results.

Let's get started.

How LLMs Generate Text

Before we understand KV Cache, we first need to understand how LLMs generate text.

An LLM - Large Language Model - is a model that has been trained on a massive amount of text data. It can understand and generate human language. When we give it a sentence, it predicts what comes next.

LLMs generate text one token at a time. A token is a small piece of text - it can be a word, part of a word, or even a single character. For simplicity, let's treat each word as one token.

Let's say we give the model this input:

"I love"

The model looks at "I" and "love", and predicts the next token: "teaching".

Now the full sequence becomes:

"I love teaching"

The model looks at "I", "love", and "teaching", and predicts the next token: "AI".

Now the full sequence becomes:

"I love teaching AI"

This process continues, one token at a time, until the model decides to stop.

The important thing to notice here is: every time the model predicts a new token, it needs to look at all the previous tokens to decide what comes next.

What Happens Inside the Model

Now, let's understand what happens inside the model when it generates each token.

The model has a component called the attention layer. The job of the attention layer is to help the model figure out which previous tokens are important for predicting the next token. Not all previous tokens are equally useful. Some matter more than others.

Inside the attention layer, every token is converted into three things. Let's understand each of them with a simple analogy.

Imagine a classroom where a new student joins and wants to find out who can help with a specific topic:

  • Query (Q): The new student's question - "Who here knows about this topic?" Every token being predicted gets a Query. It represents what the token is looking for.
  • Key (K): The name tag each existing student wears that says what they know - "I know about math" or "I know about science." Every previous token gets a Key. It describes what information that token holds.
  • Value (V): The actual notes each student has. Once the new student finds the right person using their name tag (Key), the notes (Value) are what they actually use. Every previous token gets a Value. It carries the actual information.

So, the current token uses its Query to compare against the Keys of all previous tokens. This comparison produces attention scores - numbers that tell the model how much focus to give to each previous token. A higher score means that token is more relevant. A lower score means it is less relevant. The model then uses these scores to collect the Values from the relevant tokens.

This is how the attention layer works.

Now, let's see the problem.

The Problem: Repeated Computation

Every time the model predicts the next token, it computes the Key, Value, and Query for all tokens in the sequence - not just the new one.

Let's see what happens step by step with our example:

Step 1: Input is "I love"

The model computes Key, Value, and Query for:

  • "I"
  • "love"

It uses these to predict the next token: "teaching".

Step 2: Input is "I love teaching"

The model computes Key, Value, and Query for:

  • "I" (already computed in Step 1, but computed again)
  • "love" (already computed in Step 1, but computed again)
  • "teaching" (new)

It uses these to predict the next token: "AI".

Step 3: Input is "I love teaching AI"

The model computes Key, Value, and Query for:

  • "I" (already computed in Step 1 and Step 2, but computed again)
  • "love" (already computed in Step 1 and Step 2, but computed again)
  • "teaching" (already computed in Step 2, but computed again)
  • "AI" (new)

Do you see the problem?

The Key and Value for "I" were computed in Step 1. But the model computes them again in Step 2, and again in Step 3. The same goes for "love" and "teaching".

The model is repeating the same work for tokens it has already seen. This is wasted computation.

As the sequence grows longer, this problem gets worse. If the model has already generated 100 tokens, then at the next step, it would recompute the Key and Value for all 100 previous tokens just to predict one new token. This makes text generation very slow.

The Solution: KV Cache

The idea behind KV Cache is simple: compute the Key and Value for each token once, store them, and reuse them in every future step.

Think of it like taking notes. Imagine you are in a meeting. Every time someone new speaks, instead of asking all the previous speakers to repeat what they said, you simply read from your notes and only listen to the new speaker. The notes are your cache.

Similarly, KV Cache is a memory where we save the Key and Value of every token that has already been processed. This way, the model does not have to compute them again.

Let's see how the same example works with KV Cache:

Step 1: Input is "I love"

The model computes Key, Value, and Query for:

  • "I"
  • "love"

It saves the Key and Value of "I" and "love" in the KV Cache.

It uses these to predict the next token: "teaching".

KV Cache now contains: Key and Value for "I", "love"

Step 2: Input is only "teaching" (only the new token)

The model retrieves the Key and Value of "I" and "love" from the KV Cache. No recomputation needed.

It computes Key, Value, and Query only for the new token: "teaching".

It saves the Key and Value of "teaching" in the KV Cache.

It uses all of these together to predict the next token: "AI".

KV Cache now contains: Key and Value for "I", "love", "teaching"

Step 3: Input is only "AI" (only the new token)

The model retrieves the Key and Value of "I", "love", and "teaching" from the KV Cache. No recomputation needed.

It computes Key, Value, and Query only for the new token: "AI".

It saves the Key and Value of "AI" in the KV Cache.

KV Cache now contains: Key and Value for "I", "love", "teaching", "AI"

So, instead of recomputing the Key and Value for every token at every step, the model computes them only for the new token and reuses the cached values for all previous tokens.

This is how KV Cache avoids repeated computation.

Why Only Key and Value Are Cached, Not Query

You would naturally ask: why do we cache only the Key and Value, and not the Query?

The Query is only needed for the current token - the one being generated right now. The current token uses its Query to compare against the Keys of all previous tokens to find which ones are relevant. Once the prediction is done, that Query is no longer needed.

But the Key and Value of every past token are needed at every future step, because every new token must look at all the previous tokens to make its prediction.

So, we only need to store the Keys and Values. This is why it is called KV Cache - it caches the Keys and Values.

How Much Faster Does It Get

Let's compare the two approaches side by side:

Without KV Cache:

Step 1: Compute K, V, Q for 2 tokens
Step 2: Compute K, V, Q for 3 tokens  (2 recomputed)
Step 3: Compute K, V, Q for 4 tokens  (3 recomputed)
Step 4: Compute K, V, Q for 5 tokens  (4 recomputed)
...
Step N: Compute K, V, Q for (N+1) tokens  (N recomputed)

The amount of computation keeps growing at every step.

With KV Cache:

Step 1: Compute K, V, Q for 2 tokens → Save K, V for 2 tokens in cache
Step 2: Compute K, V, Q for 1 token  → Reuse K, V for 2 tokens from cache
Step 3: Compute K, V, Q for 1 token  → Reuse K, V for 3 tokens from cache
Step 4: Compute K, V, Q for 1 token  → Reuse K, V for 4 tokens from cache
...
Step N: Compute K, V, Q for 1 token  → Reuse K, V for N tokens from cache

After the first step, the model computes only for one new token at each step instead of the entire sequence.

To put this into perspective: if the model is generating a sequence of 100 tokens, without KV Cache, the total number of Key and Value computations across all steps would be 2 + 3 + 4 + ... + 100 = 5,049 computations. With KV Cache, it would be 2 + 1 + 1 + ... + 1 = 101 computations. That is roughly 50 times fewer computations. The longer the sequence, the bigger the savings.

This is what makes KV Cache so effective at speeding up text generation.

The Trade-Off: Speed vs Memory

KV Cache makes generation faster, but it comes with a trade-off: it uses extra memory to store all the Key and Value information for every token that has been generated so far.

As the sequence gets longer, the cache grows larger. For very long sequences with thousands of tokens, the cache can consume a significant amount of memory.

So, KV Cache is a trade-off: we use more memory to save computation time. For most use cases, this trade-off is well worth it, as the speed improvement is significant.

Now, we have understood the KV Cache in LLMs.

In the next blog, we will learn about Paged Attention that solves the memory issue of KV Cache.

Prepare yourself for AI Engineering Interview: AI Engineering Interview Questions

That's it for now.

Thanks

Amit Shekhar
Founder @ Outcome School

You can connect with me on:

Follow Outcome School on:

Read all of our high-quality blogs here.