/*
 * Mechanism.java
 *
 * Created on June 10, 2004, 4:44 PM
 *
 *  Copyright 2004 Daniel Wachsstock
 *  The contents of this file are subject to the Sun Public License
 *  Version 1.0 (the License); you may not use this file except in
 *  compliance with the License. A copy of the License is available at
 *  http://www.sun.com/ or http://www.geocities.com/tenua4java/license.html
 */

package tenua.simulator;
import tenua.symbol.Expression;
import tenua.symbol.Symbol;
import tenua.symbol.SymbolTable;
import tenua.symbol.Value;
import tenua.symbol.Variable;
import tenua.symbol.VariableMemento;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.List;

/** a class that simulates a set of chemical reactions
 *
 * @author  Daniel Wachsstock
 */
public class Mechanism extends AbstractDataGenerator {
    private final Hashtable _variableNames;
    private final bsh.Interpreter _interpreter;
    private final Expression _calculations;
    private final Expression _derivativeCalculations;
    private final SymbolTable _st;
    private final List _rateNames;
    private final List _speciesNames;
    private List _outputNames;
    private VariableMemento _initialMemento, _currentMemento;
    private boolean _running; // true if in the middle of a simulation
    private final nr.ode.ODE.Factory _odeFactory = new nr.ode.ODE.Factory();
    private nr.ode.ODE _odeSolver; // the algorithm for solving ODE's
    private nr.minimizer.VecMinimizer _minimizer; // the algorithm for
      // multidimensional minimization
    private String _minimizerName = "lmb";
    private nr.Vec _species; // the chemical species concentrations
    private nr.Vec _rates; // the rates of production of each species
    private int _numRuns = 0; // the number of times goSimulate has been called

    
    // keys to variable names
    public static final String TIME_CONSTANTS =
      util.Resources.getString("variables.timeConstants");
    public static final String SPECIES =
      util.Resources.getString("variables.species");
    public static final String RATE_CONSTANTS =
      util.Resources.getString("variables.rateConstants");
    public static final String CHEMICAL_RATES =
      util.Resources.getString("variables.chemicalRates");
    public static final String PARAMETERS =
      util.Resources.getString("variables.parameters");
    public static final String OUTPUTS =
      util.Resources.getString("variables.outputs");

    /** Creates a new instance of Mechanism */
    public Mechanism (Hashtable variableNames, Expression calculations,
      Expression derivativeCalculations, SymbolTable st,
      bsh.Interpreter interpreter) {
        super((List)variableNames.get(OUTPUTS));
        _variableNames = variableNames;
        _calculations = calculations;
        _derivativeCalculations = derivativeCalculations;
        _st = st;
        _interpreter = interpreter;
        _initialMemento = _st.getMemento();
        _currentMemento = null;
        _running = false;
        // create the list of chemical rate names
        _speciesNames = (List) variableNames.get(SPECIES);
        _rateNames = (List) variableNames.get(CHEMICAL_RATES);
        // create a backup of the output names
        _outputNames = new java.util.Vector(_names);       
    } // constructor

    /** generates data by calling <code>script()</code> in the
     *  interpreter, which generally calls {@link #goSimulate()}.
     *  @param memento the VariableMemento to use. If null, ignored
     */
    public void start(Object memento){
        if (memento != null) _initialMemento = (VariableMemento) memento;
        // we reset the output names before each run; this avoids surprises if
        // rename() was called in the previous run.
        _names.clear();
        _names.addAll(_outputNames);
        eval ("$script()"); // run the script
    } // start
    
    /** actually runs the simulation; called by the interpreter */
    public void goSimulate() throws InterruptedException, 
      NoSuchMethodException, InstantiationException, IllegalAccessException{

        ++_numRuns;
        fireStartingUp();
        _currentMemento = new VariableMemento(_initialMemento);
        _species = _st.subset(_speciesNames, _currentMemento);
        nr.Vec tempSpecies = new nr.Vec_array(_species); // for the calculations
        _rates = _st.subset (_rateNames, _currentMemento);
        try{
            _odeSolver = _odeFactory.instance ( new Dxdt());
        }catch (java.lang.reflect.InvocationTargetException ex){
            // if we can't get the solver, rethrow as a runtime exception
            throw new RuntimeException (ex.getCause());
        }
        
        _running = true;
        final double startTime = ((Number)get("startTime")).doubleValue();
        final double endTime = ((Number)get("endTime")).doubleValue();

        // fire the zero-time point data
        put ("time", startTime);
        fireNewData(_calculations.eval(_currentMemento));
        if (startTime == endTime) return; // only one point wanted

        // timeStep[0] is the ideal step size for the algorithm;
        // timeStep[1] is the maximum step allowed going into solve,
        // and the actual step taken after
        double[] timeStep = new double[2];
        timeStep[0] = endTime - startTime; // intially try going the whole way
        timeStep[1] = ((Number)get("timeStep")).doubleValue();
        // timeStep == 0 means use varying time steps
        boolean varySteps = (timeStep[1] == 0d);

        _odeSolver.setEpsilon(((Number)get("epsilon")).doubleValue());

        // the condition here has to allow for startTime > endTime
        for (double time = startTime;
          startTime < endTime ? time < endTime: time > endTime;){
            Thread.yield();
            if (Thread.interrupted()){
                throw new InterruptedException(); // re-raise
            } // if
            if (varySteps){
                timeStep[1] = endTime-time;
                if (Math.abs(timeStep[0]) > Math.abs(timeStep[1]))
                    timeStep[0] = timeStep[1]; // don't overshoot
                _odeSolver.solve(tempSpecies, timeStep);
            }else{
                _odeSolver.solve(tempSpecies, timeStep[1]);
            }
            double oldTime = time;
            time += timeStep[1];
            if (oldTime == time) throw new nr.DidNotConvergeException();
            put ("time", time);
            _species.set (tempSpecies);
            fireNewData(_calculations.eval(_currentMemento));
            // fireNewData will make copies for each listener
        } // for
        _running = false;
        fireDone();
    } // goSimulate

    /** runs the simulation while varying a given list of variables.
     *  @param names the array of variable names to vary
     *  On return, the initial variables will be set such that the variables
     *  named minimize the script function called <code>fitFunction()</code>,
     *  which should return a double.
     */
    public void goVarying (String[] names) {
        nr.Vec vars = _st.subset (names, _initialMemento);
        nr.minimizer.VecMinimizer minimizer = getMinimizer (vars);
        try{
            double eps = ((Number)get("minimizerEpsilon")).doubleValue();
            minimizer.setEpsilon(eps);
        }catch (Exception ex) {
            minimizer.setEpsilon(0.01); // default epsilon
        }
        minimizer.minimize(vars);
    } // goVarying

    private nr.minimizer.VecMinimizer getMinimizer (final nr.Vec vars){
        if ("nms".equals(_minimizerName)){
            nr.ScalarFunction func = new nr.ScalarFunction(){
                public double eval (nr.Vec v) {
                    vars.set (v);
                    Object result = Mechanism.this.eval ("$fit();");
                    return ((Number) result).doubleValue();
                } // eval
            }; // new ScalarFunction
            return new nr.minimizer.NelderMeadMinimizer (func);
        }else if ("lmfd".equals(_minimizerName)){
            double[] data =
              (double[]) eval ("dataseries($f2).getAllYValues(0);");
            nr.VecFunction func = new nr.VecFunction(){
                public nr.Vec eval (nr.Vec x, nr.Vec y){
                    vars.set (x);
                    Mechanism.this.eval ("mechanism.goSimulate();");
                    double[] result = (double[]) Mechanism.this.eval
                      ("dataseries($f2,$f1).getAllYValues(1);");
                    if (y == null){
                        return new nr.Vec_wrapper(result);
                    }else{
                        y.set(result);
                        return y;
                    }
                } // eval
            }; // new VecFunction
            return new nr.minimizer.LevenbergMarquardt 
              (new nr.Vec_wrapper(data), func);
        }else if ("lmb".equals(_minimizerName)){
            double[] data =
              (double[]) eval ("dataseries($f2).getAllYValues(0);");
            nr.VecFunction func = new nr.VecFunction(){
                public nr.Vec eval (nr.Vec x, nr.Vec y){
                    vars.set (x);
                    Mechanism.this.eval ("mechanism.goSimulate();");
                    double[] result = (double[]) Mechanism.this.eval
                      ("dataseries($f2,$f1).getAllYValues(1);");
                    if (y == null){
                        return new nr.Vec_wrapper(result);
                    }else{
                        y.set(result);
                        return y;
                    }
                } // eval
            }; // new VecFunction
            return new nr.minimizer.LevenbergMarquardtBroyden
              (new nr.Vec_wrapper(data), func);
        } // if
        // TODO: change this to a resource
        throw new IllegalArgumentException
          (_minimizerName+ " is not a defined minimizer name");
    } // getMinimizer

    /** runs the simulation while varying a given list of variables, multiple
     *  times.
     *  @param names the array of variable names to vary
     *  On return, the initial variables will be set such that the variables
     *  named minimize the script function called <code>fitFunction()</code>,
     *  which should return a double. Also, a double array
     *  <code>parameters</code> is created in the interpreter
     *  such that <code>parameters[i][j]</code> is the value of
     *  the <code>j</code>th variable after the <code>i</code>th run.
     *  Also, double arrays <code>mean</code>, <code>stddev</code> and
     *  <code>stddev1</code> are created, such than <code>mean[j]</code>
     *  is the mean of the <code>j</code>th variable in all the runs,
     *  <code>stddev[j]</code> is the standard deviation, and
     *  <code>stddev1[j]</code> is the standard deviation assuming
     *  that <code>parameters[0]</code> is the true value of the
     *  variables. This would be the case if the first run fit to the
     *  actual data and the other runs fit to simulated or selected data.
     */
    public void goVarying (int n, String[] names) throws bsh.EvalError{
        double[][] parameters = new double[n][];
        nr.Vec vars = _st.subset (names, _initialMemento);
        for (int i=0; i<n; ++i){
            _interpreter.eval ("$beforeFit("+i+")");
            goVarying (names);
            parameters[i] = vars.asArray();
        } // for
        getStats(parameters);
    } // goVarying

    // calculate the statistics for the parameters
    private void getStats (double[][] parameters) throws bsh.EvalError {
        int count = parameters[0].length;
        int n = parameters.length;
        // stddev1 is standard deviation assuming parameters[0] as the mean
        double sum [] = new double[count];
        double mean [] = new double[count];
        double stddev [] = new double[count];
        double stddev1 [] = new double[count];
        _interpreter.set ("parameters", parameters);
        _interpreter.set ("mean", mean);
        _interpreter.set ("stddev", stddev);
        _interpreter.set ("stddev1", stddev1);
        for (int p=0; p<count; ++p){
            sum[p]=mean[p]=stddev[p]=stddev1[p]=0;
            for (int i=0; i<n; i++) sum[p]+=parameters[i][p];
            mean[p] = sum[p]/n;
            double sumOfDiff = 0;
            double sumOfDiff1 = 0;
            for (int i=0; i<n; ++i){
                double diff = parameters[i][p]-mean[p];
                stddev[p] += diff*diff;
                sumOfDiff += diff;
                diff = parameters[i][p]-parameters[0][p];
                stddev1[p] += diff*diff;
                sumOfDiff1 += diff;
            } // for i
            // Numerical Recipes equation 14.1.8
            stddev[p] = Math.sqrt ((stddev[p]-sumOfDiff*sumOfDiff/n)/(n-1));
            stddev1[p] = Math.sqrt ((stddev[p]-sumOfDiff1*sumOfDiff1/n)/(n-1));
        } // for p
    } // goVarying 
   
    /** Renames an output datum.
     *  @param i the index of the output to change
     *  @param newName the name to change to
     *  Does nothing if the name is not found
     */
    public void renameOutput(int i, String newName){
       _names.set(i,newName);
    } // renameOutput
    
    /** returns the current state of the simulation.
     *  @return a memento of the simulation's state
     */
    public Object getMemento() {
        return _running ? _currentMemento : _initialMemento;
    } // get Memento
    
    /** returns the name of the current ODE solver.
     *  @return the name
     */
    public String getSolver() { return _odeFactory.getODE(); }
    
    /** set the ODE solver to use.
     *  The names "fast", "normal" and "stiff" are guaranteed to be
     *  defined
     *  @param name the name of the ODE solver
     *  @throws IllegalArgumentException if the name is undefined
     */
    public void setSolver (String name){
        if ("normal".equals(name)){
            name = "Bulirsch-Stoer";
        }else if ("fast".equals(name)){
            name = "Modified Midpoint";
        }else if ("stiff".equals(name)){
            name = "Bader-Deuflhard";
        }
        _odeFactory.setODE(name);
    } // setSolver
    
    /** returns the name of the current minimization algorithm.
     *  @return the name
     */
    public String getMinimizer() { return _minimizerName; }
    
    /** Set the minimization algorithm.
     *  Does no error checking now; it will throw IllegalArgumentException
     *  when we call {@link goVarying} if the name is invalid.
     *  @param name the name of the minimizer
     */
    public void setMinimizer (String name){ _minimizerName = name; } 
    
    /** creates a {@link #util.DoubleBean} for a variable
     *  @param name the name of the variable
     *  @returns the DoubleBean that will get/set the value
     *  @throws IllegalArgumentException if name is not a variable
     *  @throws ClassCastException if name does not evaluate to a number
     */
    public util.DoubleBean doubleBean (final String name){
        return new util.DoubleBean(){
            public double getValue() {
                return ((Number)Mechanism.this.get(name)).doubleValue();
            }
            public void setValue (double d) {Mechanism.this.put(name, d);}
        };
    }
    
    /** returns the value of a variable in this mechanism or
     *  its script interpreter.
     *  @param name the name of the variable to search for
     *  @return the value. For mechanism variables, returns the value
     *  in the current run if we are running; otherwise the initial value.
     *  @throws IllegalArgumentException if name is not found
     */
    public Object get (String name) throws IllegalArgumentException {
        if (_st.get(name) != null){
            // a mechanism variable
            double value = _running ? getLatest (name) : getInitial(name);
            return new Double(value);
        }else{
            // an interpreter value
            try{
                Object result = _interpreter.get(name);
                if (result==null) result = _interpreter.get("script."+name);
                if (result==null) throw null; // automatic NullPointerException
                return result;
            }catch (Exception ex){
                IllegalArgumentException ex1=new IllegalArgumentException(name);
                ex1.initCause (ex);
                throw ex1;
            } // try
        } // if
    } // get
    
    /** returns the initial value of a variable in this mechanism.
     *  @param name the name of the variable to search for
     *  @return the value in the initial memento
     *  @throws IllegalArgumentException if name is not found
     */
    public double getInitial (String name) throws IllegalArgumentException {
        Symbol s = _st.get(name);
        if (s == null || !(s instanceof Value)) throw new IllegalArgumentException(name);
        nr.Vec result = new Expression(s).eval (_initialMemento);
        return (result.get(0)); // a Value better push one value onto the stack
    } // getInitial

    /** returns the most recently determined value of a variable in this mechanism.
     *  @param name the name of the variable to search for
     *  @return the value in the most recent memento if we are running
     *  @throws IllegalArgumentException if name is not found
     */
    public double getLatest (String name) throws IllegalArgumentException {
        Symbol s = _st.get(name);
        if (s == null || !(s instanceof Value)) throw new IllegalArgumentException(name);
        nr.Vec result = new Expression(s).eval (_currentMemento);
        return (result.get(0)); // a Value better push one value onto the stack
    } // getLatest
    
    /** assigns a value to a variable
     *  @param name the name of the variable to seach for
     *  @param d the value to assign. For mechanism variables,
     *  uses the most recent memento if we are running; otherwise the initial memento.
     *  Any name not defined in the mechanism is assumed to be defined in the script interpreter.
     *  @throws IllegalArgumentException if name is not found
     */
    public void put (String name, double d){
        if (_st.get(name) != null){
            // a mechanism variable
            if (_running) putLatest(name, d); else putInitial(name, d);
            // an interpreter value
        }else{
            try{
                _interpreter.set(name, new Double(d));
            }catch (Exception ex){
                IllegalArgumentException ex1 = new IllegalArgumentException (name);
                ex1.initCause (ex);
                throw ex1;
            } // try
        } // if
    } // put
    
    /** assigns a value to a variable
     *  @param name the name of the variable to seach for
     *  @param o the value to assign. For mechanism variables,
     *  uses the most recent memento if we are running; otherwise the initial memento.
     *  Any name not defined in the mechanism is assumed to be defined in the script interpreter.
     *  @throws IllegalArgumentException if name is not found, or a non-number is assigned
     *  to a mechanism variable.
     */
    public void put (String name, Object o){
        if (_st.get(name) != null){
            // a mechanism variable
            if (! (o instanceof Number)) throw new IllegalArgumentException (name);
            double d = ((Number) o).doubleValue();
            if (_running) putLatest(name, d); else putInitial(name, d);
        }else{
            // an interpreter value
            try{
                _interpreter.set(name, o);
            }catch (Exception ex){
                IllegalArgumentException ex1 = new IllegalArgumentException (name);
                ex1.initCause (ex);
                throw ex1;
            } // try
        } // if
    } // put
    
    /** assigns a value to use for the initial value for a variable.
     *  @param name the name of the variable to seach for
     *  @param d the value to assign.
     *  Uses the initial memento.
     *  @throws IllegalArgumentException if name is not found
     */
    public void putInitial (String name, double d){
        Symbol s = _st.get(name);
        if (s == null || !(s instanceof Variable)) throw new IllegalArgumentException(name);
        ((Variable) s).assign (d,_initialMemento);
    } // putInitial
    
    /** assigns a value to a variable.
     *  @param name the name of the variable to seach for
     *  @param d the value to assign.
     *  Uses the most recent memento
     *  @throws IllegalArgumentException if name is not found
     */
    public void putLatest (String name, double d){
        Symbol s = _st.get(name);
        if (s == null || !(s instanceof Variable)) throw new IllegalArgumentException(name);
        ((Variable) s).assign (d, _currentMemento);
    } // putLatest
    
    /** Resets all the variables' initial values to their default values.
     */
    public void resetVariables(){
        final String[] names = _st.variableList();
        for (int i = 0; i < names.length; i++){
            Symbol s = _st.get(names[i]);
            if (s instanceof Variable) ((Variable)s).reset(_initialMemento);
        } // for
    } // resetVariables
    
    /** returns a List of the names of all the variables defined of a given type.
     *  @param listName the type of variable to return
     *  @return the List of Strings. Returns an empty List (not null)
     *  if the type is not defined or is empty
     *  This should be entirely read-only; but that's not enforceable in Java
     */
    public List variableNames(String listName){
        List result = (List) _variableNames.get(listName);
        if (result == null) result = new java.util.ArrayList();
        return result;
    } // variableNames

    /** executes a string in the mechanism's interpreter.
     *  @param s the String to evaluate
     *  @return the result of the evaluation
     */
    public Object eval (String s){
        try{
            try{
                Object result = _interpreter.eval(s); // run the script
                _st.saveToPreferences(_initialMemento); // save the current defaults
                return result;
            }catch (bsh.TargetError ex1){
                throw ex1.getTarget(); // need to catch target errors and rethrow correctly
            } // try (inner)
        }catch (bsh.ParseException ex2){
            util.ErrorDialog.errorDialog("InterpreterError", ex2);
        }catch (nr.DidNotConvergeException ex3){
            util.ErrorDialog.errorDialog("DidNotConverge", ex3);
        }catch (InterruptedException intex){
            util.ErrorDialog.errorDialog("Interrupted", intex);
        }catch (ThreadDeath td){
            /* do nothing--no messages */
        }catch (Throwable ex){
            util.ErrorDialog.errorDialog("SimulationError", ex);
        } // try
        throw new ThreadDeath(); // make sure we stop the thread
    } // eval
 
    /** Get the number of times the simulation has been run
     *  with {@link #goSimulate} since creation.
     *  @returns the number of times the simulation has been run
     */
    public int getCount() { return _numRuns; }
  
    // calculate derivatives
    class Dxdt extends nr.VecFunction{
        public nr.Vec eval(nr.Vec x, nr.Vec y) {
            if (y == null) y = new nr.Vec_array (x.size());
            _species.set(x);
            _derivativeCalculations.eval(_currentMemento);
            y.set(_rates);
            return y;
        }
        
        public double jacobianCost() { return 2*_species.size(); }
    } // Dxdt

} // Mechanism
