% [P,Pi,llh,code,it,grad_Pi,grad_P] = mix_bernoulli(X,M[,tol,max_it,P0,Pi0])
%
% EM-estimation of mixture of Bernoulli distributions
%
% In:
%   X: NxD binary data matrix.
%   M: number of mixtures.
%   tol: minimum relative increase in log-likelihood to keep iterating
%      (default 0.0001 = 4 decimal places).
%   max_it: maximum number of iterations (default 400).
%   P0: MxD matrix containing the initial Bernoulli probabilities.
%   Pi0: Mx1 vector containing the initial mixing parameters.
% Out:
%   P: MxD matrix containing the Bernoulli probabilities.
%   Pi: Mx1 vector containing the mixing parameters.
%   llh: log-likelihood curve of the parameters P and Pi given the
%      data as a function of the iteration number.
%   code: stopping reason (0: tolerance achieved; 1: maximum number of
%      iterations reached; 3: log-likelihood decreases during learning,
%      probably due to precision loss near maximum).
%   it: iterations performed.
%   grad_Pi: gradient of the log-likelihood wrt the Pi parameters.
%   grad_P: gradient of the log-likelihood wrt the P parameters.
%
% See also mix_bernoulli_llh_error, mix_bernoulli_sample,
% mix_bernoulli_distrib.

% Copyright (c) 1997 by Miguel A. Carreira-Perpinan

function [P,Pi,llh,code,it,grad_Pi,grad_P] = mix_bernoulli(X,M,tol,max_it,P0,Pi0)

[N,D] = size(X);

% Single component case (M=1) has an analytic solution
% I treat it separately because if R is Nx1 then sum(R') gives a scalar
if M==1
  Pi=1;
  P=mean(X);
  llh=N*sum(log(P.^P)+log((1-P).^(1-P)));
  code=0;
  it=0;
  grad_Pi=0;
  grad_P=0;
  return;
end

% Argument defaults
if nargin==2 tol=0.0001; end;
if nargin<=3 max_it=400; end;

% Starting point
if nargin<=4                 % Bernoulli probabilities P
  rand('seed',sum(100*clock));
  P = rand(M,D)*0.5+0.25;
  % Initialisation to random values: P = rand(M,D); may be unlucky
  % enough to hit a point of minus infinity log-likelihood for the
  % given sample (extremely unlikely, though, but...).
  % Initialisation to 0.5: P = ones(M,D)/2; produces convergence to a
  % mixture of one Bernoulli with p=E{X} (p parameter equal to the
  % sample mean of X), mixing proportions equal to the initial mixing
  % proportions and log-likelihood = N*log p(E{X}|p=E{X}). This is a
  % suboptimal maximum.
else
  P = P0;
end;
if nargin<=5                 % Mixing proportions Pi
  Pi = ones(M,1)/M;
  % Pi = rand(M,1); Pi = Pi/sum(Pi);
else
  Pi = Pi0;
end;
if max_it<1
  code = 1;
  it = 0;
else
  code = -1;
  it = 1;
end

% B(n,i) = p(x_n | i), the probability that the nth data point was
% generated by the ith mixture component.
% Note that B(n,i) is of the order 0.5^D which can be very small for
% large D, perhaps leading to cancellation.
B = zeros(N,M);
% This way is slow in Matlab:
% $$$ for n=1:N
% $$$   for i=1:M
% $$$     temp=1;
% $$$     for d=1:D
% $$$       if X(n,d)==0
% $$$ 	temp=temp*(1-P(i,d));
% $$$       else
% $$$ 	temp=temp*P(i,d);
% $$$       end
% $$$     end
% $$$     B(n,i) = temp;
% $$$   end
% $$$ end
% This is faster:
% $$$ temp=zeros(size(P));
% $$$ for n=1:N
% $$$   for d=1:D
% $$$     if X(n,d)==0
% $$$       temp(:,d)=1-P(:,d);
% $$$     else
% $$$       temp(:,d)=P(:,d);
% $$$     end
% $$$   end
% $$$   for i=1:M
% $$$     B(n,:)=prod(temp');
% $$$   end
% $$$ end
% This is the fastest (even though apparently it consumes more flops...):
for n=1:N
  for i=1:M
    B(n,i)=prod(P(i,:).^X(n,:)) * prod((1-P(i,:)).^(1-X(n,:)));
  end
end

% Log-likelihood of starting point
% llh = log(prod(B*Pi)) is faster but leads to llh=-Inf due to precision loss
llh = sum(log(B*Pi));

while code<0
  llh_old = llh(length(llh));
  
  % E step: posterior probabilities or "responsibilities" R(n,i)=p(i|x_n)
  R = cprod(B',Pi)'; R = cdiv(R,sum(R')');
  
  % M step: new parameter values
  Pi = sum(R)'/N;
  P = cdiv(R'*X/N,Pi);
  
  % New log-likelihood
  for n=1:N
    for i=1:M
      B(n,i)=prod(P(i,:).^X(n,:)) * prod((1-P(i,:)).^(1-X(n,:)));
    end
  end
  llh = [llh sum(log(B*Pi))];

  % Check whether exit condition is met
  if llh(length(llh))<llh_old
    code = 3;			% Log-likelihood not monotonic
  elseif abs(llh(length(llh))-llh_old)<tol*abs(llh(length(llh)))
    code = 0;			% Relative error < tol => Tolerance achieved
  elseif it>=max_it
    code = 1;			% Max. no. iterations reached
  else
    it = it + 1;		% Continue iterating
  end
end

% Gradient of the log-likelihood. For random Pi and P, the gradient
% wrt Pi is of order N and the gradient wrt P is of order N/M. Thus,
% near a stationary point of the log-likelihood norm(grad_Pi)/N and
% norm(grad_P)*M/N must be much smaller than 1.
% Note that the computation of the gradient wrt P involves dividing by 
% P_id or 1-P_id, which can raise an exception if any P_id =~ 0 or 1.
% To avoid this, those terms are disregarded, since the corresponding
% numerator should vanish as well (although in some cases it won't due 
% to precision loss).

if nargout > 5
  R = cprod(B',Pi)'; R = cdiv(R,sum(R')');
  grad_Pi = sum(R)'./Pi - N;
  if nargout > 6
    num = R'*X - P.*(sum(R)'*ones(1,D));
    den = P.*(1-P);
    for i=1:M
      for d=1:D
        if den(i,d)<1e-4
          den(i,d)=1;
          num(i,d)=0;          % This should vanish if the denominator vanishes
        end
      end
    end
    grad_P = num./den;
  end
end
