Recurrent Neural Networks

Last time, before the midterm, we discussed using recurrent neural networks to make predictions about sequences. In particular, we treated tweets as a sequence of words. Since tweets can have a variable number of words, we needed an architecture that can take variable-sized sequences as input.

This time, we will use recurrent neural networks to generate sequences. Generating sequences is more involved comparing to making predictions about sequences. However, it is a very interesting task, and many students chose sequence-generation tasks for their projects.

Much of today's content is an adaptation of the "Practical PyTorch" github repository [1].

[1] https://github.com/spro/practical-pytorch/blob/master/char-rnn-generation/char-rnn-generation.ipynb

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

Review

The recurrent neural network architecture from last time looked something like this:

The input sequence is broken down into tokens. We could choose whether to tokenize based on words, or based on characters. The representation of each token (GloVe or one-hot) is processed by the RNN one step at a time to update the hidden (or context) state.

In a predictive RNN, the value of the hidden states is a representation of all the text that was processed thus far. Similarly, in a generative RNN, The value of the hidden state will be a representation of all the text that still needs to be generated. We will use this hidden state to produce the sequence, one token at a time.

Similar to last class we will break up the problem of generating text to generating one token at a time.

We will do so with the help of two functions:

  1. We need to be able to generate the next token, given the current hidden state. In practice, we get a probability distribution over the next token, and sample from that probability distribution.
  2. We need to be able to update the hidden state somehow. To do so, we need two piece of information: the old hidden state, and the actual token that was generated in the previous step. The actual token generated will inform the subsequent tokens.

We will repeat both functions until a special "END OF SEQUENCE" token is generated. Here is a pictorial representation of what we will do:

Note that there are several tricky things that we will have to figure out. For example, how do we actually sample the actual token from the probability distribution over tokens? What would we do during training, and how might that be different from during testing/evaluation? We will answer those questions as we implement the RNN.

For now, let's start with our training data.

Data: Donald Trump's Tweets from 2018

The training set we use is a collection of Donald Trump's tweets from 2018. We will only use tweets that are 140 characters or shorter, and tweets that contains more than just a URL. Since tweets often contain creative spelling and numbers, and upper vs lower case characters are read very differently, we will use a character-level RNN.

In [2]:
import csv
tweets = list(line[0] for line in csv.reader(open('trump.csv')))
len(tweets)
Out[2]:
22402

There are over 20000 tweets in this collection. Let's look at a few of them, just to get a sense of the kind of text we're dealing with:

In [3]:
print(tweets[100])
print(tweets[1000])
print(tweets[10000])
God Bless the people of Venezuela!
It was my honor. THANK YOU! https://t.co/1LvqbRQ1bi
Nobody but Donald Trump will save Israel. You are wasting your time with these politicians and political clowns. Best! #SheldonAdelson

Generating One Tweet

Normally, when we build a new machine learn model, we want to make sure that our model can overfit. To that end, we will first build a neural network that can generate one tweet really well. We can choose any tweet (or any other text) we want. Let's choose to build an RNN that generates tweet[100].

In [4]:
tweet = tweets[100]
tweet
Out[4]:
'God Bless the people of Venezuela!'

First, we will need to encode this tweet using a one-hot encoding. We'll build dictionary mappings from the character to the index of that character (a unique integer identifier), and from the index to the character. We'll use the same naming scheme that torchtext uses (stoi and itos).

For simplicity, we'll work with a limited vocabulary containing just the characters in tweet[100], plus two special tokens:

  • <EOS> represents "End of String", which we'll append to the end of our tweet. Since tweets are variable-length, this is a way for the RNN to signal that the entire sequence has been generated.
  • <BOS> represents "Beginning of String", which we'll prepend to the beginning of our tweet. This is the first token that we will feed into the RNN.

The way we use these special tokens will become more clear as we build the model.

In [5]:
vocab = list(set(tweet)) + ["<BOS>", "<EOS>"]
vocab_stoi = {s: i for i, s in enumerate(vocab)}
vocab_itos = {i: s for i, s in enumerate(vocab)}
vocab_size = len(vocab)

Now that we have our vocabulary, we can build the PyTorch model for this problem. The actual model is not as complex as you might think. We actually already learned about all the components that we need. (Using and training the model is the hard part)

In [6]:
class TextGenerator(nn.Module):
    def __init__(self, vocab_size, hidden_size, n_layers=1):
        super(TextGenerator, self).__init__()

        # identiy matrix for generating one-hot vectors
        self.ident = torch.eye(vocab_size)

        # recurrent neural network
        self.rnn = nn.GRU(vocab_size, hidden_size, n_layers, batch_first=True)

        # a fully-connect layer that outputs a distribution over
        # the next token, given the RNN output
        self.decoder = nn.Linear(hidden_size, vocab_size)
    
    def forward(self, inp, hidden=None):
        inp = self.ident[inp]                  # generate one-hot vectors of input
        output, hidden = self.rnn(inp, hidden) # get the next output and hidden state
        output = self.decoder(output)          # predict distribution over next tokens
        return output, hidden

model = TextGenerator(vocab_size, 64)

Training with Teacher Forcing

At a very high level, we want our RNN model to have a high probability of generating the tweet. An RNN model generates text one character at a time based on the hidden state value. At each time step, we will check whether the mdoel generated the correct character. That is, at each time step, we are trying to select the correct next character out of all the characters in our vocabulary. Recall that this problem is a multi-class classification problem, and we can use Cross-Entropy loss to train our network to become better at this type of problem.

In [7]:
criterion = nn.CrossEntropyLoss()

However, we don't just have a single multi-class classification problem. Instead, we have one classification problem per time-step (per token)! So, how do we predict the first token in the sequence? How do we predict the second token in the sequence?

To help you understand what happens durign RNN training, we'll start with a inefficient training code that shows you what happens step-by-step. We'll start with computing the loss for the first token generated, then the second token, and so on. Later on, we'll switch to a simpler and more performant version of the code.

So, let's start with the first classification problem: the problem of generating the first token (tweet[0]).

To generate the first token, we'll feed the RNN network (with an initial, empty hidden state) the "" token. Then, the output

In [8]:
bos_input = torch.Tensor([vocab_stoi["<BOS>"]]).long().unsqueeze(0)
output, hidden = model(bos_input, hidden=None)
output # distribution over the first token
Out[8]:
tensor([[[-0.0918, -0.0842,  0.0315,  0.0378, -0.0120, -0.1289, -0.0064,
          -0.1029,  0.0220,  0.0653,  0.0038,  0.0665,  0.0622, -0.0944,
           0.0047,  0.0386,  0.0756, -0.1149,  0.0396,  0.0007]]],
       grad_fn=<AddBackward0>)

We can compute the loss using criterion. Since the model is untrained, the loss is expected to be high. (For now, we won't do anything with this loss, and omit the backward pass.)

In [9]:
target = torch.Tensor([vocab_stoi[tweet[0]]]).long().unsqueeze(0)
criterion(output.reshape(-1, vocab_size), # reshape to 2D tensor
          target.reshape(-1))             # reshape to 1D tensor
Out[9]:
tensor(3.0005, grad_fn=<NllLossBackward>)

Now, we need to update the hidden state and generate a prediction for the next token. To do so, we need to provide the current token to the RNN. We already said that during test time, we'll need to sample from the predicted probabilty over tokens that the neural network just generated.

Right now, we can do something better: we can use the ground-truth, actual target token. This technique is called teacher-forcing, and generally speeds up training. The reason is that right now, since our model does not perform well, the predicted probability distribution is pretty far from the ground truth. So, it is very, very difficult for the neural network to get back on track given bad input data.

In [10]:
# Use teacher-forcing: we pass in the ground truth `target`,
# rather than using the NN predicted distribution
output, hidden = model(target, hidden)
output # distribution over the second token
Out[10]:
tensor([[[-0.1056, -0.1241,  0.0306,  0.0213, -0.0196, -0.1118,  0.0178,
          -0.1124,  0.0078,  0.0797,  0.0653,  0.1096,  0.0856, -0.1045,
           0.0404,  0.0355,  0.0577, -0.1344,  0.0471, -0.0176]]],
       grad_fn=<AddBackward0>)

Similar to the first step, we can compute the loss, quantifying the difference between the predicted distribution and the actual next token. This loss can be used to adjust the weights of the neural network (which we are not doing yet).

In [11]:
target = torch.Tensor([vocab_stoi[tweet[1]]]).long().unsqueeze(0)
criterion(output.reshape(-1, vocab_size), # reshape to 2D tensor
          target.reshape(-1))             # reshape to 1D tensor
Out[11]:
tensor(2.9567, grad_fn=<NllLossBackward>)

We can continue this process of:

  • feeding the previous ground-truth token to the RNN,
  • obtaining the prediction distribution over the next token, and
  • computing the loss,

for as many steps as there are tokens in the ground-truth tweet.

In [12]:
for i in range(2, len(tweet)):
    output, hidden = model(target, hidden)
    target = torch.Tensor([vocab_stoi[tweet[1]]]).long().unsqueeze(0)
    loss = criterion(output.reshape(-1, vocab_size), # reshape to 2D tensor
                     target.reshape(-1))             # reshape to 1D tensor
    print(i, output, loss)
2 tensor([[[-0.1138, -0.1170,  0.0208,  0.0381, -0.0625, -0.1147, -0.0110,
          -0.0882, -0.0044,  0.0436,  0.0828,  0.1008,  0.0914, -0.0980,
           0.0166,  0.0583,  0.0437, -0.1662,  0.0309, -0.0285]]],
       grad_fn=<AddBackward0>) tensor(2.9265, grad_fn=<NllLossBackward>)
3 tensor([[[-0.1209, -0.1185,  0.0153,  0.0428, -0.0865, -0.1166, -0.0276,
          -0.0776, -0.0067,  0.0237,  0.0923,  0.0935,  0.0988, -0.0966,
           0.0045,  0.0689,  0.0368, -0.1801,  0.0207, -0.0360]]],
       grad_fn=<AddBackward0>) tensor(2.9115, grad_fn=<NllLossBackward>)
4 tensor([[[-0.1262, -0.1212,  0.0122,  0.0435, -0.1004, -0.1182, -0.0370,
          -0.0729, -0.0066,  0.0122,  0.0978,  0.0889,  0.1049, -0.0970,
          -0.0026,  0.0744,  0.0342, -0.1856,  0.0152, -0.0409]]],
       grad_fn=<AddBackward0>) tensor(2.9034, grad_fn=<NllLossBackward>)
5 tensor([[[-0.1297, -0.1233,  0.0105,  0.0430, -0.1084, -0.1196, -0.0422,
          -0.0707, -0.0061,  0.0052,  0.1010,  0.0861,  0.1091, -0.0978,
          -0.0070,  0.0774,  0.0338, -0.1873,  0.0124, -0.0440]]],
       grad_fn=<AddBackward0>) tensor(2.8989, grad_fn=<NllLossBackward>)
6 tensor([[[-0.1317, -0.1246,  0.0095,  0.0423, -0.1131, -0.1207, -0.0450,
          -0.0695, -0.0057,  0.0007,  0.1028,  0.0843,  0.1118, -0.0986,
          -0.0097,  0.0792,  0.0342, -0.1875,  0.0110, -0.0458]]],
       grad_fn=<AddBackward0>) tensor(2.8963, grad_fn=<NllLossBackward>)
7 tensor([[[-0.1327, -0.1254,  0.0089,  0.0418, -0.1159, -0.1216, -0.0463,
          -0.0688, -0.0054, -0.0021,  0.1038,  0.0831,  0.1135, -0.0993,
          -0.0113,  0.0802,  0.0349, -0.1872,  0.0104, -0.0468]]],
       grad_fn=<AddBackward0>) tensor(2.8948, grad_fn=<NllLossBackward>)
8 tensor([[[-0.1332, -0.1258,  0.0086,  0.0414, -0.1175, -0.1222, -0.0470,
          -0.0683, -0.0052, -0.0039,  0.1043,  0.0823,  0.1145, -0.0998,
          -0.0122,  0.0808,  0.0354, -0.1867,  0.0102, -0.0473]]],
       grad_fn=<AddBackward0>) tensor(2.8940, grad_fn=<NllLossBackward>)
9 tensor([[[-0.1334, -0.1261,  0.0084,  0.0412, -0.1185, -0.1225, -0.0473,
          -0.0679, -0.0050, -0.0050,  0.1046,  0.0818,  0.1151, -0.1002,
          -0.0128,  0.0811,  0.0358, -0.1863,  0.0101, -0.0475]]],
       grad_fn=<AddBackward0>) tensor(2.8935, grad_fn=<NllLossBackward>)
10 tensor([[[-0.1334, -0.1262,  0.0083,  0.0410, -0.1190, -0.1228, -0.0474,
          -0.0677, -0.0049, -0.0058,  0.1048,  0.0814,  0.1154, -0.1005,
          -0.0130,  0.0813,  0.0361, -0.1861,  0.0100, -0.0476]]],
       grad_fn=<AddBackward0>) tensor(2.8933, grad_fn=<NllLossBackward>)
11 tensor([[[-0.1334, -0.1264,  0.0083,  0.0409, -0.1194, -0.1230, -0.0475,
          -0.0675, -0.0048, -0.0062,  0.1049,  0.0812,  0.1156, -0.1007,
          -0.0132,  0.0813,  0.0363, -0.1859,  0.0100, -0.0476]]],
       grad_fn=<AddBackward0>) tensor(2.8932, grad_fn=<NllLossBackward>)
12 tensor([[[-0.1333, -0.1264,  0.0083,  0.0408, -0.1196, -0.1230, -0.0475,
          -0.0674, -0.0047, -0.0065,  0.1049,  0.0810,  0.1157, -0.1008,
          -0.0133,  0.0814,  0.0364, -0.1857,  0.0100, -0.0476]]],
       grad_fn=<AddBackward0>) tensor(2.8931, grad_fn=<NllLossBackward>)
13 tensor([[[-0.1333, -0.1265,  0.0083,  0.0408, -0.1197, -0.1231, -0.0475,
          -0.0673, -0.0047, -0.0067,  0.1049,  0.0809,  0.1158, -0.1009,
          -0.0133,  0.0814,  0.0365, -0.1856,  0.0100, -0.0475]]],
       grad_fn=<AddBackward0>) tensor(2.8931, grad_fn=<NllLossBackward>)
14 tensor([[[-0.1333, -0.1265,  0.0083,  0.0408, -0.1198, -0.1231, -0.0475,
          -0.0672, -0.0047, -0.0068,  0.1050,  0.0808,  0.1158, -0.1010,
          -0.0133,  0.0813,  0.0366, -0.1856,  0.0100, -0.0475]]],
       grad_fn=<AddBackward0>) tensor(2.8931, grad_fn=<NllLossBackward>)
15 tensor([[[-0.1332, -0.1266,  0.0083,  0.0408, -0.1198, -0.1232, -0.0475,
          -0.0672, -0.0046, -0.0068,  0.1050,  0.0808,  0.1158, -0.1010,
          -0.0133,  0.0813,  0.0366, -0.1856,  0.0100, -0.0475]]],
       grad_fn=<AddBackward0>) tensor(2.8931, grad_fn=<NllLossBackward>)
16 tensor([[[-0.1332, -0.1266,  0.0083,  0.0407, -0.1199, -0.1232, -0.0475,
          -0.0672, -0.0046, -0.0069,  0.1050,  0.0807,  0.1158, -0.1010,
          -0.0133,  0.0813,  0.0366, -0.1856,  0.0100, -0.0475]]],
       grad_fn=<AddBackward0>) tensor(2.8932, grad_fn=<NllLossBackward>)
17 tensor([[[-0.1332, -0.1266,  0.0083,  0.0407, -0.1199, -0.1232, -0.0475,
          -0.0672, -0.0046, -0.0069,  0.1050,  0.0807,  0.1158, -0.1011,
          -0.0133,  0.0813,  0.0366, -0.1855,  0.0100, -0.0475]]],
       grad_fn=<AddBackward0>) tensor(2.8932, grad_fn=<NllLossBackward>)
18 tensor([[[-0.1332, -0.1266,  0.0083,  0.0407, -0.1199, -0.1232, -0.0475,
          -0.0672, -0.0046, -0.0069,  0.1050,  0.0807,  0.1158, -0.1011,
          -0.0133,  0.0813,  0.0366, -0.1855,  0.0100, -0.0474]]],
       grad_fn=<AddBackward0>) tensor(2.8932, grad_fn=<NllLossBackward>)
19 tensor([[[-0.1332, -0.1266,  0.0083,  0.0407, -0.1199, -0.1232, -0.0475,
          -0.0672, -0.0046, -0.0069,  0.1050,  0.0807,  0.1159, -0.1011,
          -0.0133,  0.0813,  0.0366, -0.1855,  0.0100, -0.0474]]],
       grad_fn=<AddBackward0>) tensor(2.8932, grad_fn=<NllLossBackward>)
20 tensor([[[-0.1332, -0.1266,  0.0083,  0.0407, -0.1199, -0.1232, -0.0475,
          -0.0672, -0.0046, -0.0069,  0.1050,  0.0807,  0.1159, -0.1011,
          -0.0133,  0.0813,  0.0366, -0.1855,  0.0100, -0.0474]]],
       grad_fn=<AddBackward0>) tensor(2.8932, grad_fn=<NllLossBackward>)
21 tensor([[[-0.1332, -0.1266,  0.0083,  0.0407, -0.1199, -0.1232, -0.0475,
          -0.0672, -0.0046, -0.0069,  0.1050,  0.0807,  0.1159, -0.1011,
          -0.0133,  0.0813,  0.0366, -0.1855,  0.0099, -0.0474]]],
       grad_fn=<AddBackward0>) tensor(2.8932, grad_fn=<NllLossBackward>)
22 tensor([[[-0.1332, -0.1266,  0.0083,  0.0407, -0.1199, -0.1232, -0.0475,
          -0.0672, -0.0046, -0.0069,  0.1050,  0.0807,  0.1159, -0.1011,
          -0.0133,  0.0813,  0.0366, -0.1855,  0.0099, -0.0474]]],
       grad_fn=<AddBackward0>) tensor(2.8932, grad_fn=<NllLossBackward>)
23 tensor([[[-0.1332, -0.1266,  0.0083,  0.0407, -0.1199, -0.1232, -0.0475,
          -0.0672, -0.0046, -0.0069,  0.1050,  0.0807,  0.1159, -0.1011,
          -0.0133,  0.0813,  0.0366, -0.1855,  0.0099, -0.0474]]],
       grad_fn=<AddBackward0>) tensor(2.8932, grad_fn=<NllLossBackward>)
24 tensor([[[-0.1332, -0.1266,  0.0083,  0.0407, -0.1199, -0.1232, -0.0475,
          -0.0672, -0.0046, -0.0069,  0.1050,  0.0807,  0.1159, -0.1011,
          -0.0133,  0.0813,  0.0366, -0.1855,  0.0099, -0.0474]]],
       grad_fn=<AddBackward0>) tensor(2.8932, grad_fn=<NllLossBackward>)
25 tensor([[[-0.1332, -0.1266,  0.0083,  0.0407, -0.1199, -0.1232, -0.0475,
          -0.0672, -0.0046, -0.0069,  0.1050,  0.0807,  0.1159, -0.1011,
          -0.0133,  0.0813,  0.0366, -0.1855,  0.0099, -0.0474]]],
       grad_fn=<AddBackward0>) tensor(2.8932, grad_fn=<NllLossBackward>)
26 tensor([[[-0.1332, -0.1266,  0.0083,  0.0407, -0.1199, -0.1232, -0.0475,
          -0.0672, -0.0046, -0.0069,  0.1050,  0.0807,  0.1159, -0.1011,
          -0.0133,  0.0813,  0.0366, -0.1855,  0.0099, -0.0474]]],
       grad_fn=<AddBackward0>) tensor(2.8932, grad_fn=<NllLossBackward>)
27 tensor([[[-0.1332, -0.1266,  0.0083,  0.0407, -0.1199, -0.1232, -0.0475,
          -0.0672, -0.0046, -0.0069,  0.1050,  0.0807,  0.1159, -0.1011,
          -0.0133,  0.0813,  0.0366, -0.1855,  0.0099, -0.0474]]],
       grad_fn=<AddBackward0>) tensor(2.8932, grad_fn=<NllLossBackward>)
28 tensor([[[-0.1332, -0.1266,  0.0083,  0.0407, -0.1199, -0.1232, -0.0475,
          -0.0672, -0.0046, -0.0069,  0.1050,  0.0807,  0.1159, -0.1011,
          -0.0133,  0.0813,  0.0366, -0.1855,  0.0099, -0.0474]]],
       grad_fn=<AddBackward0>) tensor(2.8932, grad_fn=<NllLossBackward>)
29 tensor([[[-0.1332, -0.1266,  0.0083,  0.0407, -0.1199, -0.1232, -0.0475,
          -0.0672, -0.0046, -0.0069,  0.1050,  0.0807,  0.1159, -0.1011,
          -0.0133,  0.0813,  0.0366, -0.1855,  0.0099, -0.0474]]],
       grad_fn=<AddBackward0>) tensor(2.8932, grad_fn=<NllLossBackward>)
30 tensor([[[-0.1332, -0.1266,  0.0083,  0.0407, -0.1199, -0.1232, -0.0475,
          -0.0672, -0.0046, -0.0069,  0.1050,  0.0807,  0.1159, -0.1011,
          -0.0133,  0.0813,  0.0366, -0.1855,  0.0099, -0.0474]]],
       grad_fn=<AddBackward0>) tensor(2.8932, grad_fn=<NllLossBackward>)
31 tensor([[[-0.1332, -0.1266,  0.0083,  0.0407, -0.1199, -0.1232, -0.0475,
          -0.0672, -0.0046, -0.0069,  0.1050,  0.0807,  0.1159, -0.1011,
          -0.0133,  0.0813,  0.0366, -0.1855,  0.0099, -0.0474]]],
       grad_fn=<AddBackward0>) tensor(2.8932, grad_fn=<NllLossBackward>)
32 tensor([[[-0.1332, -0.1266,  0.0083,  0.0407, -0.1199, -0.1232, -0.0475,
          -0.0672, -0.0046, -0.0069,  0.1050,  0.0807,  0.1159, -0.1011,
          -0.0133,  0.0813,  0.0366, -0.1855,  0.0099, -0.0474]]],
       grad_fn=<AddBackward0>) tensor(2.8932, grad_fn=<NllLossBackward>)
33 tensor([[[-0.1332, -0.1266,  0.0083,  0.0407, -0.1199, -0.1232, -0.0475,
          -0.0672, -0.0046, -0.0069,  0.1050,  0.0807,  0.1159, -0.1011,
          -0.0133,  0.0813,  0.0366, -0.1855,  0.0099, -0.0474]]],
       grad_fn=<AddBackward0>) tensor(2.8932, grad_fn=<NllLossBackward>)

Finally, with our final token, we should expect to output the "" token, so that our RNN learns when to stop generating characters.

In [13]:
output, hidden = model(target, hidden)
target = torch.Tensor([vocab_stoi["<EOS>"]]).long().unsqueeze(0)
loss = criterion(output.reshape(-1, vocab_size), # reshape to 2D tensor
                 target.reshape(-1))             # reshape to 1D tensor
print(i, output, loss)
33 tensor([[[-0.1332, -0.1266,  0.0083,  0.0407, -0.1199, -0.1232, -0.0475,
          -0.0672, -0.0046, -0.0069,  0.1050,  0.0807,  0.1159, -0.1011,
          -0.0133,  0.0813,  0.0366, -0.1855,  0.0099, -0.0474]]],
       grad_fn=<AddBackward0>) tensor(3.0219, grad_fn=<NllLossBackward>)

In practice, we don't really need a loop. Recall that in a predictive RNN, the nn.RNN module can take an entire sequence as input. We can do the same thing here:

In [14]:
tweet_ch = ["<BOS>"] + list(tweet) + ["<EOS>"]
tweet_indices = [vocab_stoi[ch] for ch in tweet_ch]
tweet_tensor = torch.Tensor(tweet_indices).long().unsqueeze(0)

print(tweet_tensor.shape)

output, hidden = model(tweet_tensor[:,:-1]) # <EOS> is never an input token
target = tweet_tensor[:,1:]                 # <BOS> is never a target token
loss = criterion(output.reshape(-1, vocab_size), # reshape to 2D tensor
                 target.reshape(-1))             # reshape to 1D tensor
torch.Size([1, 36])

Here, the input to our neural network model is the entire sequence of input tokens (everything from "" to the last character of the tweet). The neural network generates a prediction distribution of the next token at each step. We can compare each of these with the ground-truth target.

Our training loop (for learning to generate the single tweet) will therefore look something like this:

In [15]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
for it in range(500):
    optimizer.zero_grad()
    output, _ = model(tweet_tensor[:,:-1])
    loss = criterion(output.reshape(-1, vocab_size),
                 target.reshape(-1))
    loss.backward()
    optimizer.step()

    if (it+1) % 100 == 0:
        print("[Iter %d] Loss %f" % (it+1, float(loss)))
[Iter 100] Loss 1.653741
[Iter 200] Loss 0.162499
[Iter 300] Loss 0.035650
[Iter 400] Loss 0.016482
[Iter 500] Loss 0.009848

The training loss is decreasing with training, which is what we expect.

Generating a Token

At this point, we want to see whether our model is actually learning something. So, we need to talk about how to actually use the RNN model to generate text. If we can generate text, we can make a qualitative asssessment of how well our RNN is performing.

The main difference between training and test-time (generation time) is that we don't have the ground-truth tokens to feed as inputs to the RNN. Instead, we need to actually sample a token based on the neural network's prediction distribution.

But how can we sample a token from a distribution?

On one extreme, we can always take the token with the largest probability (argmax). This has been our go-to technique in other classification tasks. However, this idea will fail here. The reason is that in practice, we want to be able to generate a variety of different sequences from the same model. An RNN that can only generate a single new Trump Tweet is fairly useless.

In short, we want some randomness. We can do so by using the logit outputs from our model to construct a multinomial distribution over the tokens, then and sample a random token from that multinomial distribution.

One natural multinomial distribution we can choose is the distribution we get after applying the softmax on the outputs. However, we will do one more thing: we will add a temperature parameter to manipulate the softmax outputs. We can set a higher temperature to make the probability of each token more even (more random), or a lower temperature to assighn more probability to the tokens with a higher logit (output). A higher temperature means that we will get a more diverse sample, with potentially more mistakes. A lower temperature means that we may see repetitions of the same high probability sequence.

In [16]:
def sample_sequence(model, max_len=100, temperature=0.8):
    generated_sequence = ""
   
    inp = torch.Tensor([vocab_stoi["<BOS>"]]).long()
    hidden = None
    for p in range(max_len):
        output, hidden = model(inp.unsqueeze(0), hidden)
        # Sample from the network as a multinomial distribution
        output_dist = output.data.view(-1).div(temperature).exp()
        top_i = int(torch.multinomial(output_dist, 1)[0])
        # Add predicted character to string and use as next input
        predicted_char = vocab_itos[top_i]
        
        if predicted_char == "<EOS>":
            break
        generated_sequence += predicted_char       
        inp = torch.Tensor([top_i]).long()
    return generated_sequence

print(sample_sequence(model, temperature=0.8))
print(sample_sequence(model, temperature=1.0))
print(sample_sequence(model, temperature=1.5))
print(sample_sequence(model, temperature=2.0))
print(sample_sequence(model, temperature=5.0))
God Bless the people of Venezuela!
God Bless the peopeoaof Venezuela!
God opess thB people of VeVezueBa!
God Bless teGoe plos teneluela!
!dunf el  heVp

Since we only trained the model on a single sequence, we won't see the effect of the temperature parameter yet.

For now, the output of the calls to the sample_sequence function assures us that our training code looks reasonable, and we can proceed to training on our full dataset!

Training the Trump Tweet Generator

For the actual training, let's use torchtext so that we can use the BucketIterator to make batches. Like in Lab 5, we'll create a torchtext.data.Field to use torchtext to read the CSV file, and convert characters into indices. The object has convient parameters to specify the BOS and EOS tokens.

In [17]:
import torchtext

text_field = torchtext.data.Field(sequential=True,      # text sequence
                                  tokenize=lambda x: x, # because are building a character-RNN
                                  include_lengths=True, # to track the length of sequences, for batching
                                  batch_first=True,
                                  use_vocab=True,       # to turn each character into an integer index
                                  init_token="<BOS>",   # BOS token
                                  eos_token="<EOS>")    # EOS token

fields = [('text', text_field), ('created_at', None), ('id_str', None)]
trump_tweets = torchtext.data.TabularDataset("trump.csv", "csv", fields)
len(trump_tweets) # should be >20,000 like before
Out[17]:
22402
In [18]:
text_field.build_vocab(trump_tweets)
vocab_stoi = text_field.vocab.stoi # so we don't have to rewrite sample_sequence
vocab_itos = text_field.vocab.itos # so we don't have to rewrite sample_sequence
vocab_size = len(text_field.vocab.itos)
vocab_size
Out[18]:
253

Let's just verify that the BucketIterator works as expected, but start with batch_size of 1.

In [19]:
data_iter = torchtext.data.BucketIterator(trump_tweets, 
                                          batch_size=1,
                                          sort_key=lambda x: len(x.text),
                                          sort_within_batch=True)
for (tweet, lengths), label in data_iter:
    print(label)   # should be None
    print(lengths) # contains the length of the tweet(s) in batch
    print(tweet.shape) # should be [1, max(length)]
    break
None
tensor([90])
torch.Size([1, 90])

To account for batching, our actual training code will change, but just a little bit. In fact, our training code from before will work with a batch size larger than one!

In [20]:
def train(model, data, batch_size=1, num_epochs=1, lr=0.001, print_every=100):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    it = 0
    
    data_iter = torchtext.data.BucketIterator(data,
                                              batch_size=batch_size,
                                              sort_key=lambda x: len(x.text),
                                              sort_within_batch=True)
    for e in range(num_epochs):
        # get training set
        avg_loss = 0
        for (tweet, lengths), label in data_iter:
            target = tweet[:, 1:]
            inp = tweet[:, :-1]
            # cleanup
            optimizer.zero_grad()
            # forward pass
            output, _ = model(inp)
            loss = criterion(output.reshape(-1, vocab_size), target.reshape(-1))
            # backward pass
            loss.backward()
            optimizer.step()

            avg_loss += loss
            it += 1 # increment iteration count
            if it % print_every == 0:
                print("[Iter %d] Loss %f" % (it+1, float(avg_loss/print_every)))
                print("    " + sample_sequence(model, 140, 0.8))
                avg_loss = 0

model = TextGenerator(vocab_size, 64)
#train(model, trump_tweets, batch_size=1, num_epochs=1, lr=0.004, print_every=100)
#train(model, trump_tweets, batch_size=32, num_epochs=1, lr=0.004, print_every=100)
#print(sample_sequence(model, temperature=0.8))
#print(sample_sequence(model, temperature=0.8))
#print(sample_sequence(model, temperature=1.0))
#print(sample_sequence(model, temperature=1.0))
#print(sample_sequence(model, temperature=1.5))
#print(sample_sequence(model, temperature=1.5))
#print(sample_sequence(model, temperature=2.0))
#print(sample_sequence(model, temperature=2.0))
#print(sample_sequence(model, temperature=5.0))
#print(sample_sequence(model, temperature=5.0))