# Copyright (C) 2012, 2014, 2015, 2016 David Maxwell and Constantine Khroulev
#
# This file is part of PISM.
#
# PISM is free software; you can redistribute it and/or modify it under the
# terms of the GNU General Public License as published by the Free Software
# Foundation; either version 3 of the License, or (at your option) any later
# version.
#
# PISM is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License
# along with PISM; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA

"""Inverse SSA solvers using the TAO library."""

import PISM
from PISM.util import Bunch
from PISM.logging import logError
from PISM.invert.ssa import InvSSASolver

import sys
import traceback


class InvSSASolver_Tikhonov(InvSSASolver):

    """Inverse SSA solver based on Tikhonov iteration using TAO."""

    # Dictionary converting PISM algorithm names to the corresponding
    # TAO algorithms used to implement the Tikhonov minimization.
    tao_types = {}

    if (not PISM.imported_from_sphinx) and PISM.PETSc.Sys.getVersion() < (3, 5, 0):
        tao_types = {'tikhonov_lmvm': 'tao_lmvm',
                     'tikhonov_cg': 'tao_cg',
                     'tikhonov_lcl': 'tao_lcl',
                     'tikhonov_blmvm': 'tao_blmvm'}
    else:
        tao_types = {'tikhonov_lmvm': 'lmvm',
                     'tikhonov_cg': 'cg',
                     'tikhonov_lcl': 'lcl',
                     'tikhonov_blmvm': 'blmvm'}


    def __init__(self, ssarun, method):
        """
        :param ssarun: The :class:`PISM.invert.ssa.SSAForwardRun` defining the forward problem.
        :param method: String describing the actual algorithm to use. Must be a key in :attr:`tao_types`."""

        InvSSASolver.__init__(self, ssarun, method)
        self.listeners = []
        self.solver = None
        self.ip = None
        if self.tao_types.get(method) is None:
            raise ValueError("Unknown TAO Tikhonov inversion method: %s" % method)

    def addIterationListener(self, listener):
        """Add a listener to be called after each iteration.  See :ref:`Listeners`."""
        self.listeners.append(listener)

    def addDesignUpdateListener(self, listener):
        """Add a listener to be called after each time the design variable is changed."""
        self.listeners.append(listener)

    def solveForward(self, zeta, out=None):
        r"""Given a parameterized design variable value :math:`\zeta`, solve the SSA.
        See :cpp:class:`IP_TaucParam` for a discussion of parameterizations.

        :param zeta: :cpp:class:`IceModelVec` containing :math:`\zeta`.
        :param out: optional :cpp:class:`IceModelVec` for storage of the computation result.
        :returns: An :cpp:class:`IceModelVec` contianing the computation result.
        """
        ssa = self.ssarun.ssa

        reason = ssa.linearize_at(zeta)
        if reason.failed():
            raise PISM.AlgorithmFailureException(reason)
        if out is not None:
            out.copy_from(ssa.solution())
        else:
            out = ssa.solution()
        return out

    def solveInverse(self, zeta0, u_obs, zeta_inv):
        r"""Executes the inversion algorithm.

        :param zeta0: The best `a-priori` guess for the value of the parameterized design variable :math:`\zeta`.
        :param u_obs: :cpp:class:`IceModelVec2V` of observed surface velocities.
        :param zeta_inv: :cpp:class:`zeta_inv` starting value of :math:`\zeta` for minimization of the Tikhonov functional.
        :returns: A :cpp:class:`TerminationReason`.
        """
        eta = self.config.get_double("inverse.tikhonov.penalty_weight")

        design_var = self.ssarun.designVariable()
        if design_var == 'tauc':
            if self.method == 'tikhonov_lcl':
                problemClass = PISM.IP_SSATaucTaoTikhonovProblemLCL
                solverClass = PISM.IP_SSATaucTaoTikhonovProblemLCLSolver
                listenerClass = TaucLCLIterationListenerAdaptor
            else:
                problemClass = PISM.IP_SSATaucTaoTikhonovProblem
                solverClass = PISM.IP_SSATaucTaoTikhonovSolver
                listenerClass = TaucIterationListenerAdaptor
        elif design_var == 'hardav':
            if self.method == 'tikhonov_lcl':
                problemClass = PISM.IP_SSAHardavTaoTikhonovProblemLCL
                solverClass = PISM.IP_SSAHardavTaoTikhonovSolverLCL
                listenerClass = HardavLCLIterationListenerAdaptor
            else:
                problemClass = PISM.IP_SSAHardavTaoTikhonovProblem
                solverClass = PISM.IP_SSAHardavTaoTikhonovSolver
                listenerClass = HardavIterationListenerAdaptor
        else:
            raise RuntimeError("Unsupported design variable '%s' for InvSSASolver_Tikhonov. Expected 'tauc' or 'hardness'" % design_var)

        tao_type = self.tao_types[self.method]
        (stateFunctional, designFunctional) = PISM.invert.ssa.createTikhonovFunctionals(self.ssarun)

        self.ip = problemClass(self.ssarun.ssa, zeta0, u_obs, eta, stateFunctional, designFunctional)
        self.solver = solverClass(self.ssarun.grid.com, tao_type, self.ip)

        max_it = PISM.optionsInt("-inv_max_it", "maximum iteration count", default=1000)
        self.solver.setMaximumIterations(max_it)

        pl = [listenerClass(self, l) for l in self.listeners]

        for l in pl:
            self.ip.addListener(l)

        self.ip.setInitialGuess(zeta_inv)

        vecs = self.ssarun.modeldata.vecs
        if vecs.has('zeta_fixed_mask'):
            self.ssarun.ssa.set_tauc_fixed_locations(vecs.zeta_fixed_mask)

        return self.solver.solve()

    def inverseSolution(self):
        """Returns a tuple ``(zeta,u)`` of :cpp:class:`IceModelVec`'s corresponding to the values
        of the design and state variables at the end of inversion."""
        zeta = self.ip.designSolution()
        u = self.ip.stateSolution()
        return (zeta, u)


class TaucLCLIterationListenerAdaptor(PISM.IP_SSATaucTaoTikhonovProblemLCLListener):

    """Adaptor converting calls to a C++ :cpp:class:`IP_SSATaucTaoTikhonovProblemListener`
    on to a standard python-based listener.  Used internally by
    :class:`InvSSATaucSolver_Tikhonov`.  I.e. don't make one of these for yourself."""

    def __init__(self, owner, listener):
        """:param owner: The :class:`InvSSATaucSolver_Tikhonov` that constructed us
           :param listener: The python-based listener.
         """
        PISM.IP_SSATaucTaoTikhonovProblemLCLListener.__init__(self)
        self.owner = owner
        self.listener = listener

    def iteration(self, problem, eta, it, objVal, penaltyVal, d, diff_d, grad_d, u, diff_u, grad_u, constraints):
        """Called during IP_SSATaucTaoTikhonovProblemLCL iterations.  Gathers together the long list of arguments
        into a dictionary and passes it along in standard form to the python listener."""

        data = Bunch(tikhonov_penalty=eta, JDesign=objVal, JState=penaltyVal,
                     zeta=d, zeta_step=diff_d, grad_JDesign=grad_d,
                     u=u, residual=diff_u, grad_JState=grad_u,
                     constraints=constraints)
        try:
            self.listener(self.owner, it, data)
        except Exception:
            logError("\nERROR: Exception occured during an inverse solver listener callback:\n\n")
            traceback.print_exc(file=sys.stdout)
            raise


class TaucIterationListenerAdaptor(PISM.IP_SSATaucTaoTikhonovProblemListener):

    """Adaptor converting calls to a C++ :cpp:class:`IP_SSATaucTaoTikhonovProblemListener`
    on to a standard python-based listener.  Used internally by
    :class:`InvSSATaucSolver_Tikhonov`.  I.e. don't make one of these for yourself."""

    def __init__(self, owner, listener):
        """:param owner: The :class:`InvSSATaucSolver_Tikhonov` that constructed us
           :param listener: The python-based listener.
         """
        PISM.IP_SSATaucTaoTikhonovProblemListener.__init__(self)
        self.owner = owner
        self.listener = listener

    def iteration(self, problem, eta, it, objVal, penaltyVal, d, diff_d, grad_d, u, diff_u, grad_u, grad):
        """Called during IP_SSATaucTaoTikhonovProblem iterations.  Gathers together the long list of arguments
        into a dictionary and passes it along in a standard form to the python listener."""
        data = Bunch(tikhonov_penalty=eta, JDesign=objVal, JState=penaltyVal,
                     zeta=d, zeta_step=diff_d, grad_JDesign=grad_d,
                     u=u, residual=diff_u, grad_JState=grad_u, grad_JTikhonov=grad)
        try:
            self.listener(self.owner, it, data)
        except Exception:
            logError("\nERROR: Exception occured during an inverse solver listener callback:\n\n")
            traceback.print_exc(file=sys.stdout)
            raise


class HardavIterationListenerAdaptor(PISM.IP_SSAHardavTaoTikhonovProblemListener):

    """Adaptor converting calls to a C++ :cpp:class:`IP_SSATaucTaoTikhonovProblemListener`
    on to a standard python-based listener.  Used internally by
    :class:`InvSSATaucSolver_Tikhonov`.  I.e. don't make one of these for yourself."""

    def __init__(self, owner, listener):
        """:param owner: The :class:`InvSSATaucSolver_Tikhonov` that constructed us
           :param listener: The python-based listener.
         """
        PISM.IP_SSAHardavTaoTikhonovProblemListener.__init__(self)
        self.owner = owner
        self.listener = listener

    def iteration(self, problem, eta, it, objVal, penaltyVal, d, diff_d, grad_d, u, diff_u, grad_u, grad):
        """Called during IP_SSATaucTaoTikhonovProblem iterations.  Gathers together the long list of arguments
        into a dictionary and passes it along in a standard form to the python listener."""
        data = Bunch(tikhonov_penalty=eta, JDesign=objVal, JState=penaltyVal,
                     zeta=d, zeta_step=diff_d, grad_JDesign=grad_d,
                     u=u, residual=diff_u, grad_JState=grad_u, grad_JTikhonov=grad)
        try:
            self.listener(self.owner, it, data)
        except Exception:
            logError("\nERROR: Exception occured during an inverse solver listener callback:\n\n")
            traceback.print_exc(file=sys.stdout)
            raise
