GPT from Scratch
These days, I’m exploring the field of natural language generation, using auto-regressive models such as GPT-2. HuggingFace transformers offers a host of pretrained language models, many of which can be used off the shelf with minimal fine-tuning. In this post, however, we will try to build a small GPT model from scratch using PyTorch. I’ve realized that sometimes I feel eerily insecure about using things without truly understanding how they work. This post is an attempt at understanding the GPT model. Turns out that its implementation is actually quite similar to the transformers implementation I wrote about in an earlier post, so you will see some repeats here and there. Also helpful was Karpathy’s minGPT repository, which I heavily referenced for this particular implementation. With that sorted, let’s get started!
Setup
We import the modules we will need for this tutorial. Most of it comes from torch, with the exception of the built-in math module.
import math
import torch
from torch import nn
import torch.nn.functional as F
From using HuggingFace transformers, and also looking at Karpathy’s implementation, I realized that it is customary for there to be a configuration object that contains all the initialization parameters of a model. The snippet below, taken from Karpathy’s repository, demonstrates how we can build a basic class that contains various constants and hyperparameters of the model to build. Note that one can easily extend this configuration class to create GPT-2 or GPT-3 configuration classes, which would simply amount to having more layers, maximum sequence length, and embedding dimension.
class GPTConfig:
attn_dropout = 0.1
embed_dropout = 0.1
ff_dropout = 0.1
def __init__(
self, vocab_size, max_len, **kwargs
):
self.vocab_size = vocab_size
self.max_len = max_len
for key, value in kwargs.items():
setattr(self, key, value)
class GPT1Config(GPTConfig):
num_heads = 12
num_blocks = 12
embed_dim = 768
If you are already familiar with the transformer architecture, a lot of this will seem familiar. max_len
refers to the maximum length that can be processed by the model. Because transformer models process all inputs at once in parallel, its window span is not infinite (hence the introduction of models to remedy this limitation, such as Transformer XL). vocab_size
denotes the size of the vocabulary, or in other words, how many tokens the model is expected to know. num_blocks
represents the number of transformer decoder layers; num_heads
, the number of heads.
Implementation
We can now start building out the model. Shown below is the overarching model architecture, which, in my opinion, is surprisingly short and simple.
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
embed_dim = config.embed_dim
self.max_len = config.max_len
self.tok_embed = nn.Embedding(
config.vocab_size, embed_dim
)
self.pos_embed = nn.Parameter(
torch.zeros(1, config.max_len, embed_dim)
)
self.dropout = nn.Dropout(config.embed_dropout)
self.blocks = nn.Sequential(
*[Block(config) for _ in range(config.num_blocks)]
)
self.ln = nn.LayerNorm(embed_dim)
self.fc = nn.Linear(embed_dim, config.vocab_size)
def forward(self, x, target=None):
# batch_size = x.size(0)
seq_len = x.size(1)
assert seq_len <= self.max_len, "sequence longer than model capacity"
tok_embedding = self.tok_embed(x)
# tok_embedding.shape == (batch_size, seq_len, embed_dim)
pos_embedding = self.pos_embed[:, :seq_len, :]
# pos_embedding.shape == (1, seq_len, embed_dim)
x = self.dropout(tok_embedding + pos_embedding)
x = self.blocks(x)
x = self.ln(x)
x = self.fc(x)
# x.shape == (batch_size, seq_len, vocab_size)
return x
The reason why the model seems so deceptively simple is that, really, the bulk of the model comes from GPT.block
, which is the block of transformer decoder layers that does most of the heavy lifting. The only logic that is interesting in this class is the part in which we combine the token and positional embeddings to create the input to the block of decoders.
One implementation detail that I learned from Karpathy’s code is how he dealt with positional embeddings. Instead of having a dedicated trainable positional embedding layer, we can simply register a lookup matrix as a positional embedding layer of sorts, then simply slice the matrix up to the appropriate sequence length, depending on the length of the input. I think is a more elegant way of implementing positional embedding than calling torch.range()
on each forward pass, which is what would have been required had we followed the embedding layer approach.
The final output is of shape (batch_size, seq_len, vocab_size)
. We can thus interpret the output as a token prediction per position. We can use techniques such as teacher forcing to train the model and update its parameters. In an auto-regressive generation context, we would use the prediction for the token in the last position of the sequence and append it to the original input, then feed the entire modified sequence back into the model.
Decoder Block
Let’s take a look at the building blocks of the decoder, the transformer decoder block. A decoder block consists of multi-head attention, layer normalization, and a point-wise feedforward network. We use a residual connection in between each component. The feedforward network can be understood as a layer that temporarily augments the latent dimension on which the contextual embeddings live to enrich their representations. Again, the bulk of the interesting work is deferred to the multi-head self-attention layer.
class Block(nn.Module):
def __init__(self, config):
super().__init__()
embed_dim = config.embed_dim
self.ln1 = nn.LayerNorm(embed_dim)
self.ln2 = nn.LayerNorm(embed_dim)
self.attn = MultiheadAttention(config)
self.ff = nn.Sequential(
nn.Linear(embed_dim, embed_dim * 4),
nn.GELU(),
nn.Linear(embed_dim * 4, embed_dim),
nn.Dropout(config.ff_dropout),
)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.ff(self.ln2(x))
return x
The multi-head self-attention layer contains the interesting key-value-query operation. I might also take this as an opportunity to discuss some interesting discussions and details I have seen here and there.
- Self-attention can be seen as a graph neural net, or GNN, where each token in the input sequence is a node, and the edges denote the relationship between each tokens. In the case of an encoder block, the GNN is full and complete, meaning that every node is connected to every other node. In a decoder, however, the tokens are only connected to other tokens that came before it in the input sequence.
- If one decides to use key vectors as query vectors, thus effectively removing an entire query matrix $W_Q$, the graph neural net effectively becomes a undirected graph. This is because the relationship between node A and B is no different from that between node B and A. In other words, $\text{Attention}(n_a, n_b) = \text{Attention}(n_b, n_a)$. However, in the original, more common transformer implementation, in which the query vector is distinct from the key vector, this commutative relationship does not necessarily stand, thus making the attention layer a directed graph.
These are intuitive, heuristic interpretations at best, and I probably do not know enough about GNNs to make more nuanced comments beyond what I’ve written above. Nonetheless, I find this interpretation extremely interesting.
Multi-Head Attention
Returning back to where we were, below is an implementation of the multi-head attention layer. This is very similar to how we implemented the layer in the transformers post, so I’ll omit a lot of the detailed exposition. I’ve added as comment the shape of each output to help with understanding the flow of the forward pass.
class MultiheadAttention(nn.Module):
def __init__(self, config):
super().__init__()
embed_dim = config.embed_dim
self.num_heads = config.num_heads
assert embed_dim % self.num_heads == 0, "invalid heads and embedding dimension configuration"
self.key = nn.Linear(embed_dim, embed_dim)
self.value = nn.Linear(embed_dim, embed_dim)
self.query = nn.Linear(embed_dim, embed_dim)
self.proj = nn.Linear(embed_dim, embed_dim)
self.attn_dropout = nn.Dropout(config.attn_dropout)
self.proj_dropout = nn.Dropout(config.ff_dropout)
self.register_buffer(
"mask",
torch.tril(torch.ones(config.max_len, config.max_len))
.unsqueeze(0).unsqueeze(0)
)
def forward(self, x):
batch_size = x.size(0)
seq_len = x.size(1)
# x.shape == (batch_size, seq_len, embed_dim)
k_t = self.key(x).reshape(batch_size, seq_len, self.num_heads, -1).permute(0, 2, 3, 1)
v = self.value(x).reshape(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
q = self.query(x).reshape(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
# shape == (batch_size, num_heads, seq_len, head_dim)
attn = torch.matmul(q, k_t) / math.sqrt(q.size(-1))
# attn.shape == (batch_size, num_heads, seq_len, seq_len)
mask = self.mask[:, :, :seq_len, :seq_len]
attn = attn.masked_fill(mask == 0, float("-inf"))
attn = self.attn_dropout(attn)
# attn.shape == (batch_size, num_heads, seq_len, seq_len)
attn = F.softmax(attn, dim=-1)
y = torch.matmul(attn, v)
# y.shape == (batch_size, num_heads, seq_len, head_dim)
y = y.transpose(1, 2)
# y.shape == (batch_size, seq_len, num_heads, head_dim)
y = y.reshape(batch_size, seq_len, -1)
# y.shape == (batch_size, seq_len, embed_dim)
y = self.proj_dropout(self.proj(y))
return y
The part I found the most confusing was how masking played with the attention matrix. While I conceptually knew that the decoder is not supposed to see future tokens, which is why masking is necessary, it really helps to see what is going on with a more concrete example.
Below is a mask matrix, assuming that we have a decoder whose maximum sequence length is 5.
max_len = 5
mask = torch.tril(torch.ones(max_len, max_len)).unsqueeze(0).unsqueeze(0)
mask
tensor([[[[1., 0., 0., 0., 0.],
[1., 1., 0., 0., 0.],
[1., 1., 1., 0., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 1.]]]])
The idea is that the model should be able to see the values of the attention matrix for elements that are in 1s; for 0 positions, we need to apply a mask.
Let’s assume that the model accepts a batch of input whose sequence lenght is 3. Then, we would only use a portion of the mask matrix.
seq_len = 3
mask = mask[:, :, :seq_len, :seq_len]
mask
tensor([[[[1., 0., 0.],
[1., 1., 0.],
[1., 1., 1.]]]])
The point that confused me the most is that I didn’t quite fully grasp how mask works with the notion of a batch. Recall that we called .unsqueeze(0)
on the mask matrix. This is because we want to deal with batches through broadcasting. Thus, the mask should be considered as dealing with only one example of the batch. This makes thinking about the mask and input sequences a little less challenging. In other words, the mask can now simply be seen as being applied to one example of length 3. For the first token, the model should only attend to the leading token itself; for tokens that follow, it is allowed to look back, but not ahead; hence the triangular shape of the mask.
Let’s put all of this into perspective by coming up with an attention matrix.
# attn.shape == (batch_size, num_heads, seq_len, seq_len)
batch_size = 3
num_heads = 2
attn = torch.randn(batch_size, num_heads, seq_len, seq_len)
attn.shape
torch.Size([3, 2, 3, 3])
Here, we have an input with three batches. Our model only has two heads. In this case, when we apply making to the attention matrix, we end up with the following result.
attn = attn.masked_fill(mask == 0, float("-inf"))
attn
tensor([[[[-0.6319, -inf, -inf],
[ 0.7736, -0.4394, -inf],
[ 0.2407, 0.8301, -0.2763]],
[[ 0.4821, -inf, -inf],
[ 1.3904, -2.0258, -inf],
[ 0.3205, 1.8750, -1.0537]]],
[[[ 0.3154, -inf, -inf],
[-2.1034, -0.2958, -inf],
[ 0.4362, -0.8575, 1.8995]],
[[ 0.5619, -inf, -inf],
[-0.3208, -0.6639, -inf],
[ 0.6854, -0.9504, 0.2803]]],
[[[ 0.0928, -inf, -inf],
[ 0.3951, -0.0538, -inf],
[-0.9994, -2.0981, -0.1262]],
[[-0.9176, -inf, -inf],
[-0.3652, -0.9505, -inf],
[-1.2675, 0.0186, 0.0417]]]])
Applying a softmax over this matrix yields the following result. Notice that each row sums up to one; this is how weighted averaging is performed.
F.softmax(attn, dim=-1)
tensor([[[[1.0000, 0.0000, 0.0000],
[0.7708, 0.2292, 0.0000],
[0.2942, 0.5304, 0.1754]],
[[1.0000, 0.0000, 0.0000],
[0.9682, 0.0318, 0.0000],
[0.1671, 0.7907, 0.0423]]],
[[1.0000, 0.0000, 0.0000],
[0.1409, 0.8591, 0.0000],
[0.1787, 0.0490, 0.7722]],
[[1.0000, 0.0000, 0.0000],
[0.5850, 0.4150, 0.0000],
[0.5372, 0.1046, 0.3582]]],
[[[1.0000, 0.0000, 0.0000],
[0.6104, 0.3896, 0.0000],
[0.2683, 0.0894, 0.6423]],
[[1.0000, 0.0000, 0.0000],
[0.6423, 0.3577, 0.0000],
[0.1202, 0.4348, 0.4450]]]])
Model
We can now put everything together. Let’s create a basic model configuration, then initialize the model.
vocab_size = 10
max_len = 12
config = GPT1Config(vocab_size, max_len)
model = GPT(config)
This is just a basic 12-layer decoder network. Nowadays, large LMs are gargantuan, too big, in fact, that they do not fit into a single GPU. Nonetheless, our mini GPT model is still pretty respectable in my opinion.
model
GPT(
(tok_embed): Embedding(10, 768)
(dropout): Dropout(p=0.1, inplace=False)
(blocks): Sequential(
(0): Block(
(ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): MultiheadAttention(
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(query): Linear(in_features=768, out_features=768, bias=True)
(proj): Linear(in_features=768, out_features=768, bias=True)
(attn_dropout): Dropout(p=0.1, inplace=False)
(proj_dropout): Dropout(p=0.1, inplace=False)
)
(ff): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU()
(2): Linear(in_features=3072, out_features=768, bias=True)
(3): Dropout(p=0.1, inplace=False)
)
)
(1): Block(
(ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): MultiheadAttention(
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(query): Linear(in_features=768, out_features=768, bias=True)
(proj): Linear(in_features=768, out_features=768, bias=True)
(attn_dropout): Dropout(p=0.1, inplace=False)
(proj_dropout): Dropout(p=0.1, inplace=False)
)
(ff): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU()
(2): Linear(in_features=3072, out_features=768, bias=True)
(3): Dropout(p=0.1, inplace=False)
)
)
(2): Block(
(ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): MultiheadAttention(
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(query): Linear(in_features=768, out_features=768, bias=True)
(proj): Linear(in_features=768, out_features=768, bias=True)
(attn_dropout): Dropout(p=0.1, inplace=False)
(proj_dropout): Dropout(p=0.1, inplace=False)
)
(ff): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU()
(2): Linear(in_features=3072, out_features=768, bias=True)
(3): Dropout(p=0.1, inplace=False)
)
)
(3): Block(
(ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): MultiheadAttention(
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(query): Linear(in_features=768, out_features=768, bias=True)
(proj): Linear(in_features=768, out_features=768, bias=True)
(attn_dropout): Dropout(p=0.1, inplace=False)
(proj_dropout): Dropout(p=0.1, inplace=False)
)
(ff): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU()
(2): Linear(in_features=3072, out_features=768, bias=True)
(3): Dropout(p=0.1, inplace=False)
)
)
(4): Block(
(ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): MultiheadAttention(
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(query): Linear(in_features=768, out_features=768, bias=True)
(proj): Linear(in_features=768, out_features=768, bias=True)
(attn_dropout): Dropout(p=0.1, inplace=False)
(proj_dropout): Dropout(p=0.1, inplace=False)
)
(ff): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU()
(2): Linear(in_features=3072, out_features=768, bias=True)
(3): Dropout(p=0.1, inplace=False)
)
)
(5): Block(
(ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): MultiheadAttention(
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(query): Linear(in_features=768, out_features=768, bias=True)
(proj): Linear(in_features=768, out_features=768, bias=True)
(attn_dropout): Dropout(p=0.1, inplace=False)
(proj_dropout): Dropout(p=0.1, inplace=False)
)
(ff): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU()
(2): Linear(in_features=3072, out_features=768, bias=True)
(3): Dropout(p=0.1, inplace=False)
)
)
(6): Block(
(ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): MultiheadAttention(
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(query): Linear(in_features=768, out_features=768, bias=True)
(proj): Linear(in_features=768, out_features=768, bias=True)
(attn_dropout): Dropout(p=0.1, inplace=False)
(proj_dropout): Dropout(p=0.1, inplace=False)
)
(ff): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU()
(2): Linear(in_features=3072, out_features=768, bias=True)
(3): Dropout(p=0.1, inplace=False)
)
)
(7): Block(
(ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): MultiheadAttention(
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(query): Linear(in_features=768, out_features=768, bias=True)
(proj): Linear(in_features=768, out_features=768, bias=True)
(attn_dropout): Dropout(p=0.1, inplace=False)
(proj_dropout): Dropout(p=0.1, inplace=False)
)
(ff): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU()
(2): Linear(in_features=3072, out_features=768, bias=True)
(3): Dropout(p=0.1, inplace=False)
)
)
(8): Block(
(ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): MultiheadAttention(
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(query): Linear(in_features=768, out_features=768, bias=True)
(proj): Linear(in_features=768, out_features=768, bias=True)
(attn_dropout): Dropout(p=0.1, inplace=False)
(proj_dropout): Dropout(p=0.1, inplace=False)
)
(ff): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU()
(2): Linear(in_features=3072, out_features=768, bias=True)
(3): Dropout(p=0.1, inplace=False)
)
)
(9): Block(
(ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): MultiheadAttention(
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(query): Linear(in_features=768, out_features=768, bias=True)
(proj): Linear(in_features=768, out_features=768, bias=True)
(attn_dropout): Dropout(p=0.1, inplace=False)
(proj_dropout): Dropout(p=0.1, inplace=False)
)
(ff): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU()
(2): Linear(in_features=3072, out_features=768, bias=True)
(3): Dropout(p=0.1, inplace=False)
)
)
(10): Block(
(ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): MultiheadAttention(
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(query): Linear(in_features=768, out_features=768, bias=True)
(proj): Linear(in_features=768, out_features=768, bias=True)
(attn_dropout): Dropout(p=0.1, inplace=False)
(proj_dropout): Dropout(p=0.1, inplace=False)
)
(ff): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU()
(2): Linear(in_features=3072, out_features=768, bias=True)
(3): Dropout(p=0.1, inplace=False)
)
)
(11): Block(
(ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): MultiheadAttention(
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(query): Linear(in_features=768, out_features=768, bias=True)
(proj): Linear(in_features=768, out_features=768, bias=True)
(attn_dropout): Dropout(p=0.1, inplace=False)
(proj_dropout): Dropout(p=0.1, inplace=False)
)
(ff): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU()
(2): Linear(in_features=3072, out_features=768, bias=True)
(3): Dropout(p=0.1, inplace=False)
)
)
)
(ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(fc): Linear(in_features=768, out_features=10, bias=True)
)
Let’s create a dummy input and see if the model successfully acts the way we expect to. First, let’s try passing in a degenerate input whose length is beyond model capacity.
seq_len = 15
test_input = torch.randint(high=vocab_size, size=(batch_size, seq_len))
try:
model(test_input).shape
except AssertionError as e:
print(e)
sequence longer than model capacity
We get an appropriate assertion error, saying that the sequence is longer than the maximum sequence length that the model can process. Let’s see what happens if we pass in a valid input.
model(test_input[:, :max_len]).shape
torch.Size([3, 12, 10])
As expected, we get a valid output of shape (batch_size, seq_len, vocab_size)
.
Of course, we can continue from here by training the model, but that is probably something we can try another day on a Colab notebook as opposed to my local Jupyter environment.
Conclusion
In this post, we took a look at how the transformer decoder works, using a mini GPT model as an example. Note that the GPT model all have a pretty similar architecture; the only difference lies in the size of the model and the dataset corpus on which it was trained. Obviously, larger models require larger datasets.
I’ve realized that there are a lot of models that build on top of multi-head self-attention. These include models like the Reformer, which applied clever modifications to the algorithm to optimize it down to linear runtime, or new notions of embeddings, such as relative positional embeddings. We will explore these topics in future posts.
I hope you’ve enjoyed reading. Catch you up in the next one!