{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Transpose Convolutions and Autoencoders\n",
"\n",
"In our discussions of convolutional networks, we always started with an image, then reduced\n",
"the \"resolution\" of the image until we made an image-level prediction. Specifically, we focused\n",
"on image level classification problems: is the image of a cat or a dog? Which of the 10 digits\n",
"does this image represent? We always went from a large image, to a lower-dimensional representation\n",
"of the image.\n",
"\n",
"Some tasks require us to go in the opposite direction. For example, we may wish to make pixel-wise\n",
"predictions about the content of each pixel in an image. Is this pixel part of the foreground\n",
"or the background? Is this pixel a part of a car or a pedestrian? Problems that require us to\n",
"label each pixel is called a pixel-wise prediction problem. These problems require us to\n",
"produce an high-resolution \"image\" from a low-dimensional representation of its contents.\n",
"\n",
"A similar task is the task of *generating* an image given a low-dimensional *embedding* of the\n",
"image. For example, we may wish to produce a neural network model that *generates* images of\n",
"hand-written digits not in the MNIST data set. A neural network model that learns to\n",
"generate new examples of data is called a **generative model**.\n",
"\n",
"In both cases, we need a way to *increase* the resolution of our hidden units.\n",
"We need something akin to convolution, but that goes in the *opposite* direction.\n",
"We will use something called a **transpose convolution**. Transpose convolutions were first\n",
"called *deconvolutions*, since it is the ``inverse'' of a convolution operation. However,\n",
"the terminology was confusing since it has nothing to do with the mathematical notion of\n",
"deconvolution."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"import matplotlib.pyplot as plt\n",
"from torchvision import datasets, transforms\n",
"\n",
"mnist_data = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())\n",
"mnist_data = list(mnist_data)[:4096]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Convolution Transpose\n",
"\n",
"\n",
"We can make a tranpose convolution layer in PyTorch like this:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"convt = nn.ConvTranspose2d(in_channels=16,\n",
" out_channels=8,\n",
" kernel_size=5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here's an example of this convolution transpose operation in action:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x = torch.randn(32, 16, 64, 64)\n",
"y = convt(x)\n",
"y.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Notice that the width and height of `y` is `68x68`, because the `kernel_size` is 5\n",
"and we have not added any padding. You can verify that if we start with a tensor\n",
"with resolution `68x68` and applied a `5x5` convolution, we would end up with\n",
"a tensor with resolution `64x64`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"conv = nn.Conv2d(in_channels=8,\n",
" out_channels=16,\n",
" kernel_size=5)\n",
"y = torch.randn(32, 8, 68, 68)\n",
"x = conv(y)\n",
"x.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As before, we can add a padding to our convolution transpose, just like we added\n",
"padding to our convolution operations:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"convt = nn.ConvTranspose2d(in_channels=16,\n",
" out_channels=8,\n",
" kernel_size=5,\n",
" padding=2)\n",
"x = torch.randn(32, 16, 64, 64)\n",
"y = convt(x)\n",
"y.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"More interestingly, we can add a stride to the convolution to increase our resolution!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"convt = nn.ConvTranspose2d(in_channels=16,\n",
" out_channels=8,\n",
" kernel_size=5,\n",
" stride=2,\n",
" output_padding=1, # needed because stride=2\n",
" padding=2)\n",
"x = torch.randn(32, 16, 64, 64)\n",
"y = convt(x)\n",
"y.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Our resolution has doubled.\n",
"\n",
"But what is actually happening? Essentially, we are adding a padding of zeros\n",
"in between every row and every column of `x`. In the picture below,\n",
"the blue image below represents the input `x` to the convolution transpose, and \n",
"the green image above represents the output `y`.\n",
"\n",
"\n",
"\n",
"You can check https://github.com/vdumoulin/conv_arithmetic for more pictures\n",
"and links to descriptions.\n",
"\n",
"## Autoencoder\n",
"\n",
"To demonstrate the use of convolution transpose operations,\n",
"we will build something called an **autoencoder**.\n",
"\n",
"An autoencoder is a network that learns an alternate\n",
"*representations* of some data, for example a set of images.\n",
"It contains two components:\n",
"\n",
"- An **encoder** that takes an image as input, and \n",
" outputs a low-dimensional embedding (representation)\n",
" of the image.\n",
"- A **decoder** that takes the low-dimensional embedding,\n",
" and reconstructs the image.\n",
"\n",
"Beyond *dimension reduction*, an autoencoder is\n",
"a **generative model**. It can **generate** new images\n",
"not in the training set!\n",
"\n",
"An autoencoder is typically shown like below:\n",
"(image from https://hackernoon.com/how-to-autoencode-your-pok%C3%A9mon-6b0f5c7b7d97 )\n",
"\n",
"\n",
"\n",
"Here is an example of a *convolutional* autoencoder:\n",
"an autoencoder that uses solely convolutional layers:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Autoencoder(nn.Module):\n",
" def __init__(self):\n",
" super(Autoencoder, self).__init__()\n",
" self.encoder = nn.Sequential(\n",
" nn.Conv2d(1, 16, 3, stride=2, padding=1),\n",
" nn.ReLU(),\n",
" nn.Conv2d(16, 32, 3, stride=2, padding=1),\n",
" nn.ReLU(),\n",
" nn.Conv2d(32, 64, 7)\n",
" )\n",
" self.decoder = nn.Sequential(\n",
" nn.ConvTranspose2d(64, 32, 7),\n",
" nn.ReLU(),\n",
" nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),\n",
" nn.ReLU(),\n",
" nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1),\n",
" nn.Sigmoid()\n",
" )\n",
"\n",
" def forward(self, x):\n",
" x = self.encoder(x)\n",
" x = self.decoder(x)\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Notice that the final activation on the decoder is a *sigmoid*\n",
"activation. The reason is that we normalize the pixels\n",
"to be between 0 and 1. A sigmoid gives us results in the \n",
"same range.\n",
"\n",
"## Training an Autoencoder\n",
"\n",
"How do we train an autoencoder? How do we know what\n",
"kind of \"encoder\" and \"decoder\" we want?\n",
"\n",
"For one, if we pass an image through the encoder,\n",
"then pass the result through the decoder, we should get\n",
"roughly the same image back. Ideally, reducing the \n",
"dimensionality and then generating the image should\n",
"give us the same result!\n",
"\n",
"We use a loss function called `MSELoss`, which\n",
"computes the square error at every pixel.\n",
"\n",
"Beyond using a different loss function, the training \n",
"scheme is roughly the same. Note that in the code below,\n",
"we are using a new optimizer called `Adam`.\n",
"\n",
"We switched to this optimizer not because it is specifically\n",
"used for autoencoders, but because this is the optimizer that\n",
"people tend to use in practice for convolutional neural\n",
"networks. Feel free to use Adam for your other convolutional\n",
"networks.\n",
"\n",
"We are also saving the reconstructed images of the last\n",
"iteration in every epoch. We want to look at these reconstructions\n",
"at the end of training."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def train(model, num_epochs=5, batch_size=64, learning_rate=1e-3):\n",
" torch.manual_seed(42)\n",
" criterion = nn.MSELoss()\n",
" optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)\n",
" train_loader = torch.utils.data.DataLoader(mnist_data, batch_size=batch_size, shuffle=True)\n",
" outputs = []\n",
" for epoch in range(num_epochs):\n",
" for data in train_loader:\n",
" img, label = data\n",
" recon = model(img)\n",
" loss = criterion(recon, img)\n",
" loss.backward()\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
"\n",
" print('Epoch:{}, Loss:{:.4f}'.format(epoch+1, float(loss)))\n",
" outputs.append((epoch, img, recon),)\n",
" return outputs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we can train this network."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = Autoencoder()\n",
"max_epochs = 20\n",
"outputs = train(model, num_epochs=max_epochs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The loss goes down as we train, meaning that our reconstructed images look more\n",
"and more like the actual images!\n",
"\n",
"Let's look at the training progression: that is, the reconstructed images at\n",
"various points of training:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for k in range(0, max_epochs, 5):\n",
" plt.figure(figsize=(9, 2))\n",
" imgs = outputs[k][1].detach().numpy()\n",
" recon = outputs[k][2].detach().numpy()\n",
" for i, item in enumerate(imgs):\n",
" if i >= 9: break\n",
" plt.subplot(2, 9, i+1)\n",
" plt.imshow(item[0])\n",
" \n",
" for i, item in enumerate(recon):\n",
" if i >= 9: break\n",
" plt.subplot(2, 9, 9+i+1)\n",
" plt.imshow(item[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"At first, the reconstructed images look nothing like the originals. Rather, the\n",
"reconstructions look more like the average of some training images.\n",
"As training progresses, our reconstructions are clearer.\n",
"\n",
"## Structure in the Embeddings\n",
"\n",
"Since we are drastically reducing the dimensionality of the image, there has to be\n",
"some kind of structure in the embedding space. That is, the network should be able\n",
"to \"save\" space by mapping similar images to similar embeddings.\n",
"\n",
"We will demonstrate the structure of the embedding space by hving\n",
"some fun with our autoencoders. Let's begin with two images in our training set.\n",
"For now, we'll choose images of the same digit."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"imgs = outputs[max_epochs-1][1].detach().numpy()\n",
"plt.subplot(1, 2, 1)\n",
"plt.imshow(imgs[0][0])\n",
"plt.subplot(1, 2, 2)\n",
"plt.imshow(imgs[8][0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will then compute the **low-dimensional embeddings** of both images,\n",
"by applying the **encoder**:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x1 = outputs[max_epochs-1][1][0,:,:,:] # first image\n",
"x2 = outputs[max_epochs-1][1][8,:,:,:] # second image\n",
"x = torch.stack([x1,x2]) # stack them together so we only call `encoder` once\n",
"embedding = model.encoder(x)\n",
"e1 = embedding[0] # embedding of first image\n",
"e2 = embedding[1] # embedding of second image"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we will do something interesting. Not only are we goign to run the\n",
"decoder on those two embeddings `e1` and `e2`, we are also going to **interpolate**\n",
"between the two embeddings and decode those as well!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"embedding_values = []\n",
"for i in range(0, 10):\n",
" e = e1 * (i/10) + e2 * (10-i)/10\n",
" embedding_values.append(e)\n",
"embedding_values = torch.stack(embedding_values)\n",
"\n",
"recons = model.decoder(embedding_values)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's plot the reconstructions of each interpolated values.\n",
"The original images are shown below too:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(10, 2))\n",
"for i, recon in enumerate(recons.detach().numpy()):\n",
" plt.subplot(2,10,i+1)\n",
" plt.imshow(recon[0])\n",
"plt.subplot(2,10,11)\n",
"plt.imshow(imgs[8][0])\n",
"plt.subplot(2,10,20)\n",
"plt.imshow(imgs[0][0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Notice that there is a smooth transition between the two images!\n",
"The middle images are likely new, in that there are no training images\n",
"that are exactly like any of the generated images.\n",
"\n",
"As promised, we can do the same thing with two images containing\n",
"different digits. There should be a smooth transition between\n",
"the two digits."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def interpolate(index1, index2):\n",
" x1 = mnist_data[index1][0]\n",
" x2 = mnist_data[index2][0]\n",
" x = torch.stack([x1,x2])\n",
" embedding = model.encoder(x)\n",
" e1 = embedding[0] # embedding of first image\n",
" e2 = embedding[1] # embedding of second image\n",
"\n",
"\n",
" embedding_values = []\n",
" for i in range(0, 10):\n",
" e = e1 * (i/10) + e2 * (10-i)/10\n",
" embedding_values.append(e)\n",
" embedding_values = torch.stack(embedding_values)\n",
"\n",
" recons = model.decoder(embedding_values)\n",
"\n",
" plt.figure(figsize=(10, 2))\n",
" for i, recon in enumerate(recons.detach().numpy()):\n",
" plt.subplot(2,10,i+1)\n",
" plt.imshow(recon[0])\n",
" plt.subplot(2,10,11)\n",
" plt.imshow(x2[0])\n",
" plt.subplot(2,10,20)\n",
" plt.imshow(x1[0])\n",
"\n",
"interpolate(0, 1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"interpolate(1, 10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"interpolate(4, 5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"interpolate(20, 30)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}