Skip to main content
Overview

RNNs

September 7, 2021
4 min read

RNN

Given sequence data as input/output, a network that takes input xtx_t at time t and the previous hidden state ht1h_{t-1}, and outputs hth_t.

The important thing is that a new model doesn’t appear at every time stamp. A single parameter set A is used across all time stamps.

The compressed representation on the left is called a rolled diagram; the one showing time stamps is the unrolled diagram.

The diagram looks like the above.

  • hth_t: new hidden state vector
  • fwf_w: RNN function with parameters W.
    • W: linear transform matrix
  • yty_t: output vector at time step t.
    • Computed using hth_t.
    • Can be computed at every step, or only at the end — it’s flexible.
    • e.g., POS tagging requires computation at every step, while sentiment analysis only needs the final step.

fwf_w is defined as a non-linear function as shown above. WhhW_{hh} and WxhW_{xh} are split from W in the formula, which can be understood as derived from a single W matrix as in the figure below.

Since hth_t‘s dimension is a hyperparameter, let’s set it to 2.

To take xtx_t and ht1h_{t-1} as input and produce hth_t as output, W must have shape (2, 5). That’s because the dot product of the concatenation of xtx_t and ht1h_{t-1} with W yields (2,1). Instead of keeping W as (2, 5), we can split it at the boundary between the red and green circles in the figure. That is, xtx_t and ht1h_{t-1} each get their own W, and adding their dot product results yields hth_t.

So WhhW_{hh} transforms ht1h_{t-1} into hth_t, and WxhW_{xh} transforms xtx_t into hth_t.

By the same logic, WhyW_{hy} transforms hth_t into yty_t.

For binary classification, yty_t would be a 1-dimensional vector (scalar). Apply sigmoid to use the result as predicted probability. For multi-class, yty_t‘s dimension equals the number of classes, and softmax is applied to get a probability distribution.

Type of RNN

RNN can handle cases where one or both of input/output are sequence data. ref: http://karpathy.github.io/2015/05/21/rnn-effectiveness/

  • one to one (standard neural network)
    • Neither input nor output is sequence data, with a single time step.
    • Same structure as a standard DNN.
  • one to many
    • Input is not sequence data, but sequence data is output across multiple time steps.
    • Only the first step has real input; remaining steps receive all-zero tensors.
    • e.g., Image captioning
  • many to one
    • Input occurs at each time step, with a single output at the end.
    • e.g., Sentiment classification
  • many to many
    • Sequence data is input per time step, then output per time step.
      • e.g., Machine translation
    • Input and output at every time step.
      • e.g., Video classification on frame level

Character-level language model

A language model predicts the next word based on the given sequence of characters or words. Can be performed at both word and character level.

Building a character-level language model proceeds as follows:

Example of training sequence: “hello”

  1. Build a unique vocabulary at the character level. [h, e, l, o]
  2. Characters in the vocabulary are represented as one-hot vectors, as in word embedding. h = [1,0,0,0]
  3. Feed “hell” into the RNN sequentially according to the formula:

The key point is that the next character must be predicted at every time step. So the RNN is set up as many-to-many:

Output is computed as:

It’s called logit because softmax is used for multi-class classification.

Inference

Since it’s an RNN, each time step’s output can be fed back as the next time step’s input. So you only give ‘h’ as the first input and let the rest be generated automatically.

Training Shakespeare’s plays

The method used at the character level can also be applied to text. Build a vocabulary at the word level, including all punctuation — commas, ‘\n’, spaces, everything. This way you can build a simple language model with RNN.

As training progresses, the sentences generated from a given first character become more natural.

Other examples

  • Learning plays to distinguish characters and dialogue.
  • Training on LaTeX papers to generate new papers during inference.
  • Training on C code to generate code.

BPTT (Backpropagation through time)

It would be ideal to use all losses for training, but sequences are usually too long for that. So all data is used for training, but loss is only taken from certain segments for backpropagation.

How RNN works

We can trace how RNN learns. The hidden state contains all information from before time t. So tracking how the hidden state changes from its initial state reveals how the RNN learns.

The results below are from LSTM and GRU (not vanilla RNN) showing hidden state changes.

Red means a specific cell in the hidden state is becoming more negative; blue means more positive. Tracking the cell responsible for quote detection in the hidden state produced the above result. The hidden state of the cell handling if statements changed like the above.

Vanishing/Exploding gradient in RNN

RNN itself is solid, but problems arise in backpropagation. RNN formulas involve repeatedly multiplying WhW_h and passing through activation functions. This repeated multiplication causes gradients to grow unboundedly if greater than 1, or shrink toward zero if less than 1.

For a simple example, think of W as a scalar. To get the gradient of h3, we differentiate. Computing the gradient with respect to h1 requires applying the chain rule 3 times, and whhw_{hh} (value 3) gets multiplied 3 times as part of the gradient. For a longer sequence, the gradient would be proportional to an even larger power of 3. If whhw_{hh} were less than 1, the values would shrink dramatically.

The result is that the value generated at h3 should propagate well back to h1, but instead gradients converge to infinity or zero.

Loading comments...