In [1]:
%matplotlib inline
    
from numpy import *
from matplotlib.pyplot import *
import scipy.stats



def p_star(w, x, y, sigma, sigma_w):
    loglik = sum(-.5*log(2*pi*sigma**2)-(dot(w, x)-y)**2/(2*sigma**2))
    log_pw = sum(-.5*log(2*pi*sigma_w**2)-w**2/(2*sigma_w**2))
    return exp(loglik+log_pw)


random.seed(0)
w = array([-3, 1.5])


sigma_w = 6

N = 20
sigma = 2.2
x_raw = 10*(random.random((N))-.5)
x = vstack((    ones_like(x_raw),
                    x_raw,
                 ))
                
y = dot(w, x) + scipy.stats.norm.rvs(scale= sigma,size=N)
#scatter(x_raw,y)


sigma_jump = 1.0


n_iter = 2000
n_iter_burn_in = 1000
ws = zeros((n_iter, w.shape[0]))
w = random.random(w.shape)
for i in range(n_iter):
    figure(1)
    w_prime = w + scipy.stats.norm.rvs(scale= sigma_jump,size=w.shape[0])
    
    p_accept = p_star(w_prime, x, y, sigma, sigma_w)/p_star(w, x, y, sigma, sigma_w)
    
    if p_accept > 1:
        w = w_prime
    else:
        if random.random() < p_accept:
            w = w_prime
    
    ws[i,:] = w
    
    scatter(ws[:i-2,0], ws[:i-2,1], c = "b")
    scatter(ws[i-1,0], ws[i-1,1], c = "y")
    scatter(w_prime[0], w_prime[1], c = "r")
    
    xlim(-6, 6)
    ylim(-6, 6)
    title("p_accept = %f" % (p_accept))
    savefig("%.6d.png" % (i))
    close('all')
/usr/lib/python2.7/dist-packages/matplotlib/font_manager.py:273: UserWarning: Matplotlib is building the font cache using fc-list. This may take a moment.
  warnings.warn('Matplotlib is building the font cache using fc-list. This may take a moment.')
In [2]:
scatter(ws[:,0], ws[:,1], c = "b")
n_burn_in = 1000
scatter(ws[n_burn_in:2000,0], ws[n_burn_in:2000,1], c = "y") 
Out[2]:
<matplotlib.collections.PathCollection at 0x7f482e690f10>
In [3]:
hist(ws[logical_and(ws[:, 1]> 2, ws[:,1]<2.2),0])    
Out[3]:
(array([  5.,   0.,   0.,  16.,  33.,   4.,   0.,   1.,   1.,  10.]),
 array([-3.9468499 , -3.8067327 , -3.66661551, -3.52649831, -3.38638112,
        -3.24626392, -3.10614672, -2.96602953, -2.82591233, -2.68579513,
        -2.54567794]),
 <a list of 10 Patch objects>)
In [4]:
pred_y = mean(dot(ws, x), 0)
scatter(x_raw, pred_y)    
Out[4]:
<matplotlib.collections.PathCollection at 0x7f482e546310>