
#include <math.h>
#include <stdlib.h>
#include "util.h"

extern int  no_wts;                               /* length of weight vector */
extern real *w;                                             /* weight vector */
extern real fgeval(real *dw);   /* evaluate function and partial derivatives */

real *dw1, *dw2,                /* two arrays of derivatives for all weights */
     *s;                           /* search direction used for linesearches */
int  nfgeval;  /* number of func/grad evals. so far; incremented by fgeval() */

int lns(f1, z, d1)
  real *f1,                                        /* current function value */
       *z,                                         /* guess for initial step */
       d1;                                                          /* slope */
{
  real RHO = 0.25, SIG = 0.5, INT = 0.1, EXT = 3.0;
  int  MAX = 20;
  real d2, d3, f2, f3, z2, z3, A, B, max = -1.0;
  int  i, k;

  for (i=0; i<no_wts; i++) w[i] += *z*s[i];                /* update weights */
  f2 = fgeval(dw2);
  d2 = 0.0; for (i=0; i<no_wts; i++) d2 += dw2[i]*s[i];
  f3=*f1; d3=d1; z3=-*z;              /* initialize point 3 equal to point 1 */
  k = nfgeval + MAX; while (nfgeval < k) { /* allow limited amount of search */
    while (((f2 > *f1+*z*RHO*d1) || (d2 > -SIG*d1)) && (nfgeval < k)) {
      max=*z;                                         /* tighten the bracket */
      if (f2 > *f1) z2=z3-(0.5*d3*z3*z3)/(d3*z3+f2-f3);     /* quadratic fit */
      else {                                                    /* cubic fit */
        A = 6.0*(f2-f3)/z3+3.0*(d2+d3);
        B = 3.0*(f3-f2)-z3*(d3+2.0*d2);
        z2 = (sqrt(B*B-A*d2*z3*z3)-B)/A;   /* numerical error possible - ok! */
      }
      if (z2 != z2) z2 = z3/2;                /* if z2 is NaN then bisection */
      if (z2 > INT*z3) z2 = INT*z3;      /* bound solution away from current */
      if (z2 < (1.0-INT)*z3) z2 = (1.0-INT)*z3;        /* bound away from z3 */
      *z += z2;                                  /* update absolute stepsize */
      for (i=0; i<no_wts; i++) w[i] += z2*s[i];            /* update weights */
      f2 = fgeval(dw2);
      d2 = 0.0; for (i=0; i<no_wts; i++) d2 += dw2[i]*s[i];
      z3 -= z2;                  /* z3 is now relative to the location of z2 */
    }
    if (d2 > SIG*d1) { *f1 = f2; return 1; }                      /* SUCCESS */
    A = 6.0*(f2-f3)/z3+3.0*(d2+d3);              /* make cubic extrapolation */
    B = 3.0*(f3-f2)-z3*(d3+2.0*d2);
    z2 = -d2*z3*z3/(B+sqrt(B*B-A*d2*z3*z3));    /* num. error possible - ok! */
    if (z2 != z2)                                              /* z2 is NaN? */
      z2 = (max < -0.5) ? *z*(EXT-1.0) : 0.5*(max-*z);          /* bisection */
    else if (z2 < 0.0)                 /* minimum is to the left of current? */
      z2 = (max < -0.5) ? *z*(EXT-1.0) : 0.5*(max-*z);          /* bisection */
    else if ((max > -0.5) && (z2+*z > max))           /* extrap. beyond max? */
      z2 = 0.5*(max-*z);                                        /* bisection */
    else if ((max < -0.5) && (z2+*z > *z*EXT))        /* extrap. beyond EXT? */
      z2 = *z*(EXT-1.0);                             /* set to extrap. limit */
    else if (z2<-z3*INT)                      /* too close to current point? */
      z2 = -z3*INT;
    else if ((max > -0.5) && (z2 < (max-*z)*(1.0-INT))) /* too close to max? */
      z2 = (max-*z)*(1.0-INT);
    f3=f2; d3=d2; z3=-z2;                              /* swap point 2 and 3 */
    *z += z2;
    for (i=0; i<no_wts; i++) w[i] += z2*s[i];              /* update weights */
    f2 = fgeval(dw2);
    d2 = 0.0; for (i=0; i<no_wts; i++) d2 += dw2[i]*s[i];
  }
  *f1 = f2;
  return 0;                                             /* linesearch failed */
}

int conj(iter, epoch, restart, costvalue)
  int  *iter,       /* "iter" and "epoch" indicates the maximum number of... */
       *epoch,  /* iterations or epochs allowed. Actual numbers are returned */
       *restart;      /* if (*restart) then restart CG with steepest descent */
  real *costvalue;                   /* return the value of the costfunction */
{
  static int  ls_failed = 0;               /* set to 1 if line search failed */
  int         j,                                    /* miscellaneous counter */
              cur_iter = 0;                      /* counts current iteration */
  static real fun, slope1, step;
  real        *tmp, y, z, q, slope2;
  extern int  TIMEOUT;             /* is set to one when SIGVTALRM is caught */

  nfgeval = 0;   /* global int "number of function and gradient evaluations" */
  if (*restart) {            /* start by using direction of steepest descent */
    fun = fgeval(dw1);
    slope1 = 0.0;
    for (j=0; j<no_wts; j++) { s[j] = -dw1[j]; slope1 -= s[j]*s[j]; }
    step = -1.0/(slope1-1.0);          /* set initial step-size to 1/(|s|+1) */
    *restart = 0;          /* probably we won't want to restart on next call */
  }
  while ((!TIMEOUT) && ((*epoch == 0) || (nfgeval < *epoch)) &&
         ((*iter == 0) || (cur_iter < *iter))) {  
    cur_iter++;
    if (lns(&fun, &step, slope1)) {              /* if line search succeeded */
      y = z = q = 0.0; for (j=0; j<no_wts; j++)
        { y += dw2[j] * dw2[j]; z += dw1[j] * dw2[j]; q += dw1[j] * dw1[j]; }
      y = (y-z)/q; 
      for (j=0; j<no_wts; j++) s[j] = y*s[j]-dw2[j];        /* new direction */
      tmp = dw2; dw2 = dw1; dw1 = tmp;                   /* swap derivatives */
      slope2 = 0.0; for (j=0; j<no_wts; j++) slope2 += dw1[j]*s[j];
      if (slope2 > 0.0) {     /* must be negative, else use steepest descent */
        slope2 = 0.0;
        for (j=0; j<no_wts; j++) { s[j] = -dw1[j]; slope2 -= s[j]*s[j]; }
      } 
      step *= (slope1/slope2 > 100.0) ? 100.0 : slope1/slope2; slope1 = slope2;
      ls_failed = 0; 
    } else {                                           /* line search failed */
      if (ls_failed)          /* break if previous failed, else try steepest */
        { *epoch = nfgeval; *iter = cur_iter; *costvalue = fun; return 0; }
      dw1 = dw2;                                         /* swap derivatives */
      slope1 = 0.0;
      for (j=0; j<no_wts; j++) { s[j] = -dw1[j]; slope1 -= s[j]*s[j]; }
      step = -1.0/(slope1-1.0);     /* set new step-size guess to 1/(|s|+1) */
      ls_failed = 1; 
    }
  }
  *epoch = nfgeval; *iter = cur_iter; *costvalue = fun; return 1; 
}
