# Copyright (C) 2011-2015 David Maxwell and Constantine Khroulev
#
# This file is part of PISM.
#
# PISM is free software; you can redistribute it and/or modify it under the
# terms of the GNU General Public License as published by the Free Software
# Foundation; either version 3 of the License, or (at your option) any later
# version.
#
# PISM is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License
# along with PISM; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA

"""Functions and objects relating to interaction with :cpp:class:`IceModelVec`'s from python."""

import PISM


class Access(object):

    """
    Python context manager to simplify `IceModelVec` access and ghost communication.

    In PISM C++ code, one uses :cpp:member:`IceModelVec::begin_access`/:cpp:member:`IceModelVec::end_access`
    access pairs to delimit a code block allowing access to the contents of an :cpp:class`IcdeModelVec`.
    If the contents of a ghosted vector were changed in the block, :cpp:member:`IceModelVec::update_ghosts` needs to
    be called to synchronize the ghosts.  Forgetting either an :cpp:member:`end_access` or an :cpp:member:`update_ghosts`
    leads to bugs.

    Python context managers are used in conjunction with ``with`` statements to execute code at the start
    and end of a code block.  A :class:`PISM.vec.Access` context manager is used to pair up
    :cpp:member:`begin_access`/:cpp:member:`end_access` and to call :cpp:member:`update_ghosts` if needed:
    Assuming that ``v1`` and ``v2`` are vectors::

      grid = v1.get_grid()
      with PISM.vec.Access(comm=v2, nocomm=v1):
        for (i, j) in grid.points():
          v2(i, j) = v1(i, j)**3

    On entry into the ``with`` block, :cpp:member:`begin_access` is called for both ``v1`` and ``v2``.
    On exit, :cpp:member:`end_access` is called for both ``v1`` and ``v2``, and :cpp:member:`update_ghosts`
    is called for just ``v2``."""

    def __init__(self, nocomm=None, comm=None):
        """

        :param nocomm: a vector or list of vectors to access such that
                       ghost communication *will not* occur when access is done.

        :param comm:   a vector or list of vectors to access such that
                       ghost communication *will* occur when access is done.

        """
        if not nocomm is None:
            if isinstance(nocomm, list) or isinstance(nocomm, tuple):
                self.nocomm = nocomm
            else:
                self.nocomm = [nocomm]
            for v in self.nocomm:
                v.begin_access()
        else:
            self.nocomm = None

        if not comm is None:
            if isinstance(comm, list) or isinstance(comm, tuple):
                self.comm = comm
            else:
                self.comm = [comm]
            for v in self.comm:
                v.begin_access()
        else:
            self.comm = None

    def __enter__(self):
        pass

    def __exit__(self, exc_type, exc_value, traceback):
        if not self.nocomm is None:
            for v in self.nocomm:
                v.end_access()
            self.nocomm = None

        if not self.comm is None:
            for v in self.comm:
                v.end_access()
                v.update_ghosts()
            self.comm = None


class ToProcZero(object):

    """Utility class for managing communication of :cpp:class:`IceModelVec`\'s to processor 0
    and converting them to numpy vectors (e.g. for plotting or otherwise viewing).  Typical use
    is to construct a :class:`ToProcZero` once to setup a communicator for a particular :cpp:class:`IceGrid`
    and :cpp:class:`IceModelVec` type, and then repeatedly use  :meth:`ToProcZero.communicate` as
    needed."""

    def __init__(self, grid, dof=1, dim=2):
        """

        :param grid: the :cpp:class:`IceGrid` to be shared by all :cpp:class:`IceModelVec`\'s
        :param dof:  the number of degrees of freedom for the :cpp:class:`IceModelVec`\'s
                     (e.g. 1 for scalar valued Vecs, 2 for horizontal velocity Vecs)
        :param dim:  Use 2 for :cpp:class:`IceModelVec2` types and 3 for :cpp:class:`IceModelVec3`"""
        self.grid = grid
        self.dof = dof
        self.dim = dim

        self.tmp_U = None
        self.tmp_U_natural = None

        if dim != 2:
            raise NotImplementedError()

        self.da = grid.get_dm(dof, 0)

        self.tmp_U = self.da.get().createGlobalVector()
        self.tmp_U_natural = self.da.get().createNaturalVector()
        self.scatter, self.U0 = PISM.PETSc.Scatter.toZero(self.tmp_U_natural)

    def __del__(self):
        if self.tmp_U is not None:
            self.tmp_U.destroy()
        if self.tmp_U_natural is not None:
            self.tmp_U_natural.destroy()

    def communicate(self, u):
        """Communicates an :cpp:class:`IceModelVec` to processor zero.

        :param u: the :cpp:class:`IceModelVec` to communicate
        :returns: On processor 0, a numpy array with contents communicated from `u`.
                  Otherwise returns ``None``."""
        comm = self.da.get().getComm()
        rank = comm.getRank()

        u.copy_to_vec(self.da, self.tmp_U)
        self.da.get().globalToNatural(self.tmp_U, self.tmp_U_natural)
        self.scatter.scatter(self.tmp_U_natural, self.U0, False, PISM.PETSc.Scatter.Mode.FORWARD)

        rv = None
        if rank == 0:
            if self.dof == 1:
                rv = self.U0[...].reshape(self.da.get().sizes, order='f').copy()
            else:
                s = self.da.get().sizes
                rv = self.U0[...].reshape((2, s[0], s[1]), order='f').copy()

        comm.barrier()

        return rv


def randVectorS(grid, scale, stencil_width=None):
    """Create an :cpp:class:`IceModelVec2S` of normally distributed random entries.

      :param grid:  The :cpp:class:`IceGrid` to use for creating the vector.
      :param scale: Standard deviation of normal distribution.
      :param stencil_width: Ghost stencil width for the vector. Use ``None`` to indicate
                            an unghosted vector.

    This function is not efficiently implemented.
    """
    rv = PISM.IceModelVec2S()
    if stencil_width is None:
        rv.create(grid, 'rand vec', PISM.WITHOUT_GHOSTS)
    else:
        rv.create(grid, 'rand vec', PISM.WITH_GHOSTS, stencil_width)
    shape = (grid.xm(), grid.ym())
    import numpy as np

    r = np.random.normal(scale=scale, size=shape)
    with Access(nocomm=rv):
        for (i, j) in grid.points():
            rv[i, j] = r[i - grid.xs(), j - grid.ys()]
    if stencil_width is not None:
        rv.update_ghosts()
    return rv


def randVectorV(grid, scale, stencil_width=None):
    """Create an :cpp:class:`IceModelVec2V` of normally distributed random entries.

      :param grid:  The :cpp:class:`IceGrid` to use for creating the vector.
      :param scale: Standard deviation of normal distribution.
      :param stencil_width: Ghost stencil width for the vector. Use ``None`` to indicate
                            an unghosted vector.

    This function is not efficiently implemented.
    """

    rv = PISM.IceModelVec2V()
    if stencil_width is None:
        rv.create(grid, 'rand vec', PISM.WITHOUT_GHOSTS)
    else:
        rv.create(grid, 'rand vec', PISM.WITH_GHOSTS, stencil_width)

    shape = (grid.xm(), grid.ym())
    import numpy as np
    r_u = np.random.normal(scale=scale, size=shape)
    r_v = np.random.normal(scale=scale, size=shape)
    with Access(nocomm=rv):
        for (i, j) in grid.points():
            rv[i, j].u = r_u[i - grid.xs(), j - grid.ys()]
            rv[i, j].v = r_v[i - grid.xs(), j - grid.ys()]
        if stencil_width is not None:
            rv.update_ghosts()
    return rv
