Paper reading group - Attention is All You Need

AI
NLP
Attention
Transformer
Author

Alex Paul Kelly

Published

March 30, 2024

Introduction

Every 2nd Saturday we have a paper reading group where we submit papers we are interested in reading and vote for our favorite paper to read together. This week we discussed the paper “Attention is All You Need” (2017). This paper introduced the Transformer model, which has since become the foundation for many state-of-the-art models such as ChatGPT. This is my take on Transformers, why it is so important and why it’s used so widely. The aim is not to go into extensive details as numerous resources online already exist but to give an intuitive understanding and provide resources I found useful in understanding transformers from different perspectives, e.g. from a visual, maths, code and conceptual. I will assume you have basic understanding of Neural networks, back propagation ect and I will be using the token and words interchangeably (token are what language models actually use).

Why transformers

To understand why Transformers, first, you need to understand what came before. RNNs are sequenced based models and each word or token has to be processed before the next token can be processed. This doesn’t lend itself to utilizing the GPUs as the model is limited by waiting for the previous token. They are deep networks with lots of layers with hidden states that carry information from previously seen inputs. When long sequences are passed this increases the chances of vanishing and exploding gradients which reduces the effectiveness of the network, a flaw in the design. To resolve these issues, LSTMs (long short memory) models is designed with 2 pathways, one for learning the more immediate tokens and another pathway for a longer memory. This also has long sequence issues but is a improvement on RNNs. However, this still didn’t provide accurate enough predictions of tokens, and it was still sequential and it didn’t use all the resources of the GPUs which means slower training as you can’t scale on either larger GPU or multiple GPUs.

Next comes transformers with their paper “Attention is all you need”. Transformers are fantastic models that change how the models are organized, they incorporate the idea of attention, hence the paper’s name “Attention is all you need”. They use a method called Query, Key and Value vectors (attention headers) to keep multiple perspectives of relationships between tokens (or words) and their meanings. This makes the Transformer model highly parallelizable and efficient, leading to significant improvements in training speed and performance for all types of modalities more notably text as described in the paper. The attention also increases the performance with longer sequences. In summary, this has the following benefits:

  • It does not use recurrence or convolution which are the traditional methods for sequence modeling such as RNNs and LSTMs
  • Relies entirely on attention mechanisms to draw global dependencies between input and output
  • Parallelization allows for faster training times and full use of GPU resources
  • The model is highly modular and can be easily adapted to different tasks by changing the input and output representations

Model Architecture

Transformer Model Architecture

The architecture above is split into 2 ways, the left side is the encoder and the right side is the decoder. A good way to think about good way to think about this is that the encoder is like a meticulous note-taker, capturing all the critical information and its interrelations. The decoder is like a storyteller or translator, using those notes to create a new, coherent narrative or translation that reflects the captured information accurately but in a different form.

Multi-attention headers

Before we dive deeper, I want to introduce a couple more analogies to help you understand Multi-attention headers. Understanding the intuition behind multi-attention headers through analogies can be a good segway into further learning. They helped me, not only understand why using headers helps the model perform better but also helped retain the information. The multi-attention headers are made of Query, Key and Value and thats what we will dig into first. The overall purpose will make more sense once you go through the components section below.

Easy Analogy for QKV Mechanism

Imagine you’re at a large dinner party, trying to follow the conversations around the table:

Query (Q): This is like you trying to understand a comment made by someone. You’re focused on this comment and trying to figure out its context and relevance to the conversation. Key (K): Think of every person at the table as holding a “key” to different pieces of information. Some of what they say will be more relevant to understanding the comment you’re focused on, and some less so. Value (V): The “value” is the actual content or meaning each person (key) can contribute to help you understand the comment in question. After deciding whose input is most relevant, you’ll pay more attention to what those few people are saying.

The transformer, like you in this scenario, uses the QKV mechanism to decide which parts of the input (the conversation) to pay attention to when trying to understand each piece (word or comment).

Easy Analogy for Multiple Heads

Continuing with the dinner party analogy, imagine now that you’re not just trying to understand the content of the conversation but also the emotional tone, who’s leading the conversation, and how topics are changing over time.

Having multiple heads is like having several versions of you at the party, each with a different focus. One version of you is trying to follow the main topic, another is paying attention to the emotional undercurrents, another is noting how the conversation topics shift, and so on. Each “version” of you can focus on different aspects of the conversation simultaneously, ensuring that you get a fuller understanding of what’s happening at the dinner party than you would if you were just trying to follow the main topic. In essence, the QKV mechanism with multiple heads allows the transformer to “attend” to a complex sequence (like a conversation) from multiple perspectives at once, ensuring it captures the rich, multifaceted nature of the data it’s processing.

(input) encoder components

encoder highlighted
  • Input Embedding: This can be thought of as a dictionary, each word (or token) is referenced as a single vector. This also captures the semantic information of the words and their relationship to each other. This context information is important for the understanding of the meaning of each word is to each other.
  • Positional Encodings: This is used to track the position of the word for the model to understand its relevance due to its position. The words are not inherently processed in order so keeping track of the order is important. This is an important part that allows the model to process in parallel.
  • Multi-headed Attention: This is where the magic of the self-attention is performed. Each head will keep track of it’s relationship to other words and its semantic understanding. In the paper, there are 8 heads meaning there are 8 different perspectives for each word (or token) and their relationship to the other words and the great meaning. Each head is different due to the random weights assigned at the start, these weights are then updated during training to hone in on different perspectives that matter the most in reference to the training data.
  • The feed forward layer: before the forward feed, the heads are concatenated and sent to the feed foward layer for further transformations to improve the understanding of all the heads and gain a better conceptual understanding to improve the overall accuracy of the model.
  • Finally, this is then passed to the decoder part of the model to aid prediction in its new form.

(output) decoder components

decoder highlighted
  • Output Embeddings: This is the same as the input embeddings, there is no difference on this layer, only when the embeddings reach the masked layers. The embeddings can be thought of as a dictionary, each word (or token) is referenced as a single vector. This also captures the semantic information of the words and their relationship to each other. This context information is important for the understanding of the meaning of each word is to each other.
  • Positional Encodings: This is used to track the position of the word for the model to understand its relevance due to its position. The words are not inherently processed in order so keeping track of the order is important. This is an important part that allows the model to process in parallel.
  • Masked multi-head attention: Its aim is similar to the multi-headed attention in the encoder but to iteratively learn the relationship and semantic meanings one token at a time to predict the next token. For example, if the sentence to translate is “you are the best” to French, the first token is “you”, then the 2nd word “you are”, 3rd word “you are the” and 4th “you are the best” each time asking “given what you know so far, whats the next token”. This is how the masking process works but this can be done in parallel.
  • Multi-headed Attention: This again is where this gets really interesting. This layer takes the whole semantic meaning from the encoder of the whole input tokens and also the masked multi-head attention (up to the point it’s at e.g. “you Are”) to predict the next token. It applies several heads to pay attention to different aspects of the information passed.
  • The feed forward layer: before the forward feed, the heads are concatenated and sent to the feed foward layer for further transformations to improve the understanding of all the heads and gain a better conceptual understanding to improve the overall accuracy of the model.
  • Linear Layer: The linear layers gain further understanding and produce the logits to pass to the softmax function
  • Softmax function: The softmax function converts the logits to a probability and the highest probability word is selected.

Futher studying

These are all the links, blogs to help understand transformers better. If you want to dive deeper, understand your own learning style and pick option that suites you:

The original transformer paper

A two-part blog post on creating a transformer from scratch in PyTorch - Part 1 -

A two-part blog post on creating a transformer from scratch in PyTorch - Part 2

Introduction to RNNs

Understanding LSTM Networks

Understanding Encoder and Decoder LLMs

The Illustrated Transformer

The Illustrated GPT-2

Visual video on transformers by 3Blue1Brown

What are Word and Sentence Embeddings? (5-10 minute read)

Video Summaries

15 min—really fantastic animated summary 30 min–video supplement to The Illustrated Transformer 50 min—fantastic code walkthrough of encoder 40 min—fantastic code walkthrough of decoder

Statquest youtube videos

Transformers
Decoders only trasnformers
LSTMs
Seq2Seq
Attention for neural networks
Cosine similarity