# Copyright (C) 2012, 2014, 2015, 2016 David Maxwell
#
# This file is part of PISM.
#
# PISM is free software; you can redistribute it and/or modify it under the
# terms of the GNU General Public License as published by the Free Software
# Foundation; either version 3 of the License, or (at your option) any later
# version.
#
# PISM is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License
# along with PISM; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA

"""Tikhonov Gauss-Newton inverse method code."""

import PISM
from PISM.invert.ssa import InvSSASolver

class InvSSASolver_TikhonovGN(InvSSASolver):
    "Tikhonov Gauss-Newton inverse solver."
    def __init__(self, ssarun, method):
        self.solver = None
        InvSSASolver.__init__(self, ssarun, method)

        self.target_misfit = PISM.optionsFlag("-inv_target_misfit",
                                              "m/year; desired root misfit for inversions", default=None)

        if self.target_misfit is None:
            raise RuntimeError("Missing required option -inv_target_misfit")

        # FIXME: m_vel_scale is not defined (what are the units?)
        self.target_misfit = self.target_misfit / m_vel_scale

        self.listeners = []

    def solveForward(self, zeta, out=None):
        ssa = self.ssarun.ssa

        reason = ssa.linearize_at(zeta)
        if reason.failed():
            raise PISM.AlgorithmFailureException(reason.description())
        if out is not None:
            out.copy_from(ssa.solution())
        else:
            out = ssa.solution()
        return out

    def solveInverse(self, zeta0, u_obs, zeta_inv):
        eta = self.config.get_double("inverse.tikhonov.penalty_weight")

        (designFunctional, stateFunctional) = PISM.invert.ssa.createTikhonovFunctionals(self.ssarun)
        self.solver = PISM.IP_SSATaucTikhonovGNSolver(self.ssarun.ssa, zeta0, u_obs, eta, designFunctional, stateFunctional)

        vel_scale = self.ssarun.grid.ctx().config().get_double("inverse.ssa.velocity_scale")
        self.solver.setTargetMisfit(self.target_misfit / vel_scale)

#    pl = [ TikhonovIterationListenerAdaptor(self,l) for l in self.listeners ]
        pl = []
        for l in pl:
            self.solver.addListener(l)
        self.solver.setInitialGuess(zeta_inv)

        if PISM.optionsFlag("-inv_test_adjoint", ""):
            self.solver.init()
            grid = self.ssarun.grid
            d1 = PISM.vec.randVectorS(grid, 1)
            d2 = PISM.vec.randVectorS(grid, 1)
            y1 = PISM.IceModelVec2S()
            y1.create(grid, '', PISM.WITHOUT_GHOSTS)
            y2 = PISM.IceModelVec2S()
            y2.create(grid, '', PISM.WITHOUT_GHOSTS)
            self.solver.apply_GN(d1, y1)
            self.solver.apply_GN(d2, y2)
            ip1 = y1.get_vec().dot(d2.get_vec())
            ip2 = y2.get_vec().dot(d1.get_vec())
            PISM.logging.logMessage("ip1 %.10g ip2 %.10g\n" % (ip1, ip2))
            PISM.logging.logMessage("ip1 %g ip2 %g\n" % (ip1, ip2))
            exit(0)

        vecs = self.ssarun.modeldata.vecs
        if vecs.has('zeta_fixed_mask'):
            self.ssarun.ssa.set_tauc_fixed_locations(vecs.zeta_fixed_mask)

        return self.solver.solve()

    def inverseSolution(self):
        zeta = self.solver.designSolution()
        u = self.solver.stateSolution()
        return (zeta, u)
