"""Main module for running the PyWofost model.

Classes defined here:

* PyWofost - runtime class for running the PyWofost model
"""

import sys, os
import logging
import warnings
import datetime
import copy

import sqlalchemy
from sqlalchemy import Table
from sqlalchemy.exceptions import *
import numpy as np

from .pyfortran.py_w60lib_wrapper import py_W60lib_wrapper
from .pyfortran import w60_variable_description as w60_var_desc

#-------------------------------------------------------------------------------
class PyWofost(py_W60lib_wrapper):
    """Main class for running the PyWofost crop model.
    
    Subclasses from py_W60lib_wrapper in order to inherit the methods that
    wrap the FORTRAN77 code which define the actual WOFOST model. The PyWofost
    class only contains the programmatic flow of the different routines (e.g.
    rate calculation/rate integration/timer updates, etc.).
    
    Public methods defined here:
    * get_variable -- Return a state/rate variable from PyWofost
    * results_to_output_device -- send simulation results to file/DB
    
    Public attributes defined here:
    * saved_state_variables -- WOFOST state variables are appended to this list
                               during the model run.
    """

    def __init__(self, sitedata, timerdata, soildata, cropdata, meteo_fetcher,
                 mode='PP', ensemble_id=0, metadata=None):
        """Class constructor for PyWofost.
        
        The following positional arguments are needed (in this order):
        sitedata -- dictionary with SITE variables
        timerdata -- dictionary with TIMER variables (sowing/harvest/etc.)
        soildata -- dictionary with soil physiological variables
        cropdata -- dictionary with crop parameters for WOFOST
        meteo_fetcher -- callable with returns the meteodata for a given day.
        
        The following keyword arguments are available:
        mode -- set to 'WLP'/'PP' for water-limited/potential production mode
        ensemble_id -- ID of PyWofost member. Used for ensemble simulations,
            defaults to zero for single runs.
        metadata -- SQLAlchemy metadata object. Is used for deriving the column
            names in the table 'pywofost_output'. The column names are used in
            the method get_variable() to derive the WOFOST state variables.
            during the simulation. If metadata is not provide then a default
            set of variables is kept during the simulation.
        """

        # Define a logger for the PyWofost init routines
        id = str(ensemble_id).zfill(3)
        self.logger = logging.getLogger(('PyWofost.main.%s' % id))
    
        self.meteo_fetcher = meteo_fetcher
        self.simulation_mode = mode.lower()
        self.ensemble_id = ensemble_id
        self.grid_no = cropdata["GRID_NO"]
        self.crop_no = cropdata["CROP_NO"]
        self.year = timerdata["CAMPAIGNYEAR"]
        
        # Determine start/stop from crop calendar
        r = self._set_crop_calendar(timerdata)
        
        # Initialize timer based on crop calendar
        self.timer = self._timer_init(r.start_date_simulation,
                                      r.end_date_simulation,
                                      r.start_date_crop)
        
        # Initialize COMMON blocks in FORTRAN code
        self._init_COMMON_blocks(cropdata, soildata, sitedata)
        
        # Initialize CROPSI() FORTRAN routine
        (self.EVWMX, self.EVSMX, self.TRA, self.FR) = self._CROPSI_init()
        
        # Initialize ROOTD() FORTRAN routine
        self.RD = self._ROOTD_init(mode)
        
        # Initialize WATFD()/WATPP() FORTRAN routine based on production level
        if self.simulation_mode == 'pp':
            self.SM = self._WATPP_init()
        elif self.simulation_mode == 'wlp':
            self.SM = self._WATFD_init()
        else:
            errstr = "Unknown production level: %s, use either 'pp' or 'wlp'!"
            raise RuntimeError(errstr)
            
        # Create flag to indicate if model is in idle or active state.
        # This is important because if attributes are retrieved in idle state
        # they have to be retrieved from the internal dictionaries, while in
        # active state they have to be retrieved from the w60lib COMMON blocks.
        self._runtime_state = 'idle'

        # Create empty list for saving state variables during model run
        # and derive the list of variables to save
        self.saved_state_variables = []
        self._variables_to_save = self._get_variables_for_save(metadata)

        # Start initial cycle of rate calculation
        self._calc_initial_rates()

        # Pull data from COMMON blocks out and store in self
        self._save_COMMON_blocks()
        

    #---------------------------------------------------------------------------
    def _calc_initial_rates(self):
        """Calculates the initial rates of change and saves the states.
        
        This function calculates the rates of change based on the values of
        state values that are set during initialisation. It also retrieves the
        driving variables (meteo) for the current day, it saves the model
        states for the current day, and increases the timer with one day.
        
        This relates to Figure 4 in the FSE manual (page 15) at:
        http://library.wur.nl/way/catalogue/documents/fse.pdf
        
        It includes the first cycle starting at initialisation, driving
        variables, rate calculation, test finish condition, output and
        incrementing the counter.
        
        Each next cycle in the model is then carried out by the grow() function
        which progresses through the entire circle.
        """
        
        # retrieve logger
        logger = self.logger

        # Make some variables local which are passed between routines
        EVWMX = self.EVWMX
        EVSMX = self.EVSMX
        TRA = self.TRA
        RD = self.RD
        FR = self.FR
        SM = self.SM
        
        # Change state to 'active'
        self._runtime_state = 'active'

        #-------------------------------------------------------------------
        # Calculate rates on current day, ITASK=2
        #-------------------------------------------------------------------
        logger.debug("%s Starting rate calculation (ITASK=2)" %
                     self.day)

        # Get meteo data for current day
        dmeteo = self.meteo_fetcher(self.day, ensemble_id=self.ensemble_id)

        status = self._check_crop_status()
        if status == 'No crop':
            # Crop not yet sown or emerged
            pass
        elif status == 'Sowing':
            # Crop sown but not yet emerged
            (EVWMX, EVSMX, TRA, FR) = self._CROPSI_calc_rates(dmeteo, SM, EVWMX, EVSMX, TRA, FR)
        elif status == 'Emergence':
            # Crop has emerged.
            (EVWMX, EVSMX, TRA, FR) = self._CROPSI_calc_rates(dmeteo, SM, EVWMX, EVSMX, TRA, FR)
            RD = self._ROOTD_calc_rates(FR, RD)
            
        # Run water balance
        if self.simulation_mode == 'pp':
            SM = self._WATPP_calc_rates(EVWMX, EVSMX, TRA, SM)
        elif self.simulation_mode == 'wlp':
            SM = self._WATFD_calc_rates(EVWMX, EVSMX, TRA, SM, RD, dmeteo)

        # Save state variables for current day
        self._save_states()

        # Update timer
        self.day = self._timer_update()
        
        # Assign local variables back to self
        self.EVWMX = EVWMX
        self.EVSMX = EVSMX
        self.TRA = TRA
        self.RD = RD
        self.FR = FR
        self.SM = SM
        self.dmeteo = dmeteo

        # Change state back to 'idle'
        self._runtime_state = 'idle'


    #---------------------------------------------------------------------------
    def _check_crop_status(self):
        "Reports about the status of the crop: no crop/sowing/emergence"
        
        if self.timer.check_for_sowing_or_emergence() == False:
            return 'No crop'
        else:
            if self.ISTATE == 1:
                return 'Sowing'
            else:
                return 'Emergence'

    #---------------------------------------------------------------------------
    def _get_variables_for_save(self, metadata):
        """Returns the variable names of the variables that will be saved
        during model run.
        
        If metadata is None then a default set of variables is returned.
        Otherwise, metadata must be an SQLALchemy metadata object which is
        used to find out which columns are in the 'pywofost_output' table.
        The column names that are not in the primary key are returned as
        variable names."""

        default = ["development_stage","leaf_area_index",
                   "total_aboveground_biomass", "total_weight_storage_organs",
                   "total_weight_leaves", "total_weight_stems",
                   "total_weight_roots", "profile_moisture_content",
                   "rootzone_moisture_content", "transpiration",
                   "rooting_depth"]
        run_descriptors = ["grid_no","crop_no","year","day","simulation_mode", 
                           "ensemble_id"]
        
        if metadata is None:
            return default
        elif isinstance(metadata, sqlalchemy.schema.MetaData):
            var = []
            pw_output = Table('pywofost_output', metadata, autoload=True)
            for col in pw_output.columns:
                if col.name not in run_descriptors:
                    if w60_var_desc.check_variable(col.name) is True:
                        var.append(col.name)
                    else:
                        msg = "Column %s in pywofost_output table " +\
                              "not recognized as PyWofost variable."
                        raise RuntimeError(msg % col.name)
            return var
        else:
            msg = "Metadata passed is not an SQLAlchemy metadata object." +\
                "Returning default list of variables."
            warnings.warn(msg)
            self.logger.warn(msg)
            return default

    #---------------------------------------------------------------------------
    def _save_states(self):
        """Appends WOFOST variables to self.saved_state_variables for this day.
        """
        
        # Create initial dictionary of states, this is basically the primary
        # key in table pywofost_output.
        states = {"ensemble_id" : self.ensemble_id,
                  "day" : self.day,
                  "grid_no" : self.grid_no,
                  "crop_no" : self.crop_no,
                  "year": self.year,
                  "simulation_mode" : self.simulation_mode}

        for var in self._variables_to_save:
            states[var] = self.get_variable(var)
        self.saved_state_variables.append(states)
            
    #---------------------------------------------------------------------------
    def _check_terminate(self):
        """Return True if a terminate condition was reached.
        """
        
        # Check if terminate condition has been reached through some crop
        # condition (CROP_TERMNL set by CROPSI.FOR routine)
        if self.CROP_TERMNL == 1:
            return 1
        # Check if terminate condition has been reached through the timer.
        # For example a fixed harvest date.
        elif self.time_termnl == 1:
            return 1
        else:
            return 0

    #---------------------------------------------------------------------------
    def grow(self, days=1):
        """Runs the crop model for given number of days.
        """
        
        # retrieve logger
        logger = self.logger
        
        # Restore COMMON blocks in Py_w60lib
        self._restore_COMMON_blocks()
        
        # Make some variables local which are passed between routines
        EVWMX = self.EVWMX
        EVSMX = self.EVSMX
        TRA = self.TRA
        RD = self.RD
        FR = self.FR
        SM = self.SM
        dmeteo = self.dmeteo
        
        # Change state to 'active'
        self._runtime_state = 'active'
        
        days_done = 0
        terminate = self._check_terminate()
        while ((terminate == 0) and (days_done < days)):
            days_done += 1
            
            status = self._check_crop_status()

            #-------------------------------------------------------------------
            # Integration of rates of previous day, ITASK=3
            #-------------------------------------------------------------------
            logger.debug("%s Starting rate integration (ITASK=3)" %
                         self.day)
            
            if status == 'No crop':
                # Crop not yet sown or emerged
               (EVWMX, EVSMX, TRA, RD) = self._no_crop(dmeteo)
            elif status == 'Sowing':
                # Crop sown but not yet emerged
                (EVWMX, EVSMX, TRA, FR) = self._CROPSI_integrate(dmeteo, SM, EVWMX, EVSMX, TRA, FR)
            elif status == 'Emergence':
                # Crop has emerged.
                (EVWMX, EVSMX, TRA, FR) = self._CROPSI_integrate(dmeteo, SM, EVWMX, EVSMX, TRA, FR)
                RD = self._ROOTD_integrate(FR, RD)

            # Run water balance
            if self.simulation_mode == 'pp':
                SM = self._WATPP_integrate(EVWMX, EVSMX, TRA, SM)
            elif self.simulation_mode == 'wlp':
                SM = self._WATFD_integrate(EVWMX, EVSMX, TRA, SM, RD, dmeteo)

            #-------------------------------------------------------------------
            # Calculate rates on current day, ITASK=2
            #-------------------------------------------------------------------
            logger.debug("%s Starting rate calculation (ITASK=2)" %
                         self.day)

            # Get meteo data for current day
            dmeteo = self.meteo_fetcher(self.day, ensemble_id=self.ensemble_id)

            if status == 'No crop':
                # Crop not yet sown or emerged
                pass
            elif status == 'Sowing':
                # Crop sown but not yet emerged
                (EVWMX, EVSMX, TRA, FR) = self._CROPSI_calc_rates(dmeteo, SM, EVWMX, EVSMX, TRA, FR)
            elif status == 'Emergence':
                # Crop has emerged.
                (EVWMX, EVSMX, TRA, FR) = self._CROPSI_calc_rates(dmeteo, SM, EVWMX, EVSMX, TRA, FR)
                RD = self._ROOTD_calc_rates(FR, RD)
                
            # Run water balance
            if self.simulation_mode == 'pp':
                SM = self._WATPP_calc_rates(EVWMX, EVSMX, TRA, SM)
            elif self.simulation_mode == 'wlp':
                SM = self._WATFD_calc_rates(EVWMX, EVSMX, TRA, SM, RD, dmeteo)
            
            # Save state variables for current day
            self._save_states()

            # Update timer
            self.day = self._timer_update()
            
            # Check for end of simulation
            terminate = self._check_terminate()

        # Assign local variables back to self
        self.EVWMX = EVWMX
        self.EVSMX = EVSMX
        self.TRA = TRA
        self.RD = RD
        self.FR = FR
        self.SM = SM
        self.dmeteo = dmeteo

        # Save COMMON blocks in Py_w60lib
        self._save_COMMON_blocks()

        # Change state back to 'idle'
        self._runtime_state = 'idle'
        
        return terminate

    #---------------------------------------------------------------------------
    def get_variable(self, var=None):
        """Retrieve WOFOST state/rate variables.
        
        Invoke without arguments to see a list of variables.
        """
        
        # Three possibilities for WOFOST variable retrieval:
        # 1. attr is one of ("rootzone_moisture_content",
        #    "profile_moisture_content", "rooting_depth", "transpiration")
        #    In this case the variables can be retrieved directly from self
        # 2. PyWofost is in 'active' runtime_state. In this case the variables
        #    have to be retrieved from the w60lib COMMON blocks. First the
        #    short FORTRAN name for this variable is retrieved through
        #    get_FORTRAN_name(), then the actual value is retrieved through
        #    _find_w60lib_variable().
        # 3. PyWofost is in 'idle' runtime state. In this case the variables
        #    have to be retrieved from the internal copies of the FORTRAN
        #    COMMON blocks (i.e. self.STATES, self.RATES, self.CROOTD,
        #    self.CWATFD, self.CWATPP)
        
        # if var is None then print the available variables in case of
        # interactive mode, otherwise raise an error.
        if var is None:
            if self._runtime_state == 'idle':
                w60_var_desc.print_variables()
                return
            else:
                errstr = "get_variable() called with no argument!"
                raise RuntimeError(errstr)

        # Map complete attribute name to short FORTRAN name.
        short_var = w60_var_desc.get_FORTRAN_name(var.lower())

        # Case 1
        if short_var in ("rmc","pmc","rd", "tra"):
            if hasattr(self, short_var):
                value = getattr(self, short_var)
            elif hasattr(self, short_var.upper()):
                value = getattr(self, short_var.upper())
            else:
                value = None
        # Case 2
        elif self._runtime_state == 'active':
            value = self._find_w60lib_variable(short_var)
        # Case 3
        elif self._runtime_state == 'idle':
            # Map complete attribute name to short FORTRAN name.
            if short_var in self.STATES:
                value = self.STATES[short_var]
            elif short_var in self.RATES:
                value = self.RATES[short_var]
            elif short_var in self.CROOTD:
                value = self.CROOTD[short_var]
            elif short_var in self.CWATPP and self.IWB==0:
                value = self.CWATPP[short_var]
            elif short_var in self.CWATFD and self.IWB==1:
                value = self.CWATFD[short_var]
            else:
                value = None

        if value is None:
            errstr = "None value retrieved for WOFOST variable: '%s'" % var
            warnings.warn(errstr)
            self.logger.warn(errstr)

        # Convert 1-element float/int array into python float/int
        elif isinstance(value, np.ndarray):
            if len(value.ravel()) == 1 and value.dtype == np.float32:
                value = float(value)
            elif len(value.ravel()) == 1 and value.dtype == np.int32:
                value = int(value)

        return value

    #---------------------------------------------------------------------------
    def set_variable(self, var=None, value=None):
        """Update WOFOST state/rate variable.
        
        Keyword variables:
        var -- name of variable
        value -- value to be assigned
        
        Invoke without arguments to see a list of supported variables.
        """
        
        supported_vars = ('profile_moisture_content', 'leaf_area_index')

        # Print list of supported variables if no arguments are given
        if var is None:
            print "Supported variables:"
            for variable in supported_vars:
                print "- %s" % variable
            return
            
        # Check if variable is supported for updating.
        if var not in supported_vars:
            errstr = "Variable '%s' not supported for update."
            self.logger.error(errstr)
            raise RuntimeError(errstr)
        
        # Check if value is provided and is a number
        if isinstance(value, int) or isinstance(value, float):
            pass
        else:
            errstr = "Value passed to set_variable() not int or float."
            self.logger.error(errstr)
            raise RuntimeError(errstr)

        if var == 'leaf_area_index':
            LAInew = value
            #  Get current LAI, Specific Leaf Area and Leave
            # age distribution classes
            LAIold = self.get_variable(var)
            SLA = self.STATES["sla"]
            LV = self.STATES["lv"]
            
            # IF LAIold > 0 then update the lai by multiplying the leave age
            # distribution (LV) with the ration of LAInew/LAIold.
            if LAIold > 0.:
                # Calculate relative change in LAI:
                rLAI = LAInew/LAIold
                
                # Update LAI classes with relative change on array with
                # leaf biomass distribution 'LV'.
                LV *= rLAI
            # If LAIold == 0 then add the leave biomass directly to the 
            # youngest leave age class (LV[0])
            else:
                LV[0] = LAInew/SLA[0]

            # Calculate new LAI by summing LV*SLA
            # Note that this is not entirely correct because stems and
            # pods can contribute to LAI (CROPSI.for line 566). However,
            # SSA and SPA are zero for all current crops in CGMS.
            LAIsum = numpy.sum(LV * SLA)

            # Check that LAIsum obtained from multiplying leaf weight (LV)
            # with Specific Leaf Area (SLA) equals the LAI update from
            # the EnKF
            try:
                assert(LAInew - LAIsum) <  0.001
            except AssertionError:
                msg = "Error in updating leaf area index on ensemble " +\
                      "member %i: Updated integrated LAI does not equal " +\
                      "leaf weight (LV) times specific leaf area (SLA)"
                raise RuntimeError(msg % self.ensemble_id)
            
            # Update WOFOST object
            self.STATES["lv"] = LV
            self.STATES["lasum"] = LAIsum
            self.STATES["lai"] = LAInew
            
        elif var == 'profile_moisture_content':
            # Note that variable 'WWLOW' contains the amount of soil moisture
            # in the total soil column. However, updating should take place on
            # individual variables describing the soil moisture in the root 
            # zone 'W' and the soil moisture in the subsoil 'WLOW'.
            # Note that the variable SM contains the volumetric amount of soil
            # moisture in the root zone only.
            SMnew = value
            # Get current values
            SM = self.get_variable(var)
            W = self.get_variable("rootzone_water_depth")
            WLOW = self.get_variable("subsoil_water_depth")
            WWLOW = self.get_variable("profile_water_depth")

            # Calculate W and WLOW as fraction of WWLOW
            fr_WLOW = WLOW/WWLOW
            fr_W = W/WWLOW

            # Update total amount of soil water in soil column
            WWLOW_new = SMnew * self.RDM

            # Update W and WLOW based on the fraction of water they contained
            WLOW_new = fr_WLOW * WWLOW_new
            W_new = fr_W * WWLOW_new

            # Make sure that update went properly
            try:
                assert (WWLOW_new - (WLOW_new + W_new)) < 0.001
            except AssertionError:
                msg = "Error in updating profile moisture content on " +\
                      "ensemble member %i: Updated moisture content does not " +\
                      "equal value provided to set_variable()."
                raise RuntimeError(msg % self.ensemble_id)

            # Update state of the WOFOST object
            self.CWATFD["wwlow"] = WWLOW_new
            self.CWATFD["wlow"] = WLOW_new
            self.CWATFD["w"] = W_new
            self.SM = W_new / self.RD
        
    #---------------------------------------------------------------------------
    def results_to_output_device(self, database=None, outputfile=None,
                                 pad_results=None):
        """Write self.saved_state_variables to an output device.
        
        The following devices (keywords) are supported:
        database -- SQLAlchemy DB metadata object
        outputfile -- filename *or* file object
        
        Additional keywords:
        pad_results -- Date object used to copy the final results up to 
          this date. This is useful when the results have to be inserted in a
          database and aggregated to larger spatial regions which differ in
          crop phenology.
        """
        
        # Get a logger for I/O of PyWofost results
        id = str(self.ensemble_id).zfill(3)
        logger = logging.getLogger(('PyWofost.IO.%s' % id))
        
        if isinstance(pad_results, datetime.date) and \
          (len(self.saved_state_variables) > 0):
            final_results = copy.deepcopy(self.saved_state_variables[-1])
            date_at_final = final_results.pop("day")
            if pad_results > date_at_final:
                timediff = pad_results - date_at_final
                for i in range(1, timediff.days+1):
                    new_date = date_at_final + datetime.timedelta(days=i)
                    t = final_results.copy()
                    t["day"] = new_date
                    self.saved_state_variables.append(t)
            logger.debug("Padded results up to %s" % pad_results)

        if outputfile is not None:
            output_fmt = \
                "%(grid_no)10i\t%(crop_no)9i\t%(day)11s\t" +\
                "%(simulation_mode)9s\t%(ensemble_id)6i\t" +\
                "%(development_stage)6.3f\t%(leaf_area_index)6.2f\t" +\
                "%(total_aboveground_biomass)7.1f\t" +\
                "%(total_weight_storage_organs)7.1f\t" +\
                "%(total_weight_leaves)7.1f\t" + \
                "%(total_weight_stems)7.1f\t%(total_weight_roots)7.1f"+\
                "\t%(profile_moisture_content)7.3f\t" +\
                "%(rootzone_moisture_content)6.3f\t" +\
                "%(transpiration)7.2f\t%(rooting_depth)6.1f\n"
            outputlines = [(output_fmt % s) for s in self.saved_state_variables]
            
            # If file is a string we assume this is the filename to
            # use. Otherwise we assume that a file object is provided
            # that can be used to write the data to (without a header)
            if isinstance(outputfile, str):
                try:
                    try:
                        fp = open(outputfile, 'w')
                        header_fmt = "#%9s\t%9s\t%11s\t%9s\t%6s\t%6s\t%6s\t"+\
                                     "%7s\t%7s\t%7s\t%7s\t%7s\t%7s\t%6s\t"+\
                                     "%7s\t%6s\n"
                        hdr_line1 = ('Grid No','Crop No','Day', 'Sim mode',
                                     'Sim ID','DVS','LAI','TAGP', 'TWSO','TWLV',
                                     'TWST','TWRT','WWLOW', 'SM', 'TRA','RD')
                        hdr_line2 = ('','','','','','-','-', 'kg/ha','kg/ha',
                                     'kg/ha','kg/ha','kg/ha','-', '-','cm/day',
                                     'cm')
                        fp.writelines([header_fmt % hdr_line1,
                                       header_fmt % hdr_line2])
                        fp.writelines(outputlines)
                        logger.debug(("Succesfully written PyWofost results "+\
                                      "to output file %s!") % outputfile)
                    except IOError:
                        msg = ("Error opening or writing to output "+\
                                      "file %s!") % outputfile
                        raise RuntimeError(msg)
                finally:
                    fp.close()
            elif isinstance(outputfile, file):
                try:
                    outputfile.writelines(outputlines)
                    logger.debug(("Succesfully appended PyWofost results to "+\
                                  "output file %s!") % outputfile.name)
                except IOError:
                    msg = ("Error appending PyWofost results to output"+\
                                  " file %s!") % outputfile.name
                    raise RuntimeError(msg)
            else:
                msg = "outputfile in results_to_output_device() not string "+\
                      "nor file object."
                raise RuntimeError(msg)
                
        elif database is not None:
            if isinstance(database, sqlalchemy.schema.MetaData):
                try:
                    metadata = database
                    table_pyw_output = Table('pywofost_output', metadata, autoload=True)
                    i = table_pyw_output.insert()
                    i.execute(self.saved_state_variables)
                    logger.debug("Succesfully written PyWofost results to database")
                except SQLAlchemyError, inst:
                    msg = "Error writing PyWofost results to database: %s"
                    raise RuntimeError(msg % inst)
            else:
                msg = "database in results_to_output_device() not SQLAlchemy "+\
                      "metadata object."
                raise RuntimeError(msg)
        else:
            msg = "results_to_output_device() called without providing a " +\
                  "database or outputfile."
            raise RuntimeError(msg)
