In [1]:
from torch.autograd import Variable
import torch
import numpy as np
import matplotlib.pyplot as plt



from scipy.io import loadmat



%matplotlib inline  

M = loadmat("mnist_all.mat")



def get_test(M):
    batch_xs = np.zeros((0, 28*28))
    batch_y_s = np.zeros( (0, 10))
    
    test_k =  ["test"+str(i) for i in range(10)]
    for k in range(10):
        batch_xs = np.vstack((batch_xs, ((np.array(M[test_k[k]])[:])/255.)  ))
        one_hot = np.zeros(10)
        one_hot[k] = 1
        batch_y_s = np.vstack((batch_y_s,   np.tile(one_hot, (len(M[test_k[k]]), 1))   ))
    return batch_xs, batch_y_s


def get_train(M):
    batch_xs = np.zeros((0, 28*28))
    batch_y_s = np.zeros( (0, 10))
    
    train_k =  ["train"+str(i) for i in range(10)]
    for k in range(10):
        batch_xs = np.vstack((batch_xs, ((np.array(M[train_k[k]])[:])/255.)  ))
        one_hot = np.zeros(10)
        one_hot[k] = 1
        batch_y_s = np.vstack((batch_y_s,   np.tile(one_hot, (len(M[train_k[k]]), 1))   ))
    return batch_xs, batch_y_s
        

train_x, train_y = get_train(M)
test_x, test_y = get_test(M)



train_x, train_y = get_train(M)
test_x, test_y = get_test(M)

dim_x = 28*28
dim_h = 20
dim_out = 10

dtype_float = torch.FloatTensor
dtype_long = torch.LongTensor



################################################################################
#Subsample the training set for faster training

train_idx = np.random.permutation(range(train_x.shape[0]))[:1000]
x = Variable(torch.from_numpy(train_x[train_idx]), requires_grad=False).type(dtype_float)
y_classes = Variable(torch.from_numpy(np.argmax(train_y[train_idx], 1)), requires_grad=False).type(dtype_long)
#################################################################################

x, y, and y_classes are PyTorch Variables. We will now define the neural network model. We'll use toch.nn.Sequential

In [2]:
model = torch.nn.Sequential(
    torch.nn.Linear(dim_x, dim_h),
    torch.nn.ReLU(),
    torch.nn.Linear(dim_h, dim_out),
)

model can be used to apply the network to Variable inputs.

We'll define a loss function using torch.nn.CrossEntropyLoss(), which computes the cross entropy from the logits of the class probability (i.e., outputs of the softmax) and the class labels (i.e., the average negative log-probability of the correct answer).

In [3]:
loss_fn = torch.nn.CrossEntropyLoss()

Training the model

We will now use Adam, a variant of Gradient Descent, to optimize the model

In [4]:
learning_rate = 1e-2
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(10000):
    y_pred = model(x)
    loss = loss_fn(y_pred, y_classes)
    
    model.zero_grad()  # Zero out the previous gradient computation
    loss.backward()    # Compute the gradient
    optimizer.step()   # Use the gradient information to 
                       # make a step

Now, let's make predictions for the test set

In [5]:
x = Variable(torch.from_numpy(test_x), requires_grad=False).type(dtype_float)
In [6]:
y_pred = model(x).data.numpy()

Let's now look at the performance:

In [7]:
np.mean(np.argmax(y_pred, 1) == np.argmax(test_y, 1))
Out[7]:
0.871

Let's explore the weights a little bit. We can access the weights like this:

In [8]:
model[0].weight
Out[8]:
Parameter containing:
-2.9768e-02 -6.8537e-03  7.4565e-03  ...  -1.6135e-02 -3.2539e-02 -4.8748e-03
 1.0848e-02 -1.5763e-03 -9.5682e-03  ...   3.1350e-02  2.1901e-02 -2.8717e-02
-1.1450e-02  5.4038e-03 -7.4841e-03  ...   1.8107e-02 -1.0602e-02  2.8305e-02
                ...                   ⋱                   ...                
-2.9840e-02  1.2695e-02  2.2651e-02  ...   3.2983e-02 -2.7947e-02  7.9728e-03
 2.1712e-02  4.0420e-03 -3.1576e-02  ...   2.1061e-02 -3.3500e-02  2.9000e-02
 2.7542e-02 -1.5136e-04 -4.1566e-03  ...  -3.2553e-02  1.8849e-02  1.5562e-02
[torch.FloatTensor of size 20x784]

Let's look at the weights associated with unit 10

In [9]:
model[0].weight.data.numpy()[10, :].shape
Out[9]:
(784,)
In [10]:
plt.imshow(model[0].weight.data.numpy()[10, :].reshape((28, 28)), cmap=plt.cm.coolwarm)
Out[10]:
<matplotlib.image.AxesImage at 0x7fa9d7033ba8>
In [11]:
plt.imshow(model[0].weight.data.numpy()[12, :].reshape((28, 28)), cmap=plt.cm.coolwarm)
Out[11]:
<matplotlib.image.AxesImage at 0x7fa9d6f54c88>