Maximum Likelihood for Bernoulli with PyTorch

Let's say that we have 100 samples from a Bernoulli distribution:

In [1]:
import torch
import numpy as np

from torch.autograd import Variable

sample = np.array([ 1.,  1.,  0.,  1.,  1.,  1.,  0.,  1.,  1.,  1.,  1.,  1.,  1.,
        1.,  0.,  1.,  0.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        1.,  1.,  1.,  0.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        1.,  1.,  1.,  0.,  1.,  1.,  1.,  0.,  1.,  1.,  0.,  1.,  1.,
        1.,  0.,  1.,  0.,  1.,  1.,  0.,  1.,  1.,  0.,  1.,  1.,  1.,
        1.,  1.,  1.,  1.,  0.,  0.,  1.,  1.,  1.,  1.,  0.,  0.,  1.,
        0.,  1.,  1.,  1.,  1.,  0.,  0.,  1.,  1.,  1.,  1.,  0.,  1.,
        0.,  1.,  1.,  0.,  1.,  0.,  1.,  0.,  0.,  1.,  1.,  1.,  0.,
        0.,  1.,  1.,  1.,  1.,  0.,  1.,  1.,  1.,  1.,  1.,  0.,  1.,
        0.,  1.,  1.,  1.,  1.,  1.,  0.,  1.,  0.,  0.,  1.,  1.,  1.,
        0.,  1.,  1.,  1.,  0.,  1.,  1.,  1.,  0.,  0.,  1.,  1.,  1.,
        1.,  1.,  0.,  0.,  1.,  1.,  0.,  1.,  1.,  0.,  1.,  0.,  1.,
        1.,  1.,  1.,  0.,  1.,  0.,  1.,  1.,  1.,  1.,  0.,  0.,  1.,
        0.,  1.,  1.,  1.,  1.,  1.,  1.,  0.,  1.,  1.,  1.,  0.,  1.,
        1.,  1.,  1.,  0.,  1.,  1.,  1.,  0.,  0.,  0.,  1.,  1.,  1.,
        1.,  0.,  1.,  0.,  1.])
In [2]:
np.mean(sample)
Out[2]:
0.72499999999999998

Let's now define the probability p of generating 1, and put the sample into a PyTorch Variable:

In [3]:
x = Variable(torch.from_numpy(sample)).type(torch.FloatTensor)
p = Variable(torch.rand(1), requires_grad=True)

We are ready to learn the model using maximum likelihood:

In [4]:
learning_rate = 0.00002
for t in range(1000):
    NLL = -torch.sum(torch.log(x*p + (1-x)*(1-p)) )
    NLL.backward()

    if t % 100 == 0:
        print("loglik  =", NLL.data.numpy(), "p =", p.data.numpy(), "dL/dp = ", p.grad.data.numpy())

    
    p.data -= learning_rate * p.grad.data
    p.grad.data.zero_()
loglik  = [ 190.32196045] p = [ 0.30976322] dL/dp =  [-388.41665649]
loglik  = [ 120.36953735] p = [ 0.64752221] dL/dp =  [-67.89233398]
loglik  = [ 117.7012558] p = [ 0.71330011] dL/dp =  [-11.442276]
loglik  = [ 117.63499451] p = [ 0.72342253] dL/dp =  [-1.576828]
loglik  = [ 117.63378143] p = [ 0.72479147] dL/dp =  [-0.20907593]
loglik  = [ 117.63375854] p = [ 0.72497255] dL/dp =  [-0.02752686]
loglik  = [ 117.63375092] p = [ 0.72499627] dL/dp =  [-0.0037384]
loglik  = [ 117.63375854] p = [ 0.72499853] dL/dp =  [-0.00146484]
loglik  = [ 117.63375854] p = [ 0.72499853] dL/dp =  [-0.00146484]
loglik  = [ 117.63375854] p = [ 0.72499853] dL/dp =  [-0.00146484]