LSTM and GRU -- Formula Summary

Introduction

Long Short-Term Memory (LSTM) unit and Gated Recurrent Unit (GRU) RNNs are among the most widely used models in Deep Learning for NLP today. Both LSTM (1997) and GRU (2014) are designed to combat the vanishing gradient problem prevents standard RNNs from learning long-term dependencies through gating mechanism.

Note that, this article heavily rely on the following to articles, Understanding LSTM Networks and Recurrent Neural Network Tutorial, I summary the formula definition and explanation from them to enhance my understanding of LSTM and GRU as well as their similarity and difference.

GRU is a simpler variant of LSTMs that share many of the same properties, it combines the forget and input gates into a single “update gate”. And it also merges the cell state and hidden state, and makes some other changes. The resulting model is simpler than standard LSTM models, but its performance comparable to LSTM on sequence modeling, but less parameters and easier to train. LSTM and GRU

LSTM

Denote $\ast$ as elementwise multiplication and ignore bias term, LSTM calculates a hidden state $h_{t}$ as $$ \begin{aligned} i_{t} & =\sigma\big(x_{t}U^{i}+h_{t-1}W^{i}\big)\newline f_{t} & =\sigma\big(x_{t}U^{f}+h_{t-1}W^{f}\big)\newline o_{t} & =\sigma\big(x_{t}U^{o}+h_{t-1}W^{o}\big)\newline \tilde{C}_{t} & =\tanh\big(x_{t}U^{g}+h_{t-1}W^{g}\big)\newline C_{t} & =\sigma\big(f_{t}\ast C_{t-1}+i_{t}\ast\tilde{C}_{t}\big)\newline h_{t} & =\tanh(C_{t})\ast o_{t} \end{aligned} $$ Here, $i$, $f$, $o$ are called the input, forget and output gates, respectively. Note that they have the exact same equations, just with different parameter matrices ($W$ is the recurrent connection at the previous hidden layer and current hidden layer, $U$ is the weight matrix connecting the inputs to the current hidden layer). They care called gates because the sigmoid function squashes the values of these vectors between 0 and 1, and by multiplying them elementwise with another vector you define how much of that other vector you want to “let through”. The input gate defines how much of the newly computed state for the current input you want to let through. The forget gate defines how much of the previous state you want to let through. Finally, The output gate defines how much of the internal state you want to expose to the external network (higher layers and the next time step). All the gates have the same dimensions $d_h$, the size of your hidden state.

$\tilde{C}$ is a “candidate” hidden state that is computed based on the current input and the previous hidden state. $C$ is the internal memory of the unit. It is a combination of the previous memory, multiplied by the forget gate, and the newly computed hidden state, multiplied by the input gate. Thus, intuitively it is a combination of how we want to combine previous memory and the new input. We could choose to ignore the old memory completely (forget gate all 0’s) or ignore the newly computed state completely (input gate all 0’s), but most likely we want something in between these two extremes. $h_{t}$ is output hidden state, computed by multiplying the memory with the output gate. Not all of the internal memory may be relevant to the hidden state used by other units in the network.

Intuitively, plain RNNs could be considered a special case of LSTMs. If fix the input gate all 1’s, the forget gate to all 0’s (say, always forget the previous memory) and the output gate to all 1’s (say, expose the whole memory), it will almost get a standard RNN.

GRU

For GRU, the hidden state $h_{t}$ is computed as $$ \begin{aligned} z_{t} & =\sigma\big(x_{t}U^{z}+h_{t-1}W^{z}\big)\newline r_{t} & =\sigma\big(x_{t}U^{r}+h_{t-1}W^{r}\big)\newline \tilde{h}_{t} & =\tanh\big(x_{t}U^{h}+(r_{t}\ast h_{t-1})W^{h}\big)\newline h_{t} & =(1-z_{t})\ast h_{t-1}+z_{t}\ast\tilde{h}_{t} \end{aligned} $$ Here $r$ is a reset gate, and $z$ is an update gate. Intuitively, the reset gate determines how to combine the new input with the previous memory, and the update gate defines how much of the previous memory to keep around. If set the reset to all 1’s and update gate to all 0’s, it will arrive at the vanilla RNN model.

Reference

Related