Generative Adversarial Networks

Thus far, we have discussed several generative models. A generative model learns the structure of a set of input data. In doing so, the model learns to generate new data that it has never seen before in the training data. The generative models we discussed were:

  • an autoencoder
  • an RNN used to generate text

A Generative Adversarial Network (GAN) is yet another example of a generative model. Unlike an autoencoder, the main purpose of a GAN model is to learn to generate new examples.

To motivate the GAN, let's first discuss the drawbacks of an autoencoder.

In [1]:
%matplotlib inline

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from torchvision import datasets, transforms

mnist_data = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())

Autoencoder Review

Here is the code that we wrote back in the autoencoder lecture. The autoencoder model consists of an encoder that maps images to a vector embedding, and a decoder that reconstructs images from an embedding.

In [2]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 7)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 7),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

We trained an autoencoder model on the reconstruction loss: the difference in pixel intensities between a real image and its reconstruction. We won't run the entire training code today. Instead, we will load a model that was trained earlier.

In [3]:
def train(model, num_epochs=5, batch_size=64, learning_rate=1e-3):
    torch.manual_seed(42)
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
    train_loader = torch.utils.data.DataLoader(mnist_data, batch_size=batch_size, shuffle=True)
    outputs = []
    for epoch in range(num_epochs):
        for data in train_loader:
            img, label = data
            recon = model(img)
            loss = criterion(recon, img)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        print('Epoch:{}, Loss:{:.4f}'.format(epoch+1, float(loss)))
        outputs.append((epoch, img, recon),)
    return outputs

model = Autoencoder()
#outputs = train(model, num_epochs=5)
#torch.save(model.state_dict(), "autoencoder.pt")
ckpt = torch.load("autoencoder.pt")
model.load_state_dict(ckpt)

Let's take a look at one MNIST image from training, and its autoencoder reconstruction:

In [4]:
original = mnist_data[0][0].unsqueeze(0)
emb = model.encoder(original)
recon_img = model.decoder(emb).detach().numpy()[0,0,:,:]

# plot the original image
plt.subplot(1,2,1)
plt.title("original")
plt.imshow(original[0][0], cmap='gray')

# plot the reconstructed
plt.subplot(1,2,2)
plt.title("reconstruction")
plt.imshow(recon_img, cmap='gray')
Out[4]:
<matplotlib.image.AxesImage at 0x7fdae604bf28>

The reconstruction is reasonable, but notice that the reconstruction is much blurrier than the original image. If we perturb the embedding to generate a new image, we still should see this blurriness:

In [5]:
# Run this a few times
x = emb + 10 * torch.randn(1, 64, 1, 1) # add a random perturbation

# reconstruct image and plot
img = model.decoder(x)[0,0,:,:]
img = img.detach().numpy()
plt.title("perturbed reconstruction")
plt.imshow(img, cmap='gray')
Out[5]:
<matplotlib.image.AxesImage at 0x7fdb524edeb8>

The reason autoencoders tend to generate blurry images is because of the loss function that it uses. The use of MSELoss (mean square error loss) has an averaging effect. If the model learns that two possible values for a pixel is 0 and 1, then it will learn to predict a value of 0.5 for that pixel to minimize the mean square error. However, none of our training data might have a pixel intensity of 0.5 at that pixel! A human would easily tell the difference between a generated image and a real image.

But what would be a more appropriate loss function than the MSELoss? People have tried to come up with better loss functions, but it is difficult to construct a general enough loss function that is appropriate for all kinds of generation tasks. What we really want to do is learn a loss function!

The main idea is that generates images that fail to fool a human should also fail to fool a neural network trained to differentiate real vs fake images. We can use the prediction of this discriminator neural network to guide the training of our generator network.

Generative Adversarial Network

A generative adversarial network (GAN) model consists of two models:

  • A Generator network $G$ that takes in a latent embedding (usually random noise) and generates an image like those that exists in the training data
  • A Discriminator network $D$ that tries to distinguish between real images from the training data, and fake images produced by the generator

In essense, we have two neural networks that are adversaries: the generator wants to fool the discriminator, and the discriminator wants to avoid being fooled. The setup is known as a min-max game.

Let's set up a simple generator and a discriminator to start:

In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28, 300),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(300, 100),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(100, 1)
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        out = self.model(x)
        return out.view(x.size(0))

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 300),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(300, 28*28),
            nn.Sigmoid()
        )

    def forward(self, x):
        out = self.model(x).view(x.size(0), 1, 28, 28)
        return out

For now, both the Discriminator and Generator are fully-connected networks. One difference between these models and the previous models we've built is that we are using a nn.LeakyReLU activation.

Actually, you have seen leaky ReLU activations before, from the very beginning of the course, in assignment 1! Leaky ReLU is a variation of the ReLU activation that lets some information through, even when its input is less than 0. The layer nn.LeakyReLU(0.2, inplace=True) performs the computation: x if x > 0 else 0.2 * x.

But what loss function should we optimize? Consider the following quantity:

P(D correctly identifies real image) + P(D correctly identifies image generated by G)

A good discriminator would want to maximize the above quanity by altering its parameters.

Likewise, a good generator would want to minimize the above quanity. Actually, the only term that the generator controls is P(D correctly identifies image generated by G)$ So, the best thing for the generator to do is alter its parameters to generate images that can fool G.

Since we are looking at class probabilities, we will use binary cross entropy loss.

Here is a rudimentary training loop to train a GAN. For every minimatch of data, we train the discriminator for one iteration, and then we train the generator for one iteration.

For the discriminator, we use the label 1 to represent a fake image, and 0 to represent a real image.

In [7]:
def train(generator, discriminator, lr=0.001, num_epochs=5):
    criterion = nn.BCEWithLogitsLoss()
    d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)
    g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)

    train_loader = torch.utils.data.DataLoader(mnist_data, batch_size=100, shuffle=True)

    num_test_samples = 16
    test_noise = torch.randn(num_test_samples, 100)

    for epoch in range(num_epochs):
        # label that we are using both models 
        generator.train()
        discriminator.train()


        for n, (images, _) in enumerate(train_loader):
            # === Train the Discriminator ===

            noise = torch.randn(images.size(0), 100)
            fake_images = generator(noise)
            inputs = torch.cat([images, fake_images])
            labels = torch.cat([torch.zeros(images.size(0)), # real
                                torch.ones(images.size(0))]) # fake

            discriminator.zero_grad()
            d_outputs = discriminator(inputs)
            d_loss = criterion(d_outputs, labels)
            d_loss.backward()
            d_optimizer.step()

            # === Train the Generator ===
            noise = torch.randn(images.size(0), 100)
            fake_images = generator(noise)
            outputs = discriminator(fake_images)

            generator.zero_grad()
            g_loss = criterion(outputs, torch.zeros(images.size(0)))
            g_loss.backward()
            g_optimizer.step()

        scores = torch.sigmoid(d_outputs)
        real_score = scores[:images.size(0)].data.mean()
        fake_score = scores[images.size(0):].data.mean()


        print('Epoch [%d/%d], d_loss: %.4f, g_loss: %.4f, ' 
              'D(x): %.2f, D(G(z)): %.2f' 
              % (epoch + 1, num_epochs, d_loss.item(), g_loss.item(), real_score, fake_score))
        
        # plot images
        generator.eval()
        discriminator.eval()
        test_images = generator(test_noise)
        plt.figure(figsize=(9, 3))
        for k in range(16):
            plt.subplot(2, 8, k+1)
            plt.imshow(test_images[k,:].data.numpy().reshape(28, 28), cmap='Greys')
        plt.show()

Let's try training the network.

In [8]:
discriminator = Discriminator()
generator = Generator()
train(generator, discriminator, lr=0.001, num_epochs=50)
Epoch [1/50], d_loss: 0.0039, g_loss: 28.1038, D(x): 0.01, D(G(z)): 1.00
Epoch [2/50], d_loss: 0.0355, g_loss: 9.0176, D(x): 0.02, D(G(z)): 0.99
Epoch [3/50], d_loss: 0.0504, g_loss: 11.3650, D(x): 0.04, D(G(z)): 1.00
Epoch [4/50], d_loss: 0.0520, g_loss: 5.6159, D(x): 0.00, D(G(z)): 0.93
Epoch [5/50], d_loss: 0.0410, g_loss: 8.8102, D(x): 0.04, D(G(z)): 0.99
Epoch [6/50], d_loss: 0.0372, g_loss: 8.0096, D(x): 0.03, D(G(z)): 0.98
Epoch [7/50], d_loss: 0.0213, g_loss: 6.7164, D(x): 0.02, D(G(z)): 0.99
Epoch [8/50], d_loss: 0.0890, g_loss: 7.7729, D(x): 0.08, D(G(z)): 0.97
Epoch [9/50], d_loss: 0.1391, g_loss: 6.0553, D(x): 0.14, D(G(z)): 0.98
Epoch [10/50], d_loss: 0.1436, g_loss: 5.5058, D(x): 0.11, D(G(z)): 0.95
Epoch [11/50], d_loss: 0.2125, g_loss: 4.7517, D(x): 0.18, D(G(z)): 0.95
Epoch [12/50], d_loss: 0.1158, g_loss: 3.3886, D(x): 0.05, D(G(z)): 0.88
Epoch [13/50], d_loss: 0.2743, g_loss: 4.0402, D(x): 0.06, D(G(z)): 0.76
Epoch [14/50], d_loss: 0.2167, g_loss: 3.7580, D(x): 0.13, D(G(z)): 0.87
Epoch [15/50], d_loss: 0.2183, g_loss: 4.1856, D(x): 0.17, D(G(z)): 0.91
Epoch [16/50], d_loss: 0.2104, g_loss: 2.2845, D(x): 0.07, D(G(z)): 0.80
Epoch [17/50], d_loss: 0.1449, g_loss: 4.1655, D(x): 0.11, D(G(z)): 0.93
Epoch [18/50], d_loss: 0.1908, g_loss: 3.2261, D(x): 0.09, D(G(z)): 0.85
Epoch [19/50], d_loss: 0.2504, g_loss: 3.4088, D(x): 0.15, D(G(z)): 0.83
Epoch [20/50], d_loss: 0.2384, g_loss: 3.2706, D(x): 0.18, D(G(z)): 0.92
Epoch [21/50], d_loss: 0.1728, g_loss: 3.5893, D(x): 0.16, D(G(z)): 0.93
Epoch [22/50], d_loss: 0.1870, g_loss: 4.6195, D(x): 0.12, D(G(z)): 0.92
Epoch [23/50], d_loss: 0.1690, g_loss: 3.6551, D(x): 0.13, D(G(z)): 0.91
Epoch [24/50], d_loss: 0.2366, g_loss: 3.4471, D(x): 0.12, D(G(z)): 0.83
Epoch [25/50], d_loss: 0.2217, g_loss: 3.1741, D(x): 0.11, D(G(z)): 0.86
Epoch [26/50], d_loss: 0.2127, g_loss: 2.8849, D(x): 0.13, D(G(z)): 0.86
Epoch [27/50], d_loss: 0.2367, g_loss: 3.6114, D(x): 0.08, D(G(z)): 0.80
Epoch [28/50], d_loss: 0.2102, g_loss: 2.8371, D(x): 0.16, D(G(z)): 0.91
Epoch [29/50], d_loss: 0.2079, g_loss: 3.2904, D(x): 0.12, D(G(z)): 0.82
Epoch [30/50], d_loss: 0.2162, g_loss: 4.0277, D(x): 0.16, D(G(z)): 0.92
Epoch [31/50], d_loss: 0.1791, g_loss: 3.0529, D(x): 0.11, D(G(z)): 0.86
Epoch [32/50], d_loss: 0.1730, g_loss: 2.8293, D(x): 0.11, D(G(z)): 0.86
Epoch [33/50], d_loss: 0.1867, g_loss: 3.6447, D(x): 0.12, D(G(z)): 0.89
Epoch [34/50], d_loss: 0.2027, g_loss: 3.1301, D(x): 0.12, D(G(z)): 0.88
Epoch [35/50], d_loss: 0.2233, g_loss: 2.8957, D(x): 0.09, D(G(z)): 0.81
Epoch [36/50], d_loss: 0.1618, g_loss: 2.9497, D(x): 0.10, D(G(z)): 0.87
Epoch [37/50], d_loss: 0.2352, g_loss: 3.0068, D(x): 0.17, D(G(z)): 0.89
Epoch [38/50], d_loss: 0.2614, g_loss: 3.0805, D(x): 0.13, D(G(z)): 0.82
Epoch [39/50], d_loss: 0.2401, g_loss: 3.0561, D(x): 0.09, D(G(z)): 0.79
Epoch [40/50], d_loss: 0.2314, g_loss: 2.2865, D(x): 0.15, D(G(z)): 0.87
Epoch [41/50], d_loss: 0.2277, g_loss: 2.9825, D(x): 0.15, D(G(z)): 0.89
Epoch [42/50], d_loss: 0.2554, g_loss: 3.8359, D(x): 0.11, D(G(z)): 0.84
Epoch [43/50], d_loss: 0.2205, g_loss: 2.7953, D(x): 0.15, D(G(z)): 0.85
Epoch [44/50], d_loss: 0.3212, g_loss: 3.4167, D(x): 0.21, D(G(z)): 0.90
Epoch [45/50], d_loss: 0.2503, g_loss: 3.7404, D(x): 0.11, D(G(z)): 0.83
Epoch [46/50], d_loss: 0.1682, g_loss: 3.3715, D(x): 0.13, D(G(z)): 0.91
Epoch [47/50], d_loss: 0.1615, g_loss: 2.3880, D(x): 0.11, D(G(z)): 0.88
Epoch [48/50], d_loss: 0.1618, g_loss: 2.8510, D(x): 0.08, D(G(z)): 0.87
Epoch [49/50], d_loss: 0.3033, g_loss: 3.5591, D(x): 0.24, D(G(z)): 0.92
Epoch [50/50], d_loss: 0.2146, g_loss: 2.4731, D(x): 0.11, D(G(z)): 0.85

GANs are notoriously difficult to train. One difficulty is that a training curve is no longer as helpful as it was for a supervised learning problem! The generator and discriminator losses tend to bounce up and down, since both the generator and discriminator are changing over time. Tuning hyperparameters is also much more difficult, because we don't have the training curve to guide us. Newer GAN models like Wasserstein GAN tries to alleviate some of these issues, but are beyond the scope of this course.

To compound the difficulty of hyperparameter tuning GANs also take a long time to train. It is tempting to stop training early, but the effects of hyperparameters may not be noticable until later on in training.

You might have noticed in the images generated by our simple GAN that the model seems to only output a small number of digit types. This phenomenon is called mode collapse. A generator can optimize P(D correctly identifies image generated by G) by learning to generate one type of input (e.g. one digit) really well, and not learning how to generate any other digits at all!

To prevent mode collapse, newer variations of GANs provides the discriminator with a small set of either real or fake data, rather than one at a time. A discriminator would therefore be able to use the variety of the generated data as a feature to determine whether the entire small set of data is real or fake.

GAN Implementations

Batch Normalization

Since GANs take so much longer to train, GAN implementations tend to use more advanced layers than what we learned so far. Leaky ReLU is one of them, and the other is Batch Normalization.

We discussed why normalizing the input is generally a good idea. If each input neuron is roughly of the same scale, then we can use the same method to initialize all of our weights and biases.

But what about the hidden activations? That's the main idea behind batch normalization: we normalize the hidden activations to have mean 0 and standard deviation 1 (or some other value). At train time, we perform the normalization across each mini-batch, but also keep track of the means and standard deviations of the incoming activations. At test time we use the means and standard deviations learned during training to normalize the test input.

Here is an example Discriminator and Generator model that uses some of these ideas. (Not tested)

In [9]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16, 0.1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(16, 8, 3, stride=2, padding=1),
            nn.BatchNorm2d(8, 0.1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(8, 1, 7)
        )
    def forward(self, x):
        out = self.model(x)
        out = out.view(out.size(0))
        return out

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(100, 32, 7),
            nn.BatchNorm2d(32, 0.1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(16, 0.1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(16, 8, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(8, 0.1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(8, 1, 3, padding=1),
            nn.Sigmoid()
        )
    def forward(self, x):
        x = x.view(x.size(0), 100, 1, 1)
        out = self.model(x)
        return out