import os
import sys
import time
import datetime

from sqlalchemy import select, create_engine, MetaData, Table, and_
from sqlalchemy.sql import func
from sqlalchemy.exceptions import SQLAlchemyError
import numpy as np
from ..util import is_data_column
ma = np.ma

from ..pyfortran import w60_variable_description as w60_var_desc

#-------------------------------------------------------------------------------
def convert_to_masked_array(x, value=None):
    """Converts ordinary NumPy array to a masked array, masking the value
    provided by the value keyword.
    """
    if value is None:
        return 
    else:
        return ma.masked_where(np.equal(x, value), x)

#-------------------------------------------------------------------------------
class PyContainer(object):
    "Container object for holding stuff"
    pass

#-------------------------------------------------------------------------------
class TSPlotError(Exception):
    pass

#-------------------------------------------------------------------------------
class TSPlotDBError(TSPlotError):
    pass

#-------------------------------------------------------------------------------
class TSPlotNoVarError(TSPlotError):
    pass

#-------------------------------------------------------------------------------
class TSPlotInvalidVarError(TSPlotError):
    pass

#-------------------------------------------------------------------------------
class TSPlotNoDataError(TSPlotError):
    pass

#-------------------------------------------------------------------------------
class PyWofostOutputProcessor:
    """Class for processing results in pywofost_output table.
    """

    def __init__(self,  dsn=None, grid_no=None, year=None, crop_no=None,
                 simulation_mode=None, tbl=None):
        """Constructor for pywofost_tsplot.
        
        Keywords:
        * dsn - SQLAlchemy data source name
        * grid_no - grid id of spatial entity
        * crop_no - crop id of crop type
        * year - year as specified in crop_calendar
        * simulation_mode - wlp|pp for selecting water-limited|potential mode
        """        

        # Make DB connection and load output table.
        if dsn is not None:
            engine = create_engine(dsn)
            metadata = MetaData(engine)
            tbl_pw_output = Table('pywofost_output', metadata, autoload=True)
        elif isinstance(tbl, Table):
            tbl_pw_output = tbl
        else:
            raise TSPlotDBError("Cannot connect to database.")
            

        # Retrieve data characteristics: wofost data columns and ensemble size.
        self.data_columns = self._get_data_columns(tbl_pw_output)
        self.ensemble_size = self._get_ensemble_size(tbl_pw_output, grid_no,
            year, crop_no, simulation_mode)
        t = self._get_minmax_day(tbl_pw_output, grid_no, year, crop_no,
                                 simulation_mode)
        self.min_day = t[0]
        self.max_day = t[1]
        # Retrieve data from pywofost_output table and process it in a form
        # which is easy for plotting.
        self.pywofost_output = self._process_pywofost_output(tbl_pw_output,
            grid_no, year, crop_no, simulation_mode)
   #----------------------------------------------------------------------------
    def _get_data_columns(self, tbl):
        "Returns columns containing pywofost_output data"
        data_columns = []
        for column in tbl.columns:
            if is_data_column(column):
                data_columns.append(column)
        if len(data_columns) == 0:
            msg = "DB contains no columns that contain PyWofost output that"+\
                  "can be plotted."
            raise TSPlotDBError(msg)

        return data_columns
    
    #---------------------------------------------------------------------------
    def _get_ensemble_size(self, tbl, grid_no, year,
                            crop_no, simulation_mode):
        "Derive size of ensemble, can be size 1 for single run"

        s = select([func.max(tbl.c.ensemble_id)],
                   and_(tbl.c.grid_no == grid_no,
                        tbl.c.crop_no == crop_no,
                        tbl.c.year == year,
                        tbl.c.simulation_mode == simulation_mode.lower()))
        r = s.execute().fetchone()[0]
        if r is not None:
            ensemble_size = r + 1
        else:
            msg = ("No data found for crop_no %i, grid_no %i, year %i and "+\
                  "simulation mode %s") % (crop_no, grid_no, year, simulation_mode)
            raise TSPlotNoDataError(msg)

        return ensemble_size
    
    #---------------------------------------------------------------------------
    def _get_minmax_day(self, tbl, grid_no, year,
                         crop_no, simulation_mode):
        "Derive min/max day from pywofost_output"        
        s = select([func.min(tbl.c.day), func.max(tbl.c.day)],
                   and_(tbl.c.grid_no == grid_no,
                        tbl.c.crop_no == crop_no,
                        tbl.c.year == year,
                        tbl.c.simulation_mode == simulation_mode.lower()))
        row = s.execute().fetchone()
        minday = row[0]
        maxday = row[1]
        if not isinstance(minday, datetime.date):
            t = time.strptime(minday, '%Y-%m-%d')
            minday = datetime.date(t[0], t[1], t[2])
        if not isinstance(maxday, datetime.date):
            t = time.strptime(maxday, '%Y-%m-%d')
            maxday = datetime.date(t[0], t[1], t[2])

        if minday == maxday:
            msg = "DB contains only one day for crop_no %i, grid_no %i, "+\
                  "year %i and simulation mode %s" % (crop_no, grid_no,
                                                      year, simulation_mode)
            raise TSPlotNoDataError(msg)
                
        return (minday, maxday)

    #----------------------------------------------------------------------------
    def _process_pywofost_output(self, tbl, grid_no, year, crop_no,
                                  simulation_mode):
        """Process data in pywofost_output table for plotting purposes.
        
        Calculate ensemble mean and stdev when self.ensemble_size > 1."""
        
        pywofost_output = {}
        
        ensemble_size = self.ensemble_size
        data_columns = self.data_columns
        nr_columns = len(data_columns)
        nr_days = (self.max_day - self.min_day).days + 1
        
        # Define array for storing data, initialize with value NaN, so that we
        # can mask out these values later on.
        ta = np.zeros((nr_days, nr_columns, ensemble_size), dtype=np.float32)
        ta.fill(None)
        
        # Run over the ensemble IDs to fetch one ensemble member at a time.
        t1 = time.time()
        print "Processing ensemble members",
        for ensemble_id in range(ensemble_size):
            s = select([tbl],
                       and_(tbl.c.grid_no == grid_no,
                            tbl.c.crop_no == crop_no,
                            tbl.c.year == year,
                            tbl.c.simulation_mode == simulation_mode.lower(),
                            tbl.c.ensemble_id == ensemble_id),
                       order_by=[tbl.c.day])
            rows = s.execute().fetchall()
            
            # Add variable for each row/column in the array 'ta'
            for row in rows:
                # First get the day for this row
                day = row[tbl.c.day]
                # But we need the day relative to the first day as index into
                # the array 'ta'.
                day_pos = (day - self.min_day).days
                # Now process the columns which contain the actual data.
                for j, column in enumerate(data_columns):
                    ta[day_pos, j, ensemble_id] = row[column]
            sys.stdout.write(".")
        print "elapsed: %6.2f seconds." % (time.time()-t1)

        # Add time series for data as attribute  to the PyContainer object (t).
        # Start with adding the range of days for which PyWofost output exists
        pywofost_output = {}
        days = self._generate_dayrange(self.min_day, self.max_day)

        # Run through the different columns in the pywofost_output table.
        for j, column in enumerate(data_columns):
            t = PyContainer()
            setattr(t, "days", days)
            # Select the dimension corresponding with the column from array
            # (ta) and squeeze out the shallow dimension.
            x = ta[:,j,:].squeeze()
            setattr(t, 'values', x)
            # If the data contains an ensemble then calculate mean and stdev.
            # Note: stdev is calculated using NumPy which use N for
            # normalisation, not (N-1). See NumPy manual for ndarray.var()
            if self.ensemble_size > 1:
                # Convert to masked array and calc mean and stdev
                xma = ma.masked_array(x, np.isnan(x))
                x_mean = xma.mean(axis=1)
                setattr(t, 'mean', x_mean)
                x_stdev = xma.std(axis=1)
                setattr(t, 'stdev', x_stdev)

            varname = w60_var_desc.get_full_name(column.name)
            pywofost_output[varname] = t

        return pywofost_output
    
    #---------------------------------------------------------------------------
    def _generate_dayrange(self, min_day, max_day):
        "Generate list of days from min_day to max_day"
        
        nr_days = (max_day - min_day).days + 1
        return [(min_day + datetime.timedelta(days=i)) for i in range(nr_days)]

    #---------------------------------------------------------------------------
    def get_data_column_names(self):
        """Returns the column names (variables) present in the PyWofost DB. 
        """
        return self.pywofost_output.keys()
    
    #---------------------------------------------------------------------------
    def get_variable_data(self, varname):
        fullvarname = w60_var_desc.get_full_name(varname)
        return self.pywofost_output[fullvarname]
