function [mu, sigma, g, m, a, b, cost] = mcvq_mpi(X, K, J, iterations, varargin) %[mu, sigma, g, m, a, b, cost] = mcvq(X, K, J, iterations, ...) % % Train an MCVQ model on data set X, with the specified parameters. % Note that X can have missing entries, which are handled by simply % not including the missing values in the EM updates. Missing entries % are indicated by making X a sparse matrix, and setting the missing % values to 0. % % Required Arguments: % X = data set, data vectors in columns % K = number of VQ's % J = number of vectors per VQ % iterations = number of EM iterations to perform % % Optional Arguments: % 'tie r' = tie the selection of VQ's across all data vectors % 'tie g' = tie the (approx.) posterior probabilities of VQ's % 'adapt priors' = use non-uniform priors, that are updated % each M-step, for the g's and m's % 'hyper priors' = use Dirichlet hyper-priors, to encourage % the a's and b's to have low and high entropy respectively % % Return Values: % mu = % sigma = % g = approx. posterior prob. of each VQ % m = approx. posterior prob. of each vector of each VQ % a = prior probabilities of VQ's % b = prior probabilities for each vector of each VQ % cost = value of the cost function (free energy) % after each iteration of EM % % Author: David Ross % $Id: mcvq_mpi.m,v 1.3 2006/11/09 23:11:09 dross Exp $ %--------------------------------------------------- % Check the input %--------------------------------------------------- error(nargchk(4,4,nargin - length(varargin))); if ndims(X) ~= 2 error('X matrix must be 2 dimensional') end if K <= 0 | J <= 0 error('numbers of VQ''s and components must be positive') end if iterations <= 0 error('number of iterations must be positive') end % now check the options g_tied = 0; r_tied = 0; adapt_priors = 0; hyper_priors = 0; for i = 1:length(varargin) if strcmpi(varargin{i}, 'tie g') g_tied = 1; elseif strcmpi(varargin{i}, 'tie r') r_tied = 1; g_tied = 1; elseif strcmpi(varargin{i}, 'adapt priors') adapt_priors = 1; elseif strcmpi(varargin{i}, 'hyper priors') hyper_priors = 1; adapt_priors = 1; else disp('unknown argument:'); disp(varargin{i}); error(['unknown argument']); end end % if cost is not a desired output, then don't calcualte it calculate_cost = (nargout >= 7); % if the data set, X, is sparse, then the 0 entries are assumed % to represent missing values missing_data = 0; if issparse(X) missing_data = 1; missing_mask = (X ~= 0); end % display the options that MCVQ is running with disp('===================='); disp(sprintf('J=%d, K=%d',J,K)); if r_tied; disp('r tied across cases'); end if g_tied; disp('g tied across cases'); end if adapt_priors; disp('adapting priors'); end if hyper_priors; disp('using hyper-priors'); end %--------------------------------------------------- % Initialization %--------------------------------------------------- %MPI: Initialize MPI. MPI_Init; % Initialize MPI. comm = MPI_COMM_WORLD; % Create communicator. comm_size = MPI_Comm_size(comm); % Get size and rank. my_rank = MPI_Comm_rank(comm); SOURCE = 0; %MPI: take only a fraction of the data set C_total = size(X,2); X = X(:,(1+my_rank):comm_size:end); N = size(X,1); % dimensionality of input vectors C = size(X,2); % number of training examples % an alternative way to compute C_total %C_total = MPI_Reduce(SOURCE, intmax, comm, @MPI_SUM, C); %C_total = MPI_Bcast(SOURCE, intmax-1, comm, C_total); sigma_min = 0.01; % the minimum possible value of sigma disp(sprintf('sigma_min = %g',sigma_min)); disp('===================='); %% hyper priors alpha = 0.5 * ones(K,1); % low entropy beta = 6 * ones(J,1); % high entropy %MPI: only the source initializes the parameters if my_rank == SOURCE %% priors are initialized to samples from the hyper-priors, % or set to uniform if hyper-priors are not used if hyper_priors a = zeros(N,K); for i = 1:N a(i,:) = dirichletrnd(alpha)'; end b = zeros(J*K,1); for k = 1:K b((k-1)*J+1:k*J) = dirichletrnd(beta); end else a = ones(N,K) / K; b = ones(J*K,1) / J; end %% mu's are randomly chosen from the data, using the following algorithm % % Selection stores the integers from 1 to C, which are randomly permuted. % The first J*K elements of selection indicate which training examples % are chosen as our initial mu's. % % Note: if C is less than J*K, then, before permuting, we give selection % ceil(J*K/C) copies of each integer from 1 to C. mu = zeros(N, J*K); selection = repmat(1:C, [1 ceil(J*K/C)]); selection = selection(randperm(length(selection))); mu = full(X(:,selection(1:J*K))); clear selection; %% sigma initialized randomly in the interval [0.5 1] sigma = rand(N,J*K) / 2 + 0.5; % now make sure Sigma is at least sigma_min sigma(find(sigma < sigma_min)) = sigma_min; else a = []; b = []; mu = []; sigma = []; end % fill in the missing entries in Mu, if necessary if missing_data data_mean = MPI_Reduce(SOURCE, 0, comm, @MPI_SUM, full(sum(X,2))); data_mean = data_mean / C_total; if my_rank == SOURCE for jk = 1:(J*K) mu(:,jk) = mu(:,jk) + ... (data_mean + randn(N,1)*sigma_min) .* (mu(:,jk)==0); end end clear data_mean; end %MPI make sure all workers have the params [a, b, mu, sigma] = MPI_Bcast(SOURCE, 0, comm, a, b, mu, sigma); %% m's are updated first, so we can initialize them to 0 m = zeros(J*K,C); %% g's @TODO these should be sampled from the priors if g_tied g = a; else g = repmat(a,[1 1 C]); end %% empty cost vector cost = []; %--------------------------------------------------- % Main Loop %--------------------------------------------------- for it = 1:iterations tic; %------------------------------------------------------------------- % E-Step %------------------------------------------------------------------- g_old = g; g = repmat(log(a), [1 1 C^(~g_tied)]); log_sigma = log(sigma); two_sigma_sq = 2 * sigma.^2; for c = 1:C %d = log(sigma) + (mu - repmat(X(:,c),[1 J*K])).^2 ./ (2 * sigma.^2); d = log_sigma + (mu - repmat(X(:,c),[1 J*K])).^2 ./ two_sigma_sq; if missing_data d = d .* repmat(missing_mask(:,c), [1 J*K]); end %---- Update m ----% % big_G is G(N by K) replicated to make it of size (N by JK) % If g's are tied, then there is only one big_G for all c's, % otherwise we need to update big_G for each c if (~g_tied) | (c == 1) big_G = replicate_g(g_old(:,:,c),J); end m(:,c) = log(b) - dot(big_G, d, 1)'; m_reshaped = reshape(m(:,c), [J K]); m_reshaped = softmax(m_reshaped); m(:,c) = reshape(m_reshaped, [J*K 1]); %---- Update g ----% % c_for_g = 1 if g is tied, c if g is untied c_for_g = c^(~g_tied); % @TODO if this is really the E-step, then we shouldn't % recompute d, otherwise if it's the M-step, then this is okay g(:,:,c_for_g) = g(:,:,c_for_g) - (1/C_total)^(~r_tied & g_tied) * ... (repmat(m(:,c)',[N 1]) .* d) * kron(eye(K),ones(J,1)); if ~mod(c,10), fprintf(1, 'done E-step %d of %d\r', c, C); end end if g_tied | r_tied g = MPI_Reduce(SOURCE, 1e5+it, comm, @MPI_SUM, g); end % What we want to do is "g = softmax(g')'", but we can't because % we have a third index. The following code does the same thing. g = softmax(permute(g,[2 1 3])); g = ipermute(g,[2 1 3]); if g_tied | r_tied g = MPI_Bcast(SOURCE, 2e5+it, comm, g); end %------------------------------------------------------------------- % M-Step %------------------------------------------------------------------- if adapt_priors %MPI: we'll need sum(g,3) from all workers for a and sum(m,2) for b [g_sum, m_sum] = MPI_Reduce(SOURCE, 3e5+it, comm, @MPI_SUM, ... sum(g,3), sum(m,2)); %---- Update a ----% % a = mean(g,3); is basically what we're doing, unless % hyperpriors are used, which makes things messy a = C_total^(~r_tied & g_tied) * g_sum + ... hyper_priors * repmat(alpha'-1, [N 1]); a = a / (C_total^(~r_tied) + hyper_priors * (sum(alpha) - K)); % check to see if any a's are going to zero if any(any(a<(1/(10*K)))) %disp('found a''s that are too small'); % add a little bit and renormalize a = a + (a<(1/(10*K))) * 1/(10*K); a = a ./ repmat(sum(a,2), [1 K]); end %---- Update b ----% % b = mean(m,2); again this is basically it... b = m_sum + hyper_priors * repmat(beta-1,[K 1]); b = b / (C_total + hyper_priors * (sum(beta) - J)); % check to see if any b's are going to zero if any(b<(1/(10*J))) %disp('found b''s that are too small'); % add a little bit and renormalize b = b + (b<(1/(10*J))) * 1/(10*J); denom = kron(eye(K), ones(J,J)) * b; b = b ./ denom; end % make sure all workers get a and b [a,b] = MPI_Bcast(SOURCE, 4e5+it, comm, a, b); end %---- Update mu and sigma ----% mu_numer = zeros(N, J*K); sigma_numer = zeros(N, J*K); denom = zeros(N, J*K); for c = 1:C big_X = repmat(X(:,c), [1 J*K]); big_m = repmat(m(:,c)', [N 1]); if missing_data big_m = big_m .* repmat(missing_mask(:,c), [1 J*K]); end % we only need the g factor when g is not tied, % otherwise just use a multiplier of 1 if g_tied big_g = 1; else big_g = replicate_g(g(:,:,c),J); end % accumulate mu mu_numer = mu_numer + big_g .* big_m .* big_X; % accumulate sigma sigma_numer = sigma_numer + big_g .* big_m .* (big_X - mu) .^ 2; % accumulate denominator denom = denom + big_g .* big_m; if ~mod(c,10), fprintf(1, 'done M-step %d of %d\r', c, C); end end [mu_numer, sigma_numer, denom] = MPI_Reduce(SOURCE, 5e5+it, comm, ... @MPI_SUM, mu_numer, sigma_numer, denom); % replace 0's in the denominator with 1's, which is okay since the % corresponding numerator terms should also be 0's denom(find(denom ==0)) = 1; mu = mu_numer ./ denom; sigma = sqrt(sigma_numer ./ denom); % now make sure sigma is at least sigma_min sigma(sigma < sigma_min) = sigma_min; % make sure all the children know mu and sigma [mu, sigma] = MPI_Bcast(SOURCE, 6e5+it, comm, mu, sigma); %------------------------------------------------------------------- % Calculate Free Energy %------------------------------------------------------------------- if calculate_cost % NOTE: we're only computing the cost for 1/100 of the data, % to save time SKIP = 100; % note that the C factor only appears when g is tied but r is not F = C^(~r_tied & g_tied) * sum(sum(sum(g .* log(g + (g==0))))) ... + sum(sum(sum(m .* log(m + (m==0))))); % d_ijk^c term log_sigma = log(sigma); one_over_2_sigma_sq = 1./(2* sigma.^2); kr = kron(eye(K),ones(J,1)); for c = 1:SKIP:C %d = log(sigma) + (mu - repmat(X(:,c),[1 J*K])).^2 ./ (2 * sigma.^2); %d = log_sigma + (mu - repmat(X(:,c),[1 J*K])).^2 .* one_over_2_sigma_sq; d = log_sigma + faster_sparse(mu, X(:,c),one_over_2_sigma_sq); if missing_data d = d .* repmat(missing_mask(:,c),[1 J*K]); end % replace d with sum_j (m * d) d = (repmat(m(:,c)',[N 1]) .* d) * kr; F = F + sum(sum(g(:,:,c^(~g_tied)) .* d)) * SKIP; end % if we're adapting the priors, include these terms % @TODO - we could probably include these even if we aren't % adapting the priors if adapt_priors F = F - (C^(~r_tied & g_tied)) * sum(sum(sum(g,3) .* log(a))) ... - sum(sum(m,2) .* log(b)); end % if we're using a hyper-prior, include these terms if hyper_priors F = F - sum((alpha-1) .* sum(a,1)') ... - sum( (beta-1) .* log(sum(reshape(b, [J K]),2))); end F = MPI_Reduce(SOURCE, 7e5+it, comm, @MPI_SUM, F); %MPI: warning only SOURCE knows what F is now. cost = [cost F]; if my_rank == SOURCE fprintf(1,'iteration %d, cost=%g \n', it, F); end else if my_rank == SOURCE fprintf(1,'iteration %d \n',it); end end %if it == 1 % @DEBUG print memory usage info % whos; % end %------------------------------------------------------------------- % backup %------------------------------------------------------------------- if my_rank == SOURCE save temporary; end % Print out the expected time to completion elapsed_time = toc; time_to_finish = (iterations - it) * elapsed_time; fprintf(1, 'should be done learning at %s\n', ... datestr(now+time_to_finish/(60*60*24))); end %---- Reshape the matricies ----% % so that j and k *aren't* collapsed into a single matrix dimension m = reshape(m, [J K C]); save(['m' num2str(my_rank) '.mat'], 'm'); mu = permute(reshape(mu, [N J K]), [2 3 1]); sigma = permute(reshape(sigma, [N J K]), [2 3 1]); b = reshape(b, [J K]); MPI_Finalize; % Finalize Matlab MPI. disp('SUCCESS'); %----------------------------------------------------------------------- %----------------------------------------------------------------------- function result = replicate_g(G,J) % REPLICATE_G this function returns G with columns replicated so now % there are J*K of them, rather than just K K = size(G,2); result = G(:,kron(1:K,ones(1,J))); %----------------------------------------------------------------------- %----------------------------------------------------------------------- function result = getIndex(j,k,J) % GETINDEX for a vector of length J*K, this function returns the % index corresponding to element j,k result = (k-1)*J + j;