Attention Movie Reviews Let’s start with a simple example of predicting if a movie review is positive or negative. For example, both of the two reviews “My kid likes this movie so much” and “My little kid likes this animated movie” imply a positive review. How do we make the machine understand this sentiment? First of all, we need to represent each sentence in natural languages into something computable, e.g. a collection of vectors. For simplicity, we consider a word-level model, meaning that we decompose each sentence into separate words. As a result, the original sentences will look like the following: My kid likes this movie so much and My little kid likes this animated movie Afterwards, each word will be transformed to a vector (or embedding) by a lookup table. Now, the second question is: what about the model? If we use the CNN model that we introduced a few weeks ago, we use pooling layers to aggregate information across neighboring words. For example, the embeddings of [kid] and [little, kid] should make little difference to the final sentiment prediction. However, using pooling layers leads to potential errors when certain words matter. For example, a sentence “My kid does not like this movie much” implies a negative review. Applying a simple pooling operation over [not, like] can possibly eliminate the signal indicating negative sentiments. If we take a different pespective in the above example, we can observe that words [likes] and [so, much] are strong indication for a positive sentiment while the word [not] implies a negative one. Therefore, we wish our model to pay attention to these informative words. As a result, we arrive at the attention mechanism in deep models. General Attention Formulation An attention function is a mapping that takes as input a query and a set of keys and values, and generates an output vector. The output is a weighted sum of the values, where the weights assigned to each value is computed by a function between the query and the keys. General Attention Given an input query $\mathbf{q}\in\mathbb{R}^{C}$, a set of keys $\mathbf{K}=\begin{bmatrix}\mathbf{k}_1, \cdots, \mathbf{k}_N \end{bmatrix}$, where $\mathbf{k}_i \in\mathbb{R}^{C}$, and a set of values $\mathbf{V}=\begin{bmatrix}\mathbf{v}_1, \cdots, \mathbf{v}_N \end{bmatrix}$, where $\mathbf{v}_i \in\mathbb{R}^{C}$. The output vector $\mathbf{o}\in\mathbb{R}^{C}$ is computed by $$ \mathbf{o} = \sum_i \alpha_i \mathbf{v}_i, \quad\text{where } \alpha_i = \frac{e^{h(\mathbf{q}, \mathbf{k}_i)}}{\sum_j e^{h(\mathbf{q}, \mathbf{k}_j)}} $$ Here $h(\mathbf{q}, \mathbf{k})$ is a scoring function that determines how aligned is $\mathbf{q}$ with $\mathbf{k}$, or more generally a kernel function [1]. In our movie review example, the query refers to a single vector that represents the review task we are dealing with. Both sets of keys and values refer to the word embeddings. The attention function outputs a sentence-level embedding which will be next passed through a binary classifier for review sentiment. There are many way to implement $h(\mathbf{q}, \mathbf{k})$. We are particularly interested in a simple dot-product form, namely $h(\mathbf{q}, \mathbf{k}) = \mathbf{q}^\top \mathbf{k}$, which was used in the famous “Attention is all you need” paper [2]. When we want to compute the attention of multiple queries simutaneously, the dot-product operation allows efficient parallel computation. We pack multiple queries together into a matrix $\mathbf{Q}$ and pack the set of keys/values into matrices $\mathbf{K}$ and $\mathbf{V}$. The attention can be written in the following matrix form. Scaled Dot-product Attention (Matrix Form) For an input triple of queries $\mathbf{Q}\in\mathbb{R}^{M\times C}$, keys $\mathbf{K}\in\mathbb{R}^{N\times C}$, and values $\mathbf{V}\in\mathbb{R}^{N\times C}$. The dot-product attention computes the output in the following way: $$ \mathbf{O} = \mathrm{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \mathrm{softmax}(\frac{\mathbf{Q}\mathbf{K}^\top}{\sqrt{C}})\mathbf{V}, $$ where $\mathrm{softmax}(\cdot)$ is row-wise computed so that each row will sum up to 1. Why do we need the scale factor $\sqrt{C}$? Assume that each element in $Q$ and $K$ is an independent random variable with mean 0 and variance $\sigma$. Then their dot product, $\mathbf{q}^\top \mathbf{k}=\sum_i^{C} q_i k_i$, follows a distribution with mean 0 and variance $C\sigma$. This will make the output of the softmax function peakier, resulting in larger regions where it has extremely small gradients. To alleviate this issue, we scale the dot products by $\frac{1}{\sqrt{C}}$. Self-Attention vs Cross-Attention The choice of queries, keys, and values is arbitrary so far. In pratice, we have two most common categories for attention, namely self-attention and cross-attention. Self-Attention In self-attention, all of the queries, keys, and values come from the same set of inputs $\mathbf{X}\in\mathbb{R}^{N\times C}$. Particularly, we use three linear projection matrices $\mathbf{W}_{Q}\in\mathbb{R}^{C\times d}$, $\mathbf{W}_{K}\in\mathbb{R}^{C\times d}$, $\mathbf{W}_{V}\in\mathbb{R}^{C\times d}$ and compute the queries, keys, and values by $\mathbf{Q} = \mathbf{X}\mathbf{W}_{Q} $, $\mathbf{K} = \mathbf{X}\mathbf{W}_{K} $, and $\mathbf{V} = \mathbf{X}\mathbf{W}_{V} $. Cross-Attention In cross-attention, the keys and values come from the same set of inputs $\mathbf{X}\in\mathbb{R}^{N\times C}$, namely $\mathbf{K} = \mathbf{X}\mathbf{W}_{K} $, and $\mathbf{V} = \mathbf{X}\mathbf{W}_{V} $. The queries come from another set of inputs $\mathbf{Q}\in\mathbb{R}^{M \times C}$. Generally, the number of query vectors is not the same as that of key/value vectors, namely $M\neq N$. Cross-Attention Examples First, when the number of input elements $N$ is huge, cross-attention serves as a bottleneck to compress input vectors into a few latent codes. In this case, we typically have $M \ll N$. Perceiver [3] is a good example. Second, in the encoder-decoder style Transformer, the queries come from the previous decoder layer, and the memory keys and values come from the output of the encoder. This cross-attention module allows every position in the decoder to attend over all positions in the input sequence.