import numpy as np
import netCDF4
from pytmosph3r.model.hdf5model import HDF5Model
from pytmosph3r.grid import Grid3D
from pytmosph3r.atmosphere import InputAtmosphere
from pytmosph3r.emission import Emission
from pytmosph3r.planet import Planet
from pytmosph3r.util.util import arrays_to_zeros, aerosols_array_iterator, convert_log
import exo_k as xk

def get_longitudes(var):
    if var.shape[-1] > 1:
        var = var[..., :-1]
    return var

class DiagfiModel(HDF5Model):
    """Model reading from a diagfi.nc file, from a LMDZ GCM for example.
    Same parameters as :class:`~pytmosph3r.model.model.Model`, plus the ones listed here.
    See :class:`~pytmosph3r.interface.inputdata.DiagfiData` for the structure of the diagfi file.

    Args:
        t (int) : Time slice to read from diagfi input file.
        gas_units (str) : Used to indicate what units are in the input file. For example "log_vmr" indicates it is a log of a VMR. Possible units are  "vmr", "mmr", "log_vmr", "log_mmr","ln_vmr", "ln_mmr".
        aerosols_units (str) : Used to indicate what units are in the input file. For example "log_mmr" indicates it is a log of a MMR. Possible units are  "mmr", "log_mmr", "ln_mmr".
    """
    def __init__(self, t=-1, gas_units=None, aerosols_units=None, *args, **kwargs):
        self.t = int(t)
        """Time slice to read from diagfi."""
        self.gas_units = gas_units
        """Units of gases in input file."""
        self.aerosols_units = aerosols_units
        """Units of aerosols in input file."""

        HDF5Model.__init__(self, *args, **kwargs)

    def inputs(self):
        return super().inputs() + ["t", "gas_units", "aerosols_units"]

    def read_data(self):
        """Read data from a diagfi.nc file generated by the LMDZ GCM.
        The netCDF file should have at least a temperature `temp` and a pressure either defined by `p` on all grid points or by `ps` the surface pressure and coefficients `ap`, `bp` and `aps`, `bps`.
        For the gas and aerosols descriptions, see :any:`library_input`.
        The `controle` variable should (preferrably) list:

        #. n_{longitudes} (last longitude is a duplicate from the first one)

        #. n_{latitudes}

        #. n_{vertical} (optional)

        #. R_{p} (planet radius) in `m`

        #. #NOT USED#

        #. g_{0} (surface gravity)

        #. :math:`\mu` (molar mass) in `g/mol`

        """
        self.info("Reading model from %s"% self.filename)
        f = netCDF4.Dataset(self.filename)

        try:
            controle = f.variables['controle']
            Rp = int(controle[4])
            g0 = int(controle[6])
            self.planet = Planet(surface_gravity=g0, radius=Rp)
        except:
            self.warning("Could not find planet information in 'controle'.")
            self.planet = None

        n_latitudes = len(f.variables["latitude"])
        n_longitudes = len(f.variables["longitude"])-1 # ignore +180
        n_longitudes = max(n_longitudes, 1)

        """Reading pressure and temperature."""
        diagfi_temperature = get_longitudes(f.variables['temp'][self.t])
        try:
            pressure = None
            surface_pressure = get_longitudes(f.variables['ps'][self.t])
            ap = f.variables['ap'][:]
            bp = f.variables['bp'][:]
            p_levels = ap[:, None, None] + bp[:, None, None] * surface_pressure
            aps = f.variables['aps'][:]
            bps = f.variables['bps'][:]
            p_layers = aps[:, None, None] + bps[:, None, None] * surface_pressure
            n_vertical = len(p_layers)
            if isinstance(self.radiative_transfer, Emission):
                pressure = p_layers
            else:
                # in Transmission mode, we merge levels and layers (in the input atmosphere only)
                n_vertical = len(p_levels) + len(p_layers)
            assert p_levels.size
        except:
            if 'p' in f.variables:
                pressure = get_longitudes(f.variables['p'][self.t])
                n_vertical = len(pressure)
                if len(pressure) > len(diagfi_temperature):
                    # pressure given on levels. Get pressure points in middle of layers, to correspond to temperature
                    pressure = np.power(10, np.log10(pressure)[:-1]+ np.diff(np.log10(pressure), axis=0)/2.)
                    n_vertical = len(pressure)
                    self.warning("Pressure has one vertical point more than temperature. Pressure will be used as levels-wise while temperature is considered as layers-wise.")
            else:
                raise KeyError("The input file %s should have at least 'p' or 'ap'/'bp' entries defining the pressure in the model."%(self.filename))
        if isinstance(self.radiative_transfer, Emission):
            # in Emission mode, we add a surface point
            try:
                pressure = np.concatenate([[get_longitudes(f.variables['ps'][self.t])], pressure])
                n_vertical = len(pressure)
            except:
                self.warning("Could not read 'ps' (surface pressure) in input file.")

        """Create the grid now that we now its dimension."""
        grid = Grid3D(n_vertical=n_vertical, n_latitudes=n_latitudes, n_longitudes=n_longitudes)

        self.grid_to_radians(f, grid)
        diagfi_gas = self.read_gases(f, controle)
        diagfi_aerosols = self.read_aerosols(f)

        """Merging levels and layers (Transmission mode)."""
        if pressure is None:
            assert not isinstance(self.radiative_transfer, Emission), "Pressure should have been set earlier in Emission mode. Report this as a bug."

            pressure = np.zeros(grid.shape)
            temperature = np.zeros(grid.shape)
            gas_mix_ratio = {}
            for gas in diagfi_gas.keys():
                gas_mix_ratio[gas] = np.zeros(grid.shape)
            aerosols = arrays_to_zeros(diagfi_aerosols, grid.shape)

            for i in range(n_vertical):
                simple_idx = i == n_vertical-1 and -1 or int(i/2)
                if i%2 == 0:
                    pressure[i] = p_levels[int(i/2)]
                else:
                    pressure[i] = p_layers[int(i/2)]
                temperature[i] = diagfi_temperature[simple_idx]
                for gas, vmr in diagfi_gas.items():
                    gas_mix_ratio[gas][i] = vmr[simple_idx]
                for aerosol, key in aerosols_array_iterator(diagfi_aerosols):
                    aerosols[aerosol][key][i] = diagfi_aerosols[aerosol][key][simple_idx]
        elif isinstance(self.radiative_transfer, Emission):
            """Handle surface point (Emission mode)."""
            temperature = np.concatenate([[get_longitudes(f.variables['tsurf'][self.t])], diagfi_temperature])
            gas_mix_ratio = {}
            for gas, vmr in diagfi_gas.items():
                gas_mix_ratio[gas] = np.concatenate([vmr[:1], vmr])
            aerosols = arrays_to_zeros(diagfi_aerosols, grid.shape)
            for aerosol, key in aerosols_array_iterator(diagfi_aerosols):
                aerosols[aerosol][key] = np.concatenate([diagfi_aerosols[aerosol][key][:1], diagfi_aerosols[aerosol][key]])
        else:
            aerosols = diagfi_aerosols
            gas_mix_ratio = diagfi_gas
            temperature = diagfi_temperature

        """Main output of this function: the input atmosphere."""
        self.input_atmosphere = InputAtmosphere(grid=grid, pressure=pressure, temperature=temperature, gas_mix_ratio=gas_mix_ratio, aerosols=aerosols)

        """Calculate surface of each cell, and compare to diagfi 'aire' information."""
        try:
            calculated_surface = self.surface
            self.surface = get_longitudes(f.variables['aire'])
            assert np.isclose(self.surface, calculated_surface)
        except:
            self.warning("Surface from diagfi either wrong or inexistent. We will recompute it...")
            self.surface = None


    def grid_to_radians(self, f, grid):
        """Convert latitudes and longitudes to radians."""
        try:
            grid.mid_latitudes = np.radians(f.variables["latitude"])
            self.info("Latitude read from file (%s to %s)."%(f.variables["latitude"][0], f.variables["latitude"][-1]))
        except:
            pass
        try:
            grid.mid_longitudes = np.radians(f.variables["longitude"][:-1])
            self.info("Longitude read from file (%s to %s)."%(f.variables["longitude"][0], f.variables["longitude"][-1]))
        except:
            pass
        try:
            lon_units = None
            lon_units = f.variables["longitude"].units
            assert "degrees" in lon_units
        except:
            self.warning("Longitude units (%s) not recognized as degrees. However, we still transform it from degrees to radians. If your data (latitude included) is in radians, please convert it to degrees."%lon_units)

    def read_gases(self, f, controle):
        """Read gases from :attr:`f` file descriptor using :attr:`gas_dict`."""
        diagfi_gas = {}
        for gas, key in self.gas_dict.items():
            try:
                try:
                    diagfi_gas[gas] = np.asarray(get_longitudes(f.variables[key][self.t]))
                except:
                    self.error("Could not read timestep %s of %s. Reading timestep 0."% (self.t, gas))
                    diagfi_gas[gas] = np.asarray(get_longitudes(f.variables[key][0]))

                """Convert from log space to normal space."""
                diagfi_gas[gas] = convert_log(diagfi_gas[gas], self.gas_units)

                """Convert the data to VMR if needed."""
                if self.gas_units is None and 'units' not in f.variables[key].__dict__:
                    # No units. If positive, we consider it is already a VMR, else we consider it is a log10 of a VMR.
                    if (diagfi_gas[gas] < 0).all():
                        self.warning("No gas units given. The data is negative so we will assume it is a VMR given in log10. Please add units in the input file if it is not, or set gas_units to 'vmr' or 'mmr' or 'log_vmr' or 'log_mmr'.")
                        diagfi_gas[gas] = np.power(10,diagfi_gas[gas])
                elif "mmr" in self.gas_units or (self.gas_units is None and not f.variables[key].units == 'm^3/m^3'):
                    # The unit is supposed to be a MMR. We will need a molar mass for that.
                    try:
                        if not hasattr(self, "input_mu"):
                            self.input_mu = float(controle[7]/1000) # conversion g/mol to kg/mol
                    except:
                        self.error("Could not find input mu in 'controle'. Cannot convert %s from MMR to VMR (given unit is '%s')." % (gas,f.variables[key].units))
                        continue
                    diagfi_gas[gas] *= self.input_mu / xk.Molar_mass().fetch(gas)
            except:
                self.error("Could not read %s or its units from input file."%key)
        return diagfi_gas

    def read_aerosols(self, f):
        """Read aerosols from diagfi using :attr:`aerosols_dict`. Each key should be composed of the name of the aerosol + the entry it corresponds to. The entries we will try to find are: :code:`["mmr", "reff", "condensate_density", "p_min", "optical_properties"]`.
        Example of a valid dictionary: {'H2O': 'h2o_ice', 'H2O_reff': 'H2Oice_reff'}."""
        diagfi_aerosols =  {}
        if self.aerosols_dict is None:
            self.aerosols_dict = {}

        # These are the keys we read in the diagfi
        keys_to_read = ["mmr", "reff", "condensate_density", "p_min", "optical_properties"]
        for aerosol, key in self.aerosols_dict.items():
            entry = "mmr" # in case of dictionary {'H2O': 'h2o_ice'}
            for var in keys_to_read: # try to find which entry is `aerosol`
                if aerosol.endswith("_"+var):
                    aerosol = aerosol[:-(len(var)+1)]
                    entry = var
                    break
            if aerosol not in diagfi_aerosols.keys():
                diagfi_aerosols[aerosol] = {}
            diagfi_aerosols[aerosol][entry] = np.asarray(get_longitudes(f.variables[key][self.t]))

            """Convert from log space to normal space."""
            if entry == "mmr":
                diagfi_aerosols[aerosol][entry] = convert_log(diagfi_aerosols[aerosol][entry], self.aerosols_units)

            ng = np.where(diagfi_aerosols[aerosol][entry] < 0)
            if not np.isclose(diagfi_aerosols[aerosol][entry][ng], 0).all():
                self.error("Found negative values for aerosols %s. Min is %s. Will probably crash... Please remove these negative values if this is an error."%(entry, diagfi_aerosols[aerosol][entry][ng].min()))
            diagfi_aerosols[aerosol][entry][ng] = 0 # rounding errors
        return diagfi_aerosols