%ASSUMES MAXEPOCH AND NUMHID ARE SET EXTERNALLY. epsilonw = 0.05; epsilonvb = 0.05; epsilonhb = 0.05; %% weightcost = 0.0001; %% WEIGHTCOST is set outside this script (to make experiments easier) initialmomentum = 0.5; finalmomentum = 0.9; [numcases numdims numbatches]=size(batchdata); if restart ==1, restart=0; epoch=1; poshidprobs = zeros(numcases,numhid); neghidprobs = zeros(numcases,numhid); posprods = zeros(numdims,numhid); negprods = zeros(numdims,numhid); vishid = 0.1*randn(numdims, numhid); hidbiases = 0*ones(1,numhid); visbiases = zeros(1,numdims); vishidinc = zeros(numdims,numhid); hidbiasinc = zeros(1,numhid); visbiasinc = zeros(1,numdims); end for epoch = epoch:maxepoch, errsum=0; for batch = 1:numbatches, %%%%% POSITIVE PHASE %%%%%% data = batchdata(:,:,batch); poshidprobs = 1./(1 + exp(-data*vishid - repmat(hidbiases,numcases,1))); % batchposhidprobs(:,:,batch)=poshidprobs; posprods = data' * poshidprobs; poshidact = sum(poshidprobs); posvisact = sum(data); %%%%%%%%% END OF POSITIVE PHASE %%%%%%% poshidstates = poshidprobs > rand(numcases,numhid); %%%%%%%% START NEGATIVE PHASE %%%%%%%%% negdata = 1./(1 + exp(-poshidstates*vishid' - repmat(visbiases,numcases,1))); neghidprobs = 1./(1 + exp(-negdata*vishid - repmat(hidbiases,numcases,1))); negprods = negdata'*neghidprobs; neghidact = sum(neghidprobs); negvisact = sum(negdata); %%%%%%%%% END OF NEGATIVE PHASE %%%%%%%% err= sum(sum( (data-negdata).^2 )); errsum = err + errsum; % fprintf(1, '%6.1f \n', err); %fprintf(1, '%6.3f \n', max(data(:))); if epoch>5, momentum=finalmomentum; else momentum=initialmomentum; end; vishidinc = momentum*vishidinc + ... epsilonw*( (posprods-negprods)/numcases - weightcost*vishid); visbiasinc = momentum*visbiasinc + (epsilonvb/numcases)*(posvisact-negvisact); hidbiasinc = momentum*hidbiasinc + (epsilonhb/numcases)*(poshidact-neghidact); vishid = vishid + vishidinc; visbiases = visbiases + visbiasinc; hidbiases = hidbiases + hidbiasinc; end if rem(epoch,100) ==0, showrbmweights; drawnow; end; if rem(epoch,maxepoch) ==0, fprintf(1, 'numhid %4.0i epoch %4.0i error %6.1f \n', numhid, epoch, errsum); end; end;