Llama-3, A Deep Dive
In this post we explore Llama-3, a popular LLM. Because of its open-source nature and its compact architecture, it serves as a great example for studying how LLMs work and what are the main research directions around optimizing their use-cases. We focus mainly on the architecture as it literally affects all other aspects of the model - how it is used, how it is optimized, and how it can be modified.
Tokenization
To feed text inputs to the LLM we first need to tokenize them. Tokenization refers to breaking the input into a sequence of discrete elements which are then processed by the model. Text tokenization happens by breaking sentences into words or sub-words and associating dense embeddings to each of the resulting text chunks. Image tokenization instead breaks up the image into non-overlapping square patches, for example of size \((14, 14)\), flattens them, and then passes them to the model. Tokenization is a very powerful approach because it allows us to convert any data modality into discrete chunks which can be processed by the same model. This is one approach to building multimodal architectures.
For text tokenization, most LLMs use byte-pair encoding. This algorithm iteratively builds up a vocabulary of recognized tokens out of a large body of text as follows. First we set the total vocabulary size, this being all the different tokens that can be represented. Large values here are entirely possible, for example, Llama-3's vocabulary's size is 128256! Then, in the simplest version of BPE, we convert the entire text corpus, a long string, into bytes. By working with bytes, we ensure that all symbols, including rare Unicode characters, will be processed. If the model encounters a new symbol at test time that has not been observed at training time, it will simply see it as a sequence of bytes. And if we have a token for every byte, then we can tokenize every symbol. But tokenizing every single byte separately is not efficient. We can merge some bytes together into new tokens.
We can represent the tokenizer as a dictionary from byte sequences to integers. At the beginning we add all byte values (0x00
, 0x01
, 0x02
, ..., 0xff
) as keys with their corresponding integer values from \(0\) to \(255\) as values. Next we start iterating the encoded-to-bytes text corpus and count how many times a given pair of neighboring bytes occurs. For example, the byte pair (0x74
, 0x68
) corresponding to "th" occurs much more frequently than (0x7a
, 0x6d
) corresponding to "zm". We find the most frequent byte pair, suppose (0x74
, 0x68
), merge it to get 0x74_0x68
and add it to the tokenizer dictionary with value \(256\), the lowest non-taken integer. Next, we replace all occurences of (0x74
, 0x68
) with the new token 0x74_0x68
. The effect is that we can now recognize the two bytes together, as a separate token, and this reflects the statistical likelihood of their occurances. Subsequently, we repeat this process - find the next most frequent byte pair, merge it, add it to the dictionary, and replace its occurances in the text. This process continues until we have the desired number of tokens in our vocabulary.
As a result, more common byte sequences will have roughly lower integer values. Frequent long sequences like prefixes or suffixes will obtain their own tokens. With a large vocabulary and a sufficiently big corpus we will learn tokens corresponding to very specific terms and only their "roots". For example, GPT4's tokenizer breaks down institutionalism
into three tokens - institution
, al
, ism
. In practice it is common to first break down the whole text into individual words using a regex pattern and then apply BPE on the encoded sequence of words obtained. Once the tokenizer is obtained, to encode a sentence we break it into words and tokenize each word separately. To tokenize a word, we start iterating its pairs byte by byte and merge them according to the tokenizer. Once we can't merge anymore, we return the integers corresponding to the tokens. Decoding is trivial and simply looks up the tokens from their integer values using a reverse dictionary that is constructed from the main one.
LLMs are trained using next-token prediction. Hence, the most natural way to use them at test time is to have them predict the natural continuation of some text. For this, we basically prepare a list of strings, representing a batch of text inputs. Few shot or in-context learning can be mixed into the individual string prompts. Then, each of the input samples is encoded by the tokenizer. Let's take a look at the Llama-3 tokenizer. It is based on OpenAI's tiktoken, but with some special tokens added - in particular we have special tokens <|begin_of_text|>
, <|end_of_text|>
, <|eot_id|>
(end of turn), <|start_header_id|>
, <|end_header_id|>
and a bunch of other reserved ones. These have very large values and are added after the main tokens. Encoding proceeds by splitting the input string into smaller chunks and each chunk into shorter substrings. Each substring is then mapped to an integer using the tokenizer mapping. The Beginning-of-sequence (BOS) and End-of-sequence EOS tokens may be added.
Subsequently, the model receives tokenized inputs in the form of a List[List[int]]
and the inner lists have different lengths. Depending on whether a maximum generated length is set, an integer tensor of shape \((B, N)\), let's call it tokens
, is created. Here \(B\) is the batch size, i.e. number of parallel sequences, and \(N\) is either the length of the longest prompt plus the length of the text to generate (if set), or some predefined maximum sequence length, whichever is shorter. Limiting the total sequence length to process, which consists of the prompt length and the response length, is done to prevent the attention from blowing up in memory. In any case, the input tensor tokens
is filled with the input prompts and padded up to length \(N\).
Next, the model initializes a counter prev_pos
to \(0\) and iterates another counter current_pos
from the length of the shortest prompt to \(N\). This is the generation loop. At each iteration it takes the prompt tokens for all batch samples from position prev_pos
to current_pos
and performs a forward pass, obtaining the next-token logits over the vocabulary. Once the next token is selected, of shape \((B,)\), it is placed into tokens
if the tokens at current_pos
are not prompt tokens. Subsequently, prev_pos
is set to current_pos
and current_pos
is incremented by 1. Additionally, we mark whether the selected token is a EOS token and if for all batch samples an EOS token is eventually produced, the generation loop halts.
In Llama-3 at the first iteration in the generation loop the input tokens have shape \((B, M)\) where \(M\) is the length of the shortest prompt. In subsequent iterations the inputs to the forward pass are of shape \((B, 1)\). This is intentional and results from the KV-caching that is going on. Without it there will be a lot of redundant computation in calculating the keys and values of those tokens corresponding to the prompt and the answer generated so far. The autoregressive generation in LLMs requires each new token to attend to (mix together with) all previous tokens. The KV-cache does not entirely eliminate this necessity. It only avoids recomputing the keys/values for tokens which have been seen so far.
After the generation loop ends we do some slight post-processing. If there are any EOS tokens in each of the batch samples, we cut the generated sequence up until them. This can happen if only some, but not all, batch sequences produce EOS tokens. Lastly, we decode the produced integers back to words at which point our text has been generated.
Architecture
Now let's look at the Llama-3 architecture. The input is tokens
, a integer-valued table of shape \((B, N)\). Suppose \(M\) is the lenght of the shortest prompt, \(S\) is the maximum sequence length, and \(G\) is the number of tokens to generate. Then \(N = \min(S, M + G)\). Usually the actual model would throw an error if the input prompt is longer than \(N\). This is a very important constraint for these models. It is up to the LLM inference engine, which we discuss shortly, to make sure that the prompt is not too long.
The first thing the model does is to convert the token IDs to latent vectors. This happens by looking up the tokens
integers in a big embedding table of shape \((V, D)\) where \(V\) is the vocabulary and \(D\) is the vector dimension. This table is trained with gradient descent and represents the semantic information for each token. To get the embedding of token \(i\) we look up row \(i\) from this embedding.
The next step is to add positional encodings. Let's consider why this is needed. An RNN would process a sequence token by token, in order, and earlier tokens will be processed before later tokens. In that sense, the order of the tokens is encoded by the recurrent operations themselves. But with attention, we process all tokens in parallel. Permuting the keys and values will have no effect on the attention layer outputs, because of the weighted sum aggregation of values, which is permutation invariant. Permuting the queries will permute the output vectors in the same way, but so what? The following feed-forward layers also process all tokens in parallel so they cannot understand order. This reasoning shows that to consider the order of the sequence elements, a transformer would need to have positional features mixed in with the semantic features of the queries, keys and values.
Absolute positional encodings typically represent token positions as integers from \(0\) to \(S - 1\) and compute sinusoidal features from them. These features are added before the first self-attention. Unlike them, relative positional encodings capture the relative distance between tokens. For a total sequence length \(S\), there are up to \(S\) tokens after a given one, and up to \(S\) before it. The resulting positional matrix is of shape \((S, 2S - 1)\), with integer values like \((\ldots, -2, -1, 0, 1,\ldots)\). We typically learn an embedding of shape \((2S-1, D)\), \(D\) being the hidden dimension, which is reshaped and added to the attention coefficients. The problem with these relative encodings is that at test time when we are generating a sequence autoregressively the encodings for the past tokens change whenever we predict a new token. It's very hard to use a KV-cache in that case, which greatly limits their use cases.
Llama-3 and a bunch of other models instead use rotary positional encodings. The idea here is that the embeddings, which are high-dimensional vectors, will be rotated according to their absolute positions. Suppose token \(i\) occurs at absolute position \(m\). To encode its position, we will rotate it by an angle \(m\theta\), where \(\theta\) is fixed. Immediately it becomes evident that tokens appearing later in the sentence will have higher \(m\) and thus will be rotated more, which is a form of absolute encoding. But moreover, the relative positions between two tokens \(i\) and \(j\) are also preserved by this scheme. Vector \(i\) will be rotated by \(m_i \theta\) and vector \(j\) by \(m_j \theta\). If \(i\) and \(j\) occur later in the sentence, at index \(k\), but have the same relative distance, they'll be rotated by \((k + m_i) \theta\) and \((k + m_j) \theta\). Suppose before encoding their positions, the angle between \(i\) and \(j\), representing how semantically similar \(i\) and \(j\) are, is \(\alpha\). After encoding them, that angle will be \(\alpha + (m_j - m_i)\theta\), irrespective of \(k\), the absolute position of the pair in the sequence. In other words, the relative angle of \(i\) and \(j\) after encoding, will be preserved, wherever \(i\) and \(j\) occur, as long as they have the same relative distance between them.
Overall, the formula for rotary embeddings is
\(\mathbf{x}_m\) is the vector at position \(m\), \(\mathbf{W}_{q, k}\) is the matrix that computes queries and keys, and \(\mathbf{R}^d_{\Theta, m}\) is a big rotation matrix. How does it look like? In general, rotation matrices for \(d\)-dimensional vectors are complicated. Instead, we rotate each pair of dimensions with a different \(\theta\). The first two elements of the vector will be rotated by \(m\theta_1\), the next two by \(m\theta_2\), and so on, up to the last two, which will be rotated by \(m\theta_{d/2}\). Usually, the thetas are chosen as \(\Theta = \{ \theta_i = 10000^{-2(i-1)/d}, i \in [1, 2, \ldots, d/2] \}\). Overall, how an element from a vector is transformed depends on both where the vector is, \(m\), and what its dimension index is, \(i\).
In practice, suppose there are \(H\) attention heads and define \(D' = D/H\). We create the \(\Theta\) array and a regular array \(\text{m} = \{0, 1, \ldots, 2S-1\}\). Then, \(\text{m} \ \Theta^{T}\) is of shape \((2S, D'/2)\) and contains the rotation angles for every possible relative position and every dimension. We then create a complex-valued tensor where each element is a complex exponential with unit magnitude and angle given by \(\text{m} \ \Theta^{T}\). This tensor, let's call it \(P\), represents the angles to rotate by.
Now, when we call the model at test time, we also feed it the current location within the sequence, as stated before. The model can then extract from \(P\) those angles that correspond to the absolute locations of tokens
. In the self-attention blocks, after we compute the keys, queries, and values, we apply the function torch.view_as_complex
on the keys and queries, effectively viewing them as complex numbers. We then element-wise multiply them with the complex numbers representing the angles to rotate by. Basically, multiplying a complex number \(z\) by \(e^{i\phi}\) rotates \(z\) by \(\phi\). After the multiplication we use torch.view_as_real
to turn them back into real numbers, and reshape as needed.
Thus, rotary embeddings take the queries/keys, view them as complex numbers, multiply them by some complex precomputed exponentials, and turn them back into real-value queries/keys. They are useful because:
- They incorporate both absolute and relative information.
- They are efficient to compute, because we avoid matrix multiplication.
- Old tokens do not change their encodings when new tokens are added, allowing KV-caching.
Let's now look at the forward pass. After extracting the token embeddings and the angles corresponding to the current sequence positions, we create a causal mask to prevent attending to future tokens. This is only needed at the first call, when we process the prompts of shape \((B, M)\). Subsequent calls have the tokens
of shape \((B, 1)\). The embeddings of shape \((B, M, D)\) and the extracted angles are then passed to the first of many transformer blocks.
In a Llama-3 block, we first normalize the token embeddings using RMSNorm:
It is applied over the channel dimension and supposedly works just as well as LayerNorm, but less computationally demanding. Next, comes the self-attention. We compute K, Q, V, and rotate the keys and queries, as discussed above. Now comes the KV-caching. Each attention module has a K-cache and a V-cache, both of shape \((B, S, H, D/H)\), initialized to zeros. We store the newly computed K,V into their respective locations, and then retrieve all K,V up to the current location. We apply attention in the standard way, using the causal mask.
Note what happens on the next forward passes. The shape of tokens
is \((B, 1)\). The embeddings are \((B, 1, D)\) and we only compute the new K, Q, V. We store K, V in the caches and retrieve everything up to the current position. For the attention we only have the new queries but we have all previous keys/values. Thus, we still attend over all past tokens. It's just that the KV-caching saves us from computing the keys and values of past tokens, which definitely saves on compute. We don't use a causal mask because the queries are \((B, 1, D)\) and there are no future tokens to attend to.
Llama-3 then adds the attention block outputs to the inputs in a residual connection. A RMSNorm + feed forward + another residual connection follow. This is the full attention block. In Llama-8B there are 32 of those. At the end there is a linear layer projecting to the size of the vocabulary and estimating the logits for the next token.
Once we have the logits there are different ways to sample the next token. For a greedy approach, one would select the token with the higest logit. Otherwise we can softmax them and sample. In a top-\(k\) fashion we sample from the \(k\) tokens with highest probability. In a top-\(p\) fashion, called nucleus sampling, we sample from those tokens whose total probability just exceeds \(p\).
Finally, because of its scale, Llama-3 uses specialized layers from the fairscale library, amenable to parallelization. These include a VocabParallelEmbedding
for storing the token embeddings and RowParallelLinear
, ColumnParallelLinear
which can split their tensors across the rows/columns, calculate on multiple GPUs, and then all-reduce the result. The MLP used in each transformer block has the form RowParallelLinear(silu(ColumnParallelLinear(x)) * ColumnParallelLinear(x))
.
Server
Let's also take a quick look at how to actually serve an LLM. There are a lot of important tricks to optimize this part, including model quantization, pruning, continuous batching, which we won't cover. However, one interesting thing is conversation context management, i.e. how to make the LLM pay attention to earlier parts of a conversation, perhaps many user prompts ago. How to best do this is not something that the community has converged on. Obviously, if the conversation becomes so long that it exceeds the maximum sequence length of the model, you'd have to cut the conversation. But where? How do you construct a short prompt that summarizes the whole conversation so far?
We'll consider Ollama, an inference engine for LLMs, which uses llama.cpp for the actual underlying function calls. Let's forget about chatting and take a look at text completion first. I've obtained most of what follows from tracing the Ollama server logs and it may not be 100% precise.
From ollama/server/routes.go
we see that /api/generate
gets handled by GenerateHandler
. Within that function we schedule a llmServer
, basically an instance of the llama.cpp
server. The request itself contains various information specifications such as the model to use, the user and system prompts, any conversation context, and images. A key thing that happens here is to apply a template to the prompt. The template contains the precious metadata of who is saying what and has a particular format. For example, for Llama-3 the prompt Hello?
will be turned into
<|start_header_id|> user <|end_header_id|> Hello? <|eot_id|>
<|start_header_id|> assistant <|end_header_id|>
There are three roles: user
, system
, and assistant
, and they are enclosed into a header. The special token <|eot_id|>
marks the end of a turn. Note how "templating" the conversation allows us to turn something like a chat, where there are questions, into a sequence which is to be continued by the model. This happens by appending the assistant
header to the end, so that the question has been rephrased as a sequence completion task.
We can pass a longer context, which is a list of token IDs representing continued conversation. For example, the context
[128006,882,128007,271,9906,30,128009,128006,78191,128007,271,9906,0,1102,596,6555,
311,3449,499,13,2209,1070,2555,358,649,1520,499,449,11,477,1053,499,1093,311,6369,30]
corresponds to
<|start_header_id|> user <|end_header_id|> Hello? <|eot_id|>
<|start_header_id|> assistant <|end_header_id|>
Hello! It's nice to meet you. Is there something I can help you with, or would you like to chat?
So, in GenerateHandler
we detokenize the optional context, combine it with the current prompt, and apply the template, producing the full input to the model. We then call a function named Completion
, found in ollama/llm/server.go
. This function does a lot of things but importantly sends a POST request to /completion
for the underlying llmServer
to complete.
Now we need to see how the llama.cpp
server handles /completion
requests. This is found in ollama/llm/ext_server/server.cpp
. We see that the llmServer
creates a new "task", with a new ID, and pushes it into a task queue using the request_completion
function. The actual queue is implemented in a struct called llama_server_queue
from the ollama/llm/ext_server/utils.hpp
file. It enters a loop in which it waits until a new task arrives, processes the task by copying data to a "slot" (a kind of single execution unit for llama.cpp
), and runs all slots. When running the slots, update_slots
, the actual function that does all the computation, is called.
Here there is a lot of logic related to actual slots, which involves storing task information, context management, and token processing. What concerns us is what happens if the input prompt is beyond the model's maximum sequence length. In that case the prompt is truncated. This happens by retaining the first n_keep
tokens from the message, cutting the remaning tokens into two parts, and discarding the first part. Essentially, if the prompt is A B C D E F G H I J K L
, num_keep
is \(2\), and the context length is \(10\), then after truncation the prompt will be A B H I J K L . . .
. This has the benefit that it keeps the beginning and end of the context, so it roughly knows how the conversation started and what is the last thing said. In many cases this is enough to produce meaningful long conversations. But obviously, it is not able to reference facts and results from the middle of the conversation.
Ollama also support a /api/chat
interface. The only difference here is that the context in the case of /api/chat
is passed as a list of Message
instances, which are just structs holding the role and content of a prompt or answer. What happens if this sequence of messages becomes too long? The logic is essentially the same, we will have to truncate the conversation history. This happens in the chatPrompt
function in ollama/server/prompt.go
. Basically, the sequence of messages is cut in a way that does not violate two requirements. First is that any system prompts should be preserved. This is useful to have because the system prompts typically set the identity and guardrails for the chatbot and these should be consistent across the whole conversation. Second is that the latest message should be kept. This ensures that the answer is maximally relevant to the last user prompt. On a side node, I admire the humorous test cases found in ollama/server/prompt_test.go
. They show that we simply discard older non-system messages.
{
name: "truncate messages",
limit: 1,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "A test. And a thumping good one at that, I'd wager. ",
},
},
Of course, there are many more things that can be optimized when it comes to serving these models - the context can be artificially extended using Self-Extend, the model can be heavily quantized to take up less memory, optimized formats like gguf
can be used for faster loading, speculative decoding can be used to generate multiple tokens at once, and so on. All exciting stuff we'll leave for the future.