Thus far, we have discussed several generative models. A generative model learns the structure of a set of input data. In doing so, the model learns to generate new data that it has never seen before in the training data. The generative models we discussed were:
A Generative Adversarial Network (GAN) is yet another example of a generative model.
%matplotlib inline
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from torchvision import datasets, transforms
mnist_data = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())
A generative adversarial network (GAN) model consists of two models:
In essense, we have two neural networks that are adversaries: the generator wants to fool the discriminator, and the discriminator wants to avoid being fooled.
Let's set up a simple generator and a discriminator to start:
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(28*28, 300),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(300, 100),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(100, 1)
)
def forward(self, x):
x = x.view(x.size(0), -1)
out = self.model(x)
return out.view(x.size(0))
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(100, 300),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(300, 28*28),
nn.Sigmoid()
)
def forward(self, x):
out = self.model(x).view(x.size(0), 1, 28, 28)
return out
For now, both the Discriminator and Generator are fully-connected networks.
One difference between these models and the previous models we've built is
that we are using a nn.LeakyReLU
activation.
Actually, you have seen leaky ReLU activations before, from the
very beginning of the course, in assignment 1! Leaky ReLU is a variation
of the ReLU activation that lets some information through, even when
its input is less than 0. The layer nn.LeakyReLU(0.2, inplace=True)
performs the computation: x if x > 0 else 0.2 * x
.
But what loss function should we optimize? Consider the following quantity:
P(D correctly identifies real image) + P(D correctly identifies image generated by G)
A good discriminator would want to maximize the above quanity by altering its parameters.
Likewise, a good generator would want to minimize the above quanity. Actually,
the only term that the generator controls is P(D correctly identifies image generated by G)$
So, the best thing for the generator to do is alter its parameters to generate images
that can fool D.
Since we are looking at class probabilities, we will use binary cross entropy loss.
Here is a rudimentary training loop to train a GAN. For every minimatch of data, we train the discriminator for one iteration, and then we train the generator for one iteration.
For the discriminator, we use the label 1
to represent a fake image, and 0
to represent
a real image.
def train(generator, discriminator, lr=0.001, num_epochs=5):
criterion = nn.BCEWithLogitsLoss()
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
train_loader = torch.utils.data.DataLoader(mnist_data, batch_size=100, shuffle=True)
num_test_samples = 16
test_noise = torch.randn(num_test_samples, 100)
for epoch in range(num_epochs):
# label that we are using both models
generator.train()
discriminator.train()
for n, (images, _) in enumerate(train_loader):
# === Train the Discriminator ===
noise = torch.randn(images.size(0), 100)
fake_images = generator(noise)
inputs = torch.cat([images, fake_images])
labels = torch.cat([torch.zeros(images.size(0)), # real
torch.ones(images.size(0))]) # fake
d_outputs = discriminator(inputs)
d_loss = criterion(d_outputs, labels)
d_loss.backward()
d_optimizer.step()
d_optimizer.zero_grad()
# === Train the Generator ===
noise = torch.randn(images.size(0), 100)
fake_images = generator(noise)
outputs = discriminator(fake_images)
g_loss = criterion(outputs, torch.zeros(images.size(0)))
g_loss.backward()
g_optimizer.step()
g_optimizer.zero_grad()
scores = torch.sigmoid(d_outputs)
real_score = scores[:images.size(0)].data.mean()
fake_score = scores[images.size(0):].data.mean()
print('Epoch [%d/%d], d_loss: %.4f, g_loss: %.4f, '
'D(x): %.2f, D(G(z)): %.2f'
% (epoch + 1, num_epochs, d_loss.item(), g_loss.item(), real_score, fake_score))
# plot images
generator.eval()
discriminator.eval()
test_images = generator(test_noise)
plt.figure(figsize=(9, 3))
for k in range(16):
plt.subplot(2, 8, k+1)
plt.imshow(test_images[k,:].data.numpy().reshape(28, 28), cmap='Greys')
plt.show()
Let's try training the network.
discriminator = Discriminator()
generator = Generator()
#train(generator, discriminator, lr=0.001, num_epochs=20)
GANs are notoriously difficult to train. One difficulty is that a training curve is no longer as helpful as it was for a supervised learning problem! The generator and discriminator losses tend to bounce up and down, since both the generator and discriminator are changing over time. Tuning hyperparameters is also much more difficult, because we don't have the training curve to guide us. Newer GAN models like Wasserstein GAN tries to alleviate some of these issues, but are beyond the scope of this course.
To compound the difficulty of hyperparameter tuning GANs also take a long time to train. It is tempting to stop training early, but the effects of hyperparameters may not be noticable until later on in training.
You might have noticed in the images generated by our simple GAN that the model seems to only
output a small number of digit types. This phenomenon is called mode collapse. A
generator can optimize P(D correctly identifies image generated by G)
by learning
to generate one type of input (e.g. one digit) really well, and not learning how to
generate any other digits at all!