"""Module implementing the PyWofost Ensemble Kalman filter.

Implements class PyWofostEnKF

Ensemble Kalman filter implementation is derived from:
Evensen, 2003. The Ensemble Kalman Filter: theoretical formulation and practical
implementation. Ocean dynamics 53: 343-367.
"""

import sys, os
import datetime
import logging

import numpy as np
import sqlalchemy
from sqlalchemy import Table

from .pywofost_ensemble import PyWofostEnsemble

class PyWofostEnKF(PyWofostEnsemble):
    """Defines an Ensemble Kalman Filter around a PyWofost ensemble.
    
    Inherits from PyWofostEnsemble.
    
    Public methods:
    * grow_with_assimilation()
    * EnKF_results_to_output_device()
    """

    def __init__(self, sitedata, timerdata, soildata, cropdata,
                 meteodata, data_for_assimilation, rv, ensemble_size,
                 mode="wlp", **kwargs):
        """Class constructor.
        
        Positional arguments:
        ****data              -- site, timer, crop, meteo and soildata for
                                 running WOFOST.
        data_for_assimilation -- data that will be assimilated during the run.
        rv                    -- A random variable of type numpy.random.RandomState
        ensemble_size         -- Size of the wofost ensemble.
        mode                  -- Simulation mode "PP|WLP" for potential/water-limited
                                 mode.
        E****data             -- dictionaries that can be used for providing
                                 ensemble members with different start values for
                                 parameters. The E****data variables consist of
                                 a tuple of ensemble_size containing dictionaries
                                 of same structure as the ****data dictionaries.
        Note:
        data_for_assimilation is specific for this class and should consists of
        a list with data that is to be assimilated (i.e. observations) during the
        model run. For example:
        [(datetime.date(1999,06,10), {"profile_moisture_content": (0.214, 0.04),
                                      "leaf_area_index": (4.25, 0.35)}), \
         (datetime.date(1999,06,15), {"profile_moisture_content": (0.178, 0.04),
                                      "leaf_area_index": (4.80, 0.35)})]
    
        The tuple that is provided for each date/variable defines the observation
        value and variance. 
        """
        PyWofostEnsemble.__init__(self, sitedata, timerdata, soildata, cropdata,
                                  meteodata, ensemble_size, mode=mode, **kwargs)

        # Create new logger for specific EnKF messages
        self.EnKF_logger = logging.getLogger("PyWofost.EnKF")

        self.data_for_assimilation = data_for_assimilation
        self.rv = rv
        self.assimilation_intervals = self._find_assimilation_intervals()
        self.enkf_updates = []
    
    #--------------------------------------------------------------------------
    def _find_assimilation_intervals(self):
        """This module finds the intervals between subsequent dates with
        observations.
        
        It compares the start date of the model simulation with the dates
        where observations are available. It assumes that
        self.data_for_assimilation is sorted by date in ascending order.
        """

        prevdate = self._get_ensemble_date()
        assimilation_intervals = []
        for [day, values] in self.data_for_assimilation:
            t = day - prevdate
            assimilation_intervals.append(t.days)
            prevdate = day
        self.EnKF_logger.debug("Assimilation interval time steps determined.")
        return assimilation_intervals
    
    #--------------------------------------------------------------------------        
    def grow_with_assimilation(self):
        """Runs the ensemble including assimilation using the EnKF.
        
        """
        
        TERMNL = 0
        s = self.ensemble_size
        logger = self.EnKF_logger
        grid_no = self.ensemble[0].grid_no
        crop_no = self.ensemble[0].crop_no
        year = self.ensemble[0].year

        # Run WOFOST in intervals up to each next assimilation step.
        for step, interval in enumerate(self.assimilation_intervals):

            # Grow ensemble up to next assimilation step
            TERMNL = self.ensemble_grow(days=interval)

            # Break from loop if WOFOST simulations have finished
            if TERMNL == 1: break

            #Get assimilation data at current timestep
            current_day = self._get_ensemble_date()
            [day, assim_observations] = self.data_for_assimilation[step]

            # Make sure model date and date of assimilitation variable are the
            # same.
            try:
                assert day == current_day
                logstr = "Assertion of model date (%s) succeeded."
                logger.info(logstr % current_day)
            except AssertionError:
                logstr =  "Model date (%s) and date of observation to be "+\
                          "assimilated (%s) are not equal!"
                raise RuntimeError(logstr % (current_day, day))

            # Get variable codes for assimilation at current time step
            variable_codes = assim_observations.keys()
            logger.debug("Variable codes retrieved: %s" % variable_codes)

            # Get matrix holding ensemble states for observed variables (eq 44)
            A = self._get_ensemble_states(variables=variable_codes)
            
            # Calculate ensemble average (eq 45)
            one_N = np.matrix(np.ones((s,s), dtype=np.float64)/s)
            A_means = A * one_N
            
            # Calculate ensemble covariance (eq 46, 47)
            A_pert = A - A_means
            P_e = (A_pert * (A_pert.T))/(s-1)
            
            # Calculate perturbed observations and their covariance (eq 48-51)
            D, R_e = self._perturb_observations(variable_codes,
                                                values=assim_observations)
            
            # Calculate Kalman gain (K), note that H is defined as an identity 
            # matrix because the observations are equal to the model state and 
            # thus there is no measurement operator. (eq 52)
            H = self._get_measurement_operator(len(variable_codes))
            K1 = P_e * (H.T)
            K2 = (H * P_e) * H.T
            K = K1 * ((K2 + R_e).I)

            # Calculate state updates
            Aa = A + K * (D - (H * A))
            self._update_ensemble_states(variable_codes, Aa)

            # Store EnKF updates for later analyses
            t = {"grid_no":grid_no, "crop_no":crop_no, 
                 "year":year, "day":current_day, 
                 "variable":repr(variable_codes),
                 "A":repr((np.array(A_means)[:,0])),
                 "D":repr((np.array(D).mean(axis=1))),
                 "P_e":repr(P_e), "R_e":repr(R_e),
                 "K":repr(K)}
            self.enkf_updates.append(t)

        # If no intervals are left, but the WOFOST model is not yet
        # finished, then proceed until the end of the growing season.
        if TERMNL == 0:
            self.ensemble_grow(days=300)
                        
    #--------------------------------------------------------------------------
    def _update_ensemble_states(self, variable_codes, state_matrix):
        """Updates the states of each ensemble member.
        

        variable_codes - list of WOFOST variable_codes that are to be updated.
        state_matrix - matrix of [len(variable_codes), ensemble_size]
        """
        
        logger = self.EnKF_logger
        logger.info("Updating ensemble states.")
        msg = "Update member %i - %s to %f"
        for i,variable in enumerate(variable_codes):
            state_vector = (state_matrix).tolist()[i]
            for j,member in enumerate(self):
                member.set_variable(var=variable, value=state_vector[j])
                logger.debug(msg % (j, variable, state_vector[j]))

    #-------------------------------------------------------------------------------
    def _perturb_observations(self, variable_codes, values):
        """Perturbs observations based on obs=[(mean, variance), (mean, variance)]
        see equations 48 -- 51 in Evensen (2003)"""
        
        s = self.ensemble_size
        perturbed_obs = []
        for variable in variable_codes:
            [obs_mean, obs_variance] = values[variable]
            perturbed_obs += [self.rv.normal(obs_mean, np.sqrt(obs_variance), s)]
        return (np.matrix(perturbed_obs), np.matrix(np.cov(perturbed_obs))) 
            
    #--------------------------------------------------------------------------
    def _get_measurement_operator(self, s):
        """Defines the measurement operator for the Kalman gain calculations.
        parameter s is the size of the square matrix. Currently this function
        returns an identity matrix."""
        return np.matrix(np.identity(s))
        
    #--------------------------------------------------------------------------
    def _get_ensemble_states(self, variables):
        """Returns matrix with ensemble states for the defined variables.
        """
        
        logger = self.EnKF_logger
        logger.debug("Retrieving ensemble states.")
        ensemble_state = []
        for member in self:
            member_state = []
            for var in variables:
                logger.debug("Getting state variable %s from member." % var)
                member_state.append(member.get_variable(var))
            ensemble_state.append(member_state)
        return (np.matrix(ensemble_state)).T
    
    #--------------------------------------------------------------------------
    def EnKF_results_to_output_device(self, database=None):
        """Sends EnKF results to an output device.
        
        EnKF results include variances, kalman gain and ensemble_mean.
        
        Currently only a database is implemented as an output device.
        """
        
        logger = self.EnKF_logger
        if database is not None:
            if isinstance(database, sqlalchemy.schema.MetaData):
                try:
                    metadata = database
                    table_enkf_output = Table("enkf_output", metadata,
                                              autoload=True)
                    i = table_enkf_output.insert()
                    i.execute(self.enkf_updates)
                    logger.info("Successfully written EnKF results to DB.")
                except:
                    msg = "Error writing EnKF output to database."
                    raise RuntimeError(msg)
        else:
            msg = "No output device specified in call to "+\
                  "EnKF_results_to_output_device()."
            raise RuntimeError(msg)
