/*
 * NelderMeadMinimizer.java
 *
 * Created on July 7, 2005, 5:23 PM
 *
 *  Copyright 2005 Daniel Wachsstock
 *  The contents of this file are subject to the Sun Public License
 *  Version 1.0 (the License); you may not use this file except in
 *  compliance with the License. A copy of the License is available at
 *  http://www.sun.com/ or http://www.geocities.com/tenua4java/license.html
 */
package nr.minimizer;
import nr.*;

/** The downhill simplex algorithm of Nelder and Mead. Based on
 *  Numerical Recipes, chapter 10.4
 *
 *  @author Daniel Wachsstock
 */
public class NelderMeadMinimizer extends VecMinimizerImp {
    protected Vec[] _simplex; // the simplex
    protected int _numPoints; // the number of points in the simplex
      // _numPoints should be == _n+1 for optimal working of the algorithm
      // if it's less, then the algorithm only searches a subspace of possible
      // minima. If it's more, then it is very inefficient.
    protected double[] _fs; // the function values at each point of the simplex
    protected int _lo, _nhi, _hi; // indices to the lowest, second-highest, and
      // highest points of the simplex
    protected double[] _sum; // sum of the coordinates along each axis of all the
      // points. The line through the i-th point perpendicular to the hyperplane
      // formed by the other points is
      // _sum[i]*(1-p)/_n - _simplex[i] * ((_n+1)*p-1)/_n
      // where p is the distance to the plane. p == 1 is the original point,
      // p == 0.5 is halfway there, p == -1 is reflected across the plane.
    
    /** Creates a new instance of NelderMeadMinimizer */
    public NelderMeadMinimizer (ScalarFunction f) {
        super (f);
    } // constructor
    
    void doMinimize(){
        // create the simplex
        _numPoints = _n + 1;
        _simplex = new Vec[_numPoints];
        _fs = new double [_numPoints];
        for (int i = 0; i < _numPoints; ++i){
            // first point is _x; subsequent points are incremented in each dimension
            _simplex[i] = _x.copy();
            if (i > 0){
                double d = _x.get(i-1);
                if (d != 0d){
                    _simplex[i].set (i-1, 2*d);
                }else{
                    _simplex[i].set (i-1, 1);
                } // if d
            } // if i
            _fs[i] = eval(_simplex[i]);
        } // for
        _sum = new double[_n];
        getSum();
        for (;;){
            getLoHi();
            if (converged (_simplex[_lo], _simplex[_hi])){
                _x.set (_simplex[_lo]);
                _fx = _fs[_lo];
                return;
            } // if
            double ftry = tryIt (-1d); // try a reflection
            if (ftry < _fs[_lo]){
                // new point really good; go further
                ftry = tryIt (2d);
            }else if (ftry >= _fs[_nhi]){
                // new point really bad, try a 50% movement of the high point
                // and if that doesn't work, shrink the whole thing
                double oldFHi = _fs[_hi];
                ftry = tryIt (0.5d);
                if (ftry >= oldFHi) shrink();
            } // if ftry >= _fs[nhi]
        } // forever
    } // doMinimize
    
    protected void getSum(){
        for (int i = 0; i < _n; ++i){
            _sum[i] = 0d;
            for (int point = 0; point < _numPoints; ++point){
                _sum[i] += _simplex[point].get(i);
            } // for point
        } // for i
    } // getSum
    
    protected void getLoHi(){
        _lo = _nhi = 0;
        _hi = 1;
        if (_fs[_nhi] > _fs[_hi]){
            _hi = 0;
            _lo = _nhi = 1;
        } // if
        for (int i = 2; i < _numPoints; ++i){
            if (_fs[i] <= _fs[_lo]) _lo = i;
            if (_fs[i] > _fs[_hi]){
                _nhi = _hi;
                _hi = i;
            }else if (_fs[i] > _fs[_nhi]){
                _nhi = i;
            } // if
        } // for
    } // getLoHi
    
    // move the highest point by factor closer to the plane formed by the
    // other points. If the new point is better, replace it.
    protected double tryIt (double factor){
        Vec newPoint = new Vec_array(_n);
        double factor1 = (1d-factor)/_n;
        double factor2 = factor1 - factor;
        for (int i = 0; i < _n; ++i){
            newPoint.set (i, factor1*_sum[i] - factor2*_simplex[_hi].get(i));
        } // for
        double ftry = eval (newPoint);
        if (ftry < _fs[_hi]){ // it's better!
            for (int i = 0; i < _n; ++i){
                _sum[i] += newPoint.get(i) - _simplex[_hi].get(i);
            } // for
            _fs[_hi] = ftry;
            _simplex[_hi] = newPoint;
        } // if
        return ftry;
    } // tryIt

    // shrink all the points but the lowest half way to the lowest
    protected void shrink(){
        for (int point = 0; point < _numPoints; ++point) if (point != _lo) {
            for (int i = 0; i < _n; ++i){
                _simplex[point].set
                  (i, 0.5*(_simplex[point].get(i)+_simplex[_lo].get(i)));
            } // for i
            _fs[point] = eval (_simplex[point]);
        } // for point
        getSum();
    } // shrink
    
    /** test suite */
    public static void main (String[] args){
        ScalarFunction f = new ScalarFunction(){
            public double eval (Vec x){
                return 0.6-bessj0 (sqr(x.get(0)-1)+sqr(x.get(1)-0.6)+sqr(x.get(2)-0.7));
            } // eval
            public double sqr (double x){return x*x;}
            public double bessj0 (double x){
                double ax,z;
                double xx,y,ans,ans1,ans2;

                if ((ax=Math.abs(x)) < 8.0) {
                        y=x*x;
                        ans1=57568490574.0+y*(-13362590354.0+y*(651619640.7
                                +y*(-11214424.18+y*(77392.33017+y*(-184.9052456)))));
                        ans2=57568490411.0+y*(1029532985.0+y*(9494680.718
                                +y*(59272.64853+y*(267.8532712+y*1.0))));
                        ans=ans1/ans2;
                } else {
                        z=8.0/ax;
                        y=z*z;
                        xx=ax-0.785398164;
                        ans1=1.0+y*(-0.1098628627e-2+y*(0.2734510407e-4
                                +y*(-0.2073370639e-5+y*0.2093887211e-6)));
                        ans2 = -0.1562499995e-1+y*(0.1430488765e-3
                                +y*(-0.6911147651e-5+y*(0.7621095161e-6
                                -y*0.934945152e-7)));
                        ans=Math.sqrt(0.636619772/ax)*(Math.cos(xx)*ans1-z*Math.sin(xx)*ans2);
                }
                return ans;
            }
        }; // new ScalarFunction
        f = new ScalarFunction(){
            public double eval (Vec x){
                double result = 0;
                result += (x.get(0) < 0) ? x.get(0)*x.get(0) : x.get(0) ;
                result += (x.get(1) < 0) ? x.get(1)*x.get(1) : x.get(1) ;
                result += (x.get(2) < 0) ? x.get(2)*x.get(2) : x.get(2) ;
                return result;
            } // eval
        };
        VecMinimizer min = new NelderMeadMinimizer (f);
        min.setEpsilon (1e-3);
        Vec x = new Vec_array(4);
        x.set (0, -400);
        x.set (1, 1000);
        x.set (2, 50000);
        System.out.println("Starting: " + x);
        double m = min.minimize (x);
        System.out.println("Ending: " + x + "; min = " + m);
        System.out.println("Function calls:" + min.numFuncEvals());
        m = min.minimize(x);
        System.out.println("Ending: " + x + "; min = " + m);
        System.out.println("Function calls:" + min.numFuncEvals());
        m = min.minimize(x);
        System.out.println("Ending: " + x + "; min = " + m);
        System.out.println("Function calls:" + min.numFuncEvals());
    } // main

    // for printing out the simplex, for debugging
    protected String simplexString(){
        String result = "";
        for (int point = 0; point < _numPoints; ++point){
            result += _simplex[point].toString();
        }
        return result;
    }
} // NelderMeadMinimizer
