Transformer From Scratch In PyTorch: Model
The Transformer architecture, first introduced in (Vaswani et. al. 2017), is an encoder-decoder model that can be used in many scenarios of supervised sequence learning. The success of the Transformer is primarily due to its performance, simple architecture, and its ability to parallelize input which drastically speeds up training. This is in comparison with previous traditional sequence learning models, such as recurrent neural networks, which would process elements of a sequence one at a time.
In this post, we'll build the Transformer model from scratch in PyTorch with an emphasis on modularity and performance. Note that in our implementation, we will be following the Pre-Layer Normalization version of the Transformer.
Imports
Here, we summarize the imports and global variables we will be using in our implementation.
import torch
import torch.nn.functional as F
import torch.nn as nn
import logging
LOGGER = logging.getLogger(__name__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Overview
Before diving into the code, we give a high-level overview of the Transformer architecture. The Transformer model follows an encoder-decoder architecture, where the source sequence is fed into an encoder, and the target sequence and encoder output is fed into a decoder. The decoder then outputs probabilities.
Below is a diagram of the transformer architecture I made.
This diagram actually demonstrates the architecture of the Pre-Layer Normalization Transformer, which differs from the original transformer from Attention is All You Need, which is a Post-Layer Normalization Transformer. The difference will be explained later in this post. We follow the Pre-Layer Normalization Transformer as it is the transformer model used in most applications because it has been shown to be superior in (Xiong et. al., 2020).
Using the diagram above, we can summarize the forward pass of the Transformer at a high level.
- Given a batch of pairs of input sequences, output sequences, we embed the input and outputs using their respective embedding matrices and sum these with positional encoding matrices.
- The input embeddings travel through
encoder layers. Each encoder layer consists of two sublayers, including multihead attention (to be defined) and a feed forward network. - The output embeddings travel through
decoder layers. Each decoder layer consists of three sublayers, including masked (to be defined) multihead attention, an encoder-decoder attention layer, and a feed forward network. - Finally, the output is projected back into a vector space whose dimension is the same as the dimension of the target vocabularly space, and softmax is applied element-wise.
Unlike previous sequence-learning models, the Transformer design allows much parallelization which significantly speeds up training. It also avoids many issues with exploding gradients that RNNs and LSTMs were known to suffer from.
As we now understand at a high-level how the architecture operates, we now turn to the individual components that are necessary to create the Transformer encoder and decoder. The components of the encoder, decoder, and how the encoder outputs are fed into the decoder, are what makes the Transformer successful.
Attention mechanism
We'll first start with the attention. In the original paper, the Transformer architecture relies on a relatively simple model for the concept of attention. Generally, attention is a function that takes in a
- A query
- A set of key-value pairs
where and for
and returns a vector
More generally, if we have
This then leads us to define attention more generally as
The Pytorch code for this would then be as follows.
def attention(
Q: torch.tensor,
K: torch.tensor,
V: torch.tensor,
dropout: Optional[nn.Dropout] = None,
mask: Optional[torch.tensor] = None,
) -> tuple[torch.tensor, Optional[torch.tensor]]:
"""
Computes attention given query, keys, values.
If we have n-many key-value pairs of dimension dk, dv respectively
and m-many queries of dimension dk, then
- Q has shape (batch_size, m, dk)
- K has shape (batch_size, n, dk)
- V has shape (batch_size, n, dv)
In the transformer architecture we have
- m = n = seq_len
- dk = dv = dmodel
Thus, the attention_weights has shape (batch_size, seq_len, seq_len)
and the attended_values has shape (batch_size, seq_len, d_model)
"""
LOGGER.debug(
f"computing attention with dimensions {Q.size()=} {K.size()=} {V.size()=}"
f" with mask.size()={mask.size() if mask is not None else None}"
)
dk = float(Q.size(-1))
# Compute attention
scale = torch.sqrt(torch.tensor(dk)).to(device)
attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / scale
# Apply attention mask (if provided).
if mask is not None:
LOGGER.debug(f"Applying {mask.size()=} to {attention_scores.size()=}")
attention_scores = attention_scores.masked_fill(mask == 0, -1e9)
attention_weights = F.softmax(attention_scores, dim=-1)
if dropout is not None:
attention_weights = dropout(attention_weights)
# Calculate the weighted sum of values
attended_values = torch.matmul(attention_weights, V)
return attended_values, attention_weights
We make a few comments on this code.
- This code allows us to calculate attention for a triplet of 2D tensors. However, this code can handle higher dimensional tensors, which is good as we can then parallel-compute batches of attention instead of calling this function in a for-loop.
- We apply a dropout to the attention calculation as per the original Transformer paper.
- This attention function optionally takes in a mask, which essentially zeros out the attention calculation in certain positions. This is necessary for many attention calculations in the Transformer, which we'll discuss briefly now.
Multihead attention
In the Transformer architecture, instead of performing a single attention computation on the
Each attention computation is referred to as an attention head, and the
The
These attention heads are then concatenated and applied to a final projection matrix
To write this in code, we could explicitly define torch.nn.Linear
layers. However, note that we can do some parallel computation.
For example, we can define one matrix
To do this, we write a method called split_heads
which takes in a Pytorch tensors and splits the 2D matrix columns into
def split_heads(Q: torch.tensor, num_heads: int):
"""
Split the last dimension into (num_heads, head_dim).
Reshape the tensor to (batch_size, seq_length, num_heads, head_dim)
and then transpose to get (batch_size, num_heads, seq_length, head_dim).
"""
batch_size, seq_length, d_model = Q.size()
head_dim = d_model // num_heads
# Reshape to separate heads
Q = Q.view(batch_size, seq_length, num_heads, head_dim)
# Transpose to get (batch_size, num_heads, seq_length, head_dim)
Q = Q.transpose(1, 2)
return Q
This method will split the 2D matrices, columnwise, into an num_heads
sized chunks. Each chunk is a view of the original tensor,
and
these chunks are then stacked together, horizontally, into individual 2D matrices. The stack
call copies the chunks
(and so does torch.cat
), and apparently it is too complex
for the developers to implement concatenation without copying. In any case, this will transform a tensor like so:
tensor([[[ 1., 2.],
[ 5., 6.],
[ 9., 10.],
[13., 14.]],
[[ 3., 4.],
[ 7., 8.],
[11., 12.],
[15., 16.]]])
class MultiheadAttention(nn.Module):
"""
Class to compute multihead attention with num_heads-many heads
"""
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
super(MultiheadAttention, self).__init__()
self.num_heads = num_heads
self.head_dim = d_model // num_heads
# Linear projection for the attention heads
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
# Linear projection for the output layer
self.W_o = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(
self,
Q: torch.tensor,
K: torch.tensor,
V: torch.tensor,
mask: Optional[torch.tensor] = None,
):
LOGGER.debug(
f"Computing multihead attention with {Q.size()=} {K.size()=} {V.size()=}"
f" with mask.size()={mask.size() if mask is not None else None}"
)
Q = self.W_q(Q)
K = self.W_k(K)
V = self.W_v(V)
batch_size = Q.size(0)
d_model = Q.size(-1)
# Split into multiple heads. Shape should now be (batch_size, num_heads, seq_length, head_dim)
Q = split_heads(Q, self.num_heads)
K = split_heads(K, self.num_heads)
V = split_heads(V, self.num_heads)
# Add an extra dimension to the mask to get (batch_size, 1, 1, seq_length)
if mask is not None:
mask = mask.unsqueeze(1)
# Compute attention
output, attention_weights = attention(Q, K, V, dropout=self.dropout, mask=mask)
# Concatenate the heads and compute transformation
output = output.permute(0, 2, 1, 3).reshape(batch_size, -1, d_model)
output = self.W_o(output)
# if self.training:
# return output, None
return output, attention_weights
Position-wise Networks
Another relatively simple component of the Transformer architecture are position-wise feed-forward neural networks. These are
standard feed-forward networks which process each sequence element independently. If our input consists of
where
where
class PositionwiseFeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super(PositionwiseFeedForward, self).__init__()
self.W_1 = nn.Linear(d_model, d_ff)
self.relu = nn.ReLU()
self.W_2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.tensor) -> torch.tensor:
"""
Computes
FFN(x_i) = ReLU(x_iW_1 + b_1)W_2 + b_2.
- x has shape batch_size \\times seq_length \\times d_model
"""
output = self.W_1(x)
output = self.dropout(self.relu(output))
output = self.W_2(output)
return output
Layer Normalization
Another component of the architecture is layer normalization. This is a strategy to stabilize training and reduce training time, introduced in (Lei Be et. al.). In the transformer architecture, this is applied in combination with a residual connection in each sublayer in the encoder and decoder. Specifically, the equation is
This is how layer normalization was incorporated into the Transformer architecture as presented in the original paper, and this is called post-layer normalization. However, most if not all implementations of the Transformer will actually perform pre-layer normalization which we compute as
since as shown in (Xiong et. al., 2020) it leads to much better training performance. In any case, the PyTorch module for layer normalization is given below.
class LayerNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-5):
"""
Computes layer normalization.
LayerNorm(x) =
\\gamma \cdot \\frac{x - \\mu}{\\sqrt{\\sigma^2 + \\epsilon}} + \\beta
where
- \\gamma is a scale parameter
- \\mu is the mean
- \\sigma is the standard deviation
- \\epsilon is an offset for numerical stability
- \\beta is a shift parameter.
For training purposes \\sqrt{\\sigma^2 + \\epsilon} ~= \\sigma + \\epsilon.
"""
super(LayerNorm, self).__init__()
self.d_model = d_model
self.eps = eps
# Learnable scale and shift parameters
self.gamma = nn.Parameter(torch.ones(d_model))
self.beta = nn.Parameter(torch.zeros(d_model))
def forward(self, x: torch.tensor) -> torch.float:
# Calculate mean and standard deviation along the last dimension
mean = x.mean(dim=-1, keepdim=True)
std = x.std(dim=-1, keepdim=True)
# Apply LayerNorm formula
x_normalized = self.gamma * (x - mean) / (std + self.eps) + self.beta
return x_normalized
Positional encoding
Before introducing the actual Encoder, we must also implement positional encoding. The purpose of position encoding is
to inject the model input with information about the
In the original Transformer paper, they implemented sinusoidal positional encoding, which builds a matrix
where
def positional_encoding(max_len: int, d_model: int):
"""
Computes positional encoding according to
PE(pos, 2i) = sin(pos/10000^{2i / dmodel})
PE(pos, 2i + 1) = cos(pos/10000^{2i / dmodel})
"""
div_terms = torch.pow(torch.tensor(10_000.0), torch.arange(0, d_model, 2) / d_model)
pos_enc = torch.arange(max_len, dtype=torch.float32).unsqueeze(1).repeat(1, d_model)
pos_enc[:, 0::2] = torch.sin(pos_enc[:, 0::2] / div_terms)
pos_enc[:, 1::2] = torch.cos(pos_enc[:, 1::2] / div_terms)
return pos_enc
Encoder
At this point, we have everything written to now define the encoder layer. The encoder is duplicated 6 times before sending its output to the decoder. Each encoder layer consists of layer normalization, multihead self-attention, layer normalization again, and a pointwise feed-forward network. We write the PyTorch module as below. Note that this implements pre-layer normalization, which differs from the original Transformer architecture that implemented post-layer normalization.
class EncoderLayer(nn.Module):
"""
Implements a single Encoder layer with pre-layer normalization.
"""
def __init__(self, d_model: int, num_heads: int, d_ffn: int, dropout: float = 0.1):
super(EncoderLayer, self).__init__()
# Self-attention sub-layer
self.self_attention = MultiheadAttention(d_model, num_heads, dropout=dropout)
# Position-wise feedforward sub-layer
self.feedforward = PositionwiseFeedForward(d_model, d_ffn, dropout=dropout)
# Layer Normalization
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
# Dropout
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.tensor, mask: Optional[torch.tensor] = None):
# Multihead self-attention sub-layer
x_norm = self.norm1(x)
attention_output, _ = self.self_attention(x_norm, x_norm, x_norm, mask=mask)
x = x + self.dropout(attention_output)
# Position-wise feedforward sub-layer
x_norm = self.norm2(x)
ff_output = self.feedforward(x_norm)
output = x + self.dropout(ff_output)
return output
We can then use this EncoderLayer
class to define our main Encoder
class, which can instantiate and
connect any number of encoder layers together.
class Encoder(nn.Module):
"Class for encoder, which consists of N-many EncoderLayers"
def __init__(
self,
num_stacks: int,
d_model: int,
num_heads: int,
d_ffn: int,
dropout: float = 0.1,
):
super(Encoder, self).__init__()
self.layers = nn.ModuleList(
[
EncoderLayer(
d_model=d_model, num_heads=num_heads, d_ffn=d_ffn, dropout=dropout
)
for _ in range(num_stacks)
]
)
def forward(self, x: torch.tensor, mask: torch.tensor):
"Pass the input (and mask) through each layer in turn."
for layer in self.layers:
x = layer(x, mask)
return x
Decoder
Similarly we can define our DecoderLayer
class to represent one instance of a decoder layer.
In the Transformer model, the decoder similarly uses 6 decoder layers.
class DecoderLayer(nn.Module):
"""
Implements a single Decoder layer with pre-layer normalization.
"""
def __init__(self, d_model: int, num_heads: int, d_ffn: int, dropout: float = 0.1):
super(DecoderLayer, self).__init__()
# Self-attention sub-layer
self.self_attention = MultiheadAttention(d_model, num_heads, dropout=dropout)
# Encoder-Decoder attention sub-layer
self.encoder_attention = MultiheadAttention(d_model, num_heads, dropout=dropout)
# Position-wise feedforward sub-layer
self.feedforward = PositionwiseFeedForward(d_model, d_ffn, dropout=dropout)
# Layer normalization
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
self.norm3 = LayerNorm(d_model)
self.norm4 = LayerNorm(d_model)
# Dropout
self.dropout = nn.Dropout(dropout)
def forward(
self,
x: torch.tensor,
encoder_output: torch.tensor,
self_mask: Optional[torch.tensor] = None,
encoder_mask: Optional[torch.tensor] = None,
):
# Self-attention sub-layer
x_norm = self.norm1(x)
self_attention_output, _ = self.self_attention(
x_norm, x_norm, x_norm, mask=self_mask
)
x = x + self.dropout(self_attention_output)
# Encoder-Decoder attention sub-layer
x_norm = self.norm2(x)
encoder_output_norm = self.norm3(encoder_output)
encoder_attention_output, encoder_attention_weights = self.encoder_attention(
x_norm, encoder_output_norm, encoder_output_norm, mask=encoder_mask
)
x = x + self.dropout(encoder_attention_output)
# Position-wise feedforward sub-layer
x_norm = self.norm4(x)
ff_output = self.feedforward(x_norm)
x = x + self.dropout(ff_output)
return x, encoder_attention_weights
This can then be used analogously in our main Decoder
class as follows.
class Decoder(nn.Module):
"Class for decoder, which consists of N-many DecoderLayers"
def __init__(
self,
num_stacks: int,
d_model: int,
num_heads: int,
d_ffn: int,
dropout: float = 0.1,
):
super(Decoder, self).__init__()
self.layers = nn.ModuleList(
[
DecoderLayer(
d_model=d_model, num_heads=num_heads, d_ffn=d_ffn, dropout=dropout
)
for _ in range(num_stacks)
]
)
def forward(
self,
x: torch.tensor,
encoder_output: torch.tensor,
self_mask: torch.tensor,
encoder_mask: torch.tensor,
):
"Pass the input (and mask) through each layer in turn."
attn_weights = []
for layer in self.layers:
x, encoder_decoder_attn_weights = layer(
x, encoder_output, self_mask, encoder_mask
)
attn_weights.append(encoder_decoder_attn_weights)
return x, attn_weights
Masking
At this point we can actually introduce the Transformer model. However, we will discuss masking in the architecture before we do this, as it is a very critical component of the Transformer model training and forward pass computation.
Generally, masking is the idea of hiding certain fields in a data structure used by an algorithm. Usually, the data structure
is a matrix, and hiding the fields is achieved by setting them to zero. In this case a mask is a binary-valued
matrix where False
values correspond to "please treat this value as zero".
In the transformer architecture, masking is used for two purposes.
- Padding. When running a forward pass on the Transformer, it is generally done in batches. This allows us to achieve parallelization by stacking a batch of sequences into a matrix and sending it off into the Transformer. However, different sequences will consist of different lengths, an issue we can resolve with padding the sequences so that they all have the same shape. This is fine, but this causes a problem in our attention calculation. Specifically, these padded values will contribute to the attention calculation.
- Future Words. When training the Transformer, it is also done in batches, which consist of pairs
of input and output sequences. However, the point of the Transformer is to be an autoregressive model, meaning that it
predicts tokens via (1) a given start token and (2) all previously generated tokens. Therefore, when training to predict an output sequence,
it should not be allowed to see all of the tokens in the ground truth output sequence.
That is, when predicting token
, it must only use the previously seen tokens .
These two issues arise because the forward pass and training procedure are heavily parallelized. Specifically, they arise because of the attention calculation:
The solution to these two issues is simple: We just mask certain columns of the matrix
below, so as to avoid attention towards specific values in
This explains why our attention
function above had an optional mask
argument, and this is exactly
what the masking functionality achieves. To demonstrate this, here's an example of us using
the mask in the attention, and what using the mask achieves.
# Some random Transformer params
batch_size = 2
seq_len = 10
d_model = 5
# Construct Q, K, V matrices
Q = torch.rand(size=(batch_size, seq_len, d_model))
K = torch.rand(size=(batch_size, seq_len, d_model))
V = torch.rand(size=(batch_size, seq_len, d_model))
# Construct a mask, set the 5th and last column to zero
mask = torch.ones(batch_size, 1, seq_len).bool()
mask[:, :, -1] = False
mask[:, :, 5] = False
# Compute attention
attn_vals, attn_weights = attention(Q, K, V, mask=mask)
# The last and 5th attention weights should be zero
assert torch.all(attn_weights[:, :, -1] == 0)
assert torch.all(attn_weights[:, :, 5] == 0)
As we can see, the mask allowed us to "zero out" whichever columns we wanted zeroed out in the
attention calculation. The one question remains; how do we generate these masks in code?
To answer that, we briefly demonstrate construction of these two types of masks.
Below we construct a padding mask given a batch of sequences with different lengths.
from torch.nn.utils.rnn import pad_sequence
# List of sequences
tensors = [
torch.tensor([1, 2, 3, 4]),
torch.tensor([5, 6, 7]),
torch.tensor([8, 9]),
torch.tensor([10, 11, 12, 13, 14])
]
# Pad the tensors to the maximum length
padded_tensors = pad_sequence(tensors, batch_first=True, padding_value=0)
print(padded_tensors)
# will print
tensor([[ 1, 2, 3, 4, 0],
[ 5, 6, 7, 0, 0],
[ 8, 9, 0, 0, 0],
[10, 11, 12, 13, 14]])
# Construct the padding mask
src_mask = (padded_tensor != 0)
print(src_mask)
# will print
tensor([[[ True, True, True, True, False],
[ True, True, True, False, False],
[ True, True, False, False, False],
[ True, True, True, True, True]]])
Constructing the future mask for masking words is also pretty easy.
def future_mask(sequence_length):
"""
Creates a lower-triangular n \\times n matrix
used to mask future positions
"""
attn_shape = (1, sequence_length, sequence_length)
future_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(torch.uint8)
return future_mask == 0
For example, this function returns output like below.
future_mask(5)
# will print
tensor([[[ True, False, False, False, False],
[ True, True, False, False, False],
[ True, True, True, False, False],
[ True, True, True, True, False],
[ True, True, True, True, True]]])
It is important to note that in theory a lot of the padding and future masks would contain much
duplicate data. We can circumvent this by using broadcasting
in conjunction with PyTorch's masked_fill
method. In this implementation,
- Padding masks will have shape
(batch_size, 1, sequence_length)
in order to broadcast across the zeroth and second dimension in the attention calculation. - Future masks will have shape
(1, sequence_length, sequence_length)
, which will broadcast across the zeroth dimension in the attention calculation.
Transformer
We now introduce the transformer model. Below is our implementation using our previous code.
class Transformer(nn.Module):
def __init__(
self,
src_vocab_size: int,
tgt_vocab_size: int,
num_encoder_stacks: int = 6,
num_decoder_stacks: int = 6,
d_model: int = 512,
d_ffn: int = 2048,
num_encoder_heads: int = 8,
num_decoder_heads: int = 8,
max_seq_len: int = 100,
dropout: float = 0.1,
):
super(Transformer, self).__init__()
self.d_model = d_model
self.max_seq_len = max_seq_len
self.src_embedding = Embeddings(src_vocab_size, d_model)
self.tgt_embedding = Embeddings(tgt_vocab_size, d_model)
self.encoder = Encoder(num_encoder_stacks, d_model, num_encoder_heads, d_ffn)
self.decoder = Decoder(num_decoder_stacks, d_model, num_decoder_heads, d_ffn)
# Mark positional encoder as not learnable, so that .parameters() doesn't pass it to optimizer
self.register_buffer(
"positional_encoder", positional_encoding(max_seq_len, d_model)
)
self.output_layer = nn.Linear(d_model, tgt_vocab_size)
self.src_dropout = nn.Dropout(dropout)
self.tgt_dropout = nn.Dropout(dropout)
self.softmax = nn.Softmax(dim=-1)
# Initialize parameters with Glorot / fan_avg.
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def encode(self, src: torch.tensor, src_mask: torch.tensor):
LOGGER.debug(
"Computing forward pass of encoder with "
f"{src.size()=}, {src_mask.size()=}"
)
# Embed inputs, add position encoding, apply dropout
src = self.src_embedding(src)
src = src + self.positional_encoder[: src.size(1)]
src = self.src_dropout(src)
# Encode the source sequence
enc_output = self.encoder(src, src_mask)
return enc_output
def decode(
self,
tgt: torch.tensor,
enc_output: torch.tensor,
tgt_mask: torch.tensor,
src_mask: torch.tensor,
):
LOGGER.debug(
"Computing forward pass of decoder with "
f"{tgt.size()=}, {enc_output.size()=}, {tgt_mask.size()=}, {src_mask.size()=}"
)
# Embed targets, add position encoding, apply dropout
tgt = self.tgt_embedding(tgt)
tgt = tgt + self.positional_encoder[: tgt.size(1)]
tgt = self.tgt_dropout(tgt)
# Decode the target sequence using the encoder output
dec_output, encoder_decoder_attn_weights = self.decoder(
tgt, enc_output, tgt_mask, src_mask
)
return dec_output, encoder_decoder_attn_weights
def forward(
self,
src: torch.tensor,
tgt: torch.tensor,
tgt_mask: torch.tensor,
src_mask: torch.tensor,
):
"""
Forward pass of Transformer.
- src has size (batch_size, src_seq_len)
- tgt has size (batch_size, tgt_seq_len)
- src_mask has size (batch_size, 1, seq_len), and
prevents attention to padding indices
- tgt_mask has size (1, tgt_seq_len, tgt_seq_len), and
prevents attention to future positions and padding
- output has size (batch_size, tgt_seq_len, tgt_vocab_size)
"""
LOGGER.debug(
f"computing forward pass with {src.size()=} "
f"{tgt.size()=} {src_mask.size()=} {tgt_mask.size()=}"
)
enc_output = self.encode(src, src_mask)
dec_output, encoder_decoder_attn_weights = self.decode(
tgt, enc_output, tgt_mask, src_mask
)
# Compute output layer
logits = self.output_layer(dec_output)
return logits, encoder_decoder_attn_weights