Sequence to Sequence Model

3 minute read

This post is about the introduction and various building blocks of Sequence to Sequence Model.

As the name suggests, Sequence to Sequence models takes in a sequence and produce another sequence. Sequence can be anything of text/image/audio/video etc.

Generally, all sequence-to-sequence model has encoder, decoder and a intermediate state as main components. Encoder and Decoder can be anytype of rnn/cnn/transformer etc.

overview fig

There are many applications which uses sequence to sequence model architecture. Few of them are:

  • Language Translation
  • Text Summarization
  • Chatbots
  • Parts of Speech Tagging
  • Image Captioning
  • Image Modification
  • etc..

For better understanding of Sequence to Sequence models, let’s take example of Language Translation (English -> German) and see the components.

arch fig

Usually, we use a RNN to encode the source (input) to a context vector. This context vector can be thought as an abstract representation of the entire input sentence. This vector then decoded by the RNN decoder, which learns to output the target (output) sentence by generating one word at a time.

Encoder

At each time-step, the input to the encoder RNN is both the current word, , as well as the hidden state from the previous time-step, , and the encoder RNN outputs a new hidden state . You can think of the hidden state as a vector representation of the sentence so far. The RNN can be represented as a function of both of and :

We’re using the RNN generally here, it could be any recurrent architecture, such as an LSTM (Long Short-Term Memory) or a GRU (Gated Recurrent Unit).

Here, we have , where , etc. The initial hidden state, , is usually either initialized to zeros or a learned parameter.

Once the final word, , has been passed into the RNN, we use the final hidden state, , as the context vector, i.e. . This is a vector representation of the entire source sentence.

Decoder

Now we have our context vector, , we can start decoding it to get the target sentence, “Hallo Guten Morgen”. Again, we append start and end of sequence tokens to the target sentence. At each time-step, the input to the decoder RNN (blue) is the current word, , as well as the hidden state from the previous time-step, , where the initial decoder hidden state, , is the context vector, , i.e. the initial decoder hidden state is the final encoder hidden state. Thus, similar to the encoder, we can represent the decoder as:

In the decoder, we need to go from the hidden state to an actual word, therefore at each time-step we use to predict (by passing it through a Linear layer) what we think is the next word in the sequence, .

We always use <sos> for the first input to the decoder, , but for subsequent inputs, , we will sometimes use the actual, ground truth next word in the sequence, and sometimes use the word predicted by our decoder, . This is called teacher forcing, and you can read about it more here.

When training/testing our model, we always know how many words are in our target sentence, so we stop generating words once we hit that many. During inference (i.e. real world usage) it is common to keep generating words until the model outputs an <eos> token or after a certain amount of words have been generated.

Once we have our predicted target sentence, , we compare it against our actual target sentence, , to calculate our loss. We then use this loss to update all of the parameters in our model.

References

Understanding LSTM

Sequence to Sequence Learning Paper

In the next post, we will see how to implement the language translation in pytorch.

THANK YOU !!!