Skip to main content
Overview

Attention

September 8, 2021
4 min read

Seq2Seq with Attention

Seq2Seq with LSTM

Seq2Seq falls under the many-to-many category of RNN architectures. Both input and output are word-level sequences.

The diagram above shows a dialog system (e.g., chat bot). The part that receives the input sentence is the encoder; the part that generates the output sentence is the decoder. The RNN model used here is LSTM. The hidden state from the encoder’s last step becomes the input hidden state for the decoder.

SoS (Start of Sentence) Represents the first token of the generated sentence. SoS is managed separately in the vocabulary and is fed as the first input to the decoder during training.

EoS (End of Sentence) Represents the last token of the generated sentence. It specifies when to stop generating.

Problem

All encoder information must be stored in a fixed-dimension hidden state. So even though LSTM addresses long-term dependency, as the sequence grows longer, earlier information is likely to be lost or distorted.

For example, in a sentence like ‘I go home,’ the subject should be recognized first. But since subjects typically appear at the beginning of the sentence, that information may degrade toward the end, causing the decoder to fail at generating the subject properly.

Workaround Reverse the sentence order. Turn ‘I go home’ into ‘home go I’ to place important information at the end. Not a fundamental solution.

Solution Use all hidden states generated at each step.

Seq2Seq with Attention

This is a task translating French sentences to English.

The encoder generates a hidden state at each step, same as regular Seq2Seq. The hidden state from the encoder’s last step becomes the input hidden state for the first decoder step.

To determine which encoder hidden states are needed, the dot product is computed between each encoder hn(e)h_n^{(e)} and the decoder’s h1(d)h_1^{(d)}. In the diagram above, 4 dot product results would be computed individually. The dot product results can be thought of as similarity between hidden states.

Applying softmax to the dot product results (treating them as logits) yields probabilities. These probabilities are used as weights applied to hn(e)h_n^{(e)}.

Attention vector This resulting weight vector that sums to 1 is called the attention vector.

A weighted average of hn(e)h_n^{(e)} using these weights produces a single attention output vector. This result is also called the context vector.

In summary, information that the decoder’s hidden state needs is selected and combined from the encoder hidden states.

Attention module The part enclosed by green lines in the diagram above is called the attention module. It takes encoder hidden states as input and computes a single attention output.


The decoder’s hidden state and context vector (attention output) are concatenated to become the input of the output layer. This is how the next word is predicted.


The second decoder step repeats the same process. The decoder receives h1(d)h_1^{(d)} as input hidden state and ‘the’ as input, producing h2(d)h_2^{(d)}.


The repetition continues until the output produces an end token (EoS).

Decoder’s Hidden State

The decoder’s hidden state vector must serve two roles.

  • Determine which encoder hidden states to focus on.
    • = It must contain information for creating the attention vector.
  • Serve as input to the output layer for prediction.

The decoder’s training is conducted so it can perform both roles simultaneously.

Therefore, backpropagation follows the purple path shown in the diagram above.

Teacher Forcing

In teacher forcing, the decoder’s input during training is the ground truth. In other words, even if the model mispredicts the next word during training, the ground truth corrects it.

Similarity Measurement

Beyond simple dot products, similarity can be computed in several ways.

  • scorescore: similarity function
  • hth_t: decoder hidden state
  • hˉs\bar h_s: encoder hidden state

generalgeneral A weight matrix WaW_a is introduced in the dot product. Think of it as granting authority to assign weights to individual multiplication elements in matrix multiplication.

(abcd)(xyzv)\begin{pmatrix} a & b \\ c & d \end{pmatrix}\begin{pmatrix} x & y \\ z & v \end{pmatrix} For example, the matrix product above consists of terms like ax+bzax+bz, ay+bvay+bv, etc.

Weights are assigned to each element: w0(ax+bz)w_0(ax+bz), w1(ay+bv)w_1(ay+bv), etc. — adding a tunable variable to each matrix product element. In deep learning, this creates learnable parameters.

concatconcat In [ht;hˉs][h_t;\bar h_s], the semicolon denotes concatenation between matrices. Looking at the formula, the term wrapped in tanh resembles a neural network — and it is.

If ht=[1,3]h_t=[1,3] and hˉs=[2,5]\bar h_s=[2, -5], the network is constructed as shown above. W1 and W2 represent fully connected networks.

In the formula, W2 is denoted vav_a. This is because the final network layer must output a scalar, so W2 must be a vector. In the diagram, the 3-dimensional vector must be reduced to a scalar, so W2 must also be a 3-dimensional vector.


Why diversify similarity measurement methods? Compared to simple dot products, more tunable parameters are introduced during model training. These additional parameters are heavily involved in computing the attention vector.

In other words, by incorporating variables into similarity measurement, the model can also learn the process of computing the attention vector.

Advantages of Attention

  • Dramatically improved machine translation performance.
    • Unlike previous Seq2Seq, it created an environment where the decoder can focus on specific information.
    • Solved the problem of long sentences being poorly translated.
  • Attention solves the bottleneck problem.
    • Solved problems arising from stuffing all previous information into a single hidden state.
    • The decoder can directly access source information.
  • Attention solves gradient vanishing.
    • Previously, backpropagation propagated loss sequentially through the decoder and encoder (red path in the diagram above). The bottleneck phenomenon occurs here too. Especially if you want to update hidden states at early encoder steps, backpropagation must go very deep.
    • With attention, this propagation process is simplified (blue path in the diagram). Attention output creates shortcut-like paths in backpropagation.
  • Attention provides some interpretability.
    • Knowing the attention vector distribution for a specific input reveals what information the decoder is focusing on.

Attention Examples

An example of translating French to English using attention. It translates in order just fine, and for phrases where word order changes, the attention mechanism automatically detects the reordering and handles the translation. An end-to-end translation was performed automatically.

Loading comments...