/* leapfrog.c: Perform leapfrog iterations for a set of differential eq.
 * 
 * Do 'num' leapfrog iterations and save every 'mod' one to the log-file 'df'.
 * If 'num' is negative then just keep going until TIMEOUT is set.
 *
 * (c) Copyright 1996 Carl Edward Rasmussen */

#include <math.h>
#include <stdio.h>
#include "util.h"
#include "rand.h"

extern int  no_wts, TIMEOUT;
extern real *dw1, *z, *w;
extern struct itimerval timer;
extern struct exampleset train;

extern real prior(real *dw);
extern real fgeval(real *dw, real (*prior)(real *dw));

void leapfrog(FILE *df, real RHO, real EPSILON, int num, int mod, int start)
{
  int  i, j, k = 0, reject = 0;
  long nexttime = elapsedTime(&timer);
  real E_old, E_kin, E_pot, E_pot_old, *oz, *ow;                /* old state */

  EPSILON /= sqrt((real) train.num);                       /* scale stepsize */

  oz = (real*) malloc((size_t) no_wts*sizeof(real));
  ow = (real*) malloc((size_t) no_wts*sizeof(real));

  E_pot = fgeval(dw1, prior);

  E_kin = 0; for (j=0; j<no_wts; j++) E_kin += z[j]*z[j]; E_kin *= 0.5;
  E_old = E_pot+E_kin;

  for (i=start; i<=num || num<0; i++) {
    for (j=0; j<no_wts; j++)
      { oz[j] = z[j]; ow[j] = w[j]; } E_pot_old = E_pot;   /* save old state */

    for (j=0; j<no_wts; j++)                    /* initial 2/3 leapfrog step */
      { z[j] -= 0.5*EPSILON*dw1[j]; w[j] += EPSILON*z[j]; }

    E_pot = fgeval(dw1, prior);    
    E_kin = 0.0; for (j=0; j<no_wts; j++)     /* remaining 1/3 leapfrog step */
      { z[j] -= 0.5*EPSILON*dw1[j]; E_kin += z[j]*z[j]; } 
    E_kin *= 0.5;

    if (exp(E_old-E_pot-E_kin) < rand_uniform()) {                 /* reject */
      reject++;
      for (j=0; j<no_wts; j++)       /* restore old state and negate momenta */
       { z[j] = -oz[j]; w[j] = ow[j]; } E_pot = E_pot_old;
    }

    for (j=0; j<no_wts; j++)                              /* replace momenta */
      z[j] = RHO*z[j]+sqrt(1.0-RHO*RHO)*rand_gaussian();
    E_kin = 0.0; for (j=0; j<no_wts; j++) E_kin += z[j]*z[j]; E_kin *= 0.5;
    E_old = E_pot+E_kin;
    k++;
    if (((mod>0) && !(i % mod)) ||
        ((mod<0) && (elapsedTime(&timer)>nexttime))) { 
      fprintf(df, "%6d %8d %10.6f %10.6f %10.6f", i,
                   elapsedTime(&timer),
                   E_pot, E_kin, (real) reject/k);
      for (j=0; j<no_wts; j++) fprintf(df, " %10.6f", w[j]);
      for (j=0; j<no_wts; j++) fprintf(df, " %10.6f", z[j]);
      fprintf(df, "\n"); fflush(df);
      nexttime -= 1000*abs(num)/mod;                /* set next time to save */
      k = reject = 0;                  /* reset iteration and reject counter */
    }
    if (TIMEOUT) exit(0);  /* TIMEOUT will be set when a SIGVTALRM is caught */
  }
  free(oz); free(ow);
}










