import typing
from typing import Union, List, Dict, Optional
from pathlib import Path

import eccodes
import xarray as xr
from tqdm import tqdm

from ._level import _fix_level
from ._util import _check_message
from ._xarray import create_data_array_from_message, get_level_coordinate_name
from reki._util import _load_first_variable


def load_field_from_file(
        file_path: Union[str, Path],
        parameter: Union[str, Dict] = None,
        level_type: Union[str, Dict] = None,
        level: Union[int, float, List, Dict] = None,
        level_dim: Optional[str] = None,
        show_progress: bool = False,
        **kwargs
) -> Optional[xr.DataArray]:
    """
    Load **one** field from local GRIB2 file using eccodes-python.
    Or load multi levels into one field.

    Parameters
    ----------
    file_path: str or Path
    parameter: str or typing.Dict
        parameter name.
        Use GRIB key `shortName` or a dict of filter conditions such as:
            {
                "discipline": 0,
                "parameterCategory": 2,
                "parameterNumber": 225,
            }

    level_type: str or typing.Dict or None
        level type.

        - Use "pl", "ml" or "sfc". They will be converted into dict.
        - Use GRIB key `typeOfLevel`, such as
            - "isobaricInhPa"
            - "isobaricInPa"
            - "surface"
            - "heightAboveGround"
            - ...
          See https://apps.ecmwf.int/codes/grib/format/edition-independent/3/ for more values.
        - If `typeOfLevel` is not available, use dict to specify filter conditions.
          For example, to get one filed from GRAPES GFS modelvar GRIB2 file, use:
            {
                "typeOfFirstFixedSurface": 131
            }

    level: int or float or typing.List or typing.Dict or None
        level value(s).

        - If use a scalar, level will be a non-dimension coordinate.
        - If your want to extract multi levels, use a list and level will be a dimension (level, lat, lon).
        - If use a dict, message will be filtered by dict keys. Support custom calculate keys:
            - ``first_level``
            - ``second_level``
        - If use `"all"`, all levels of level_type will be packed in the result field.
        - If use `None`, only the first field will be returned.

    level_dim: str or None
        name of level dimension.
        If none, function will generate a name for level dim.
        If `level_type="pl"`, some values can be used:

            - `None` or `pl` or `isobaricInhPa`: level_dim is a float number with unit hPa.
            - `isobaricInPa`: level_dim is a float number with unit Pa.

    show_progress: bool
        show progress bar.

    Returns
    -------
    DataArray or None:
        DataArray if found, or None if not.

    Examples
    --------
    Load 850hPa temperature field from a GRIB2 file generated by GRAPES GFS.

    >>> load_field_from_file(
    ...     file_path="/sstorage1/COMMONDATA/OPER/NWPC/GRAPES_GFS_GMF/Prod-grib/2020031721/ORIG/gmf.gra.2020031800105.grb2",
    ...     parameter="t",
    ...     level_type="isobaricInhPa",
    ...     level=850,
    ... )
    <xarray.DataArray 't' (latitude: 720, longitude: 1440)>
    array([[249.19234375, 249.16234375, 249.16234375, ..., 249.15234375,
            249.19234375, 249.14234375],
           [249.45234375, 249.45234375, 249.42234375, ..., 249.45234375,
            249.44234375, 249.44234375],
           [249.69234375, 249.68234375, 249.68234375, ..., 249.70234375,
            249.67234375, 249.68234375],
           ...,
           [235.33234375, 235.45234375, 235.62234375, ..., 235.47234375,
            235.63234375, 235.48234375],
           [235.78234375, 235.91234375, 235.64234375, ..., 235.80234375,
            235.72234375, 235.82234375],
           [235.66234375, 235.86234375, 235.82234375, ..., 235.85234375,
            235.68234375, 235.70234375]])
    Coordinates:
        time           datetime64[ns] 2020-03-18
        step           timedelta64[ns] 4 days 09:00:00
        valid_time     datetime64[ns] 2020-03-22T09:00:00
        isobaricInhPa  int64 850
      * latitude       (latitude) float64 89.88 89.62 89.38 ... -89.38 -89.62 -89.88
      * longitude      (longitude) float64 0.0 0.25 0.5 0.75 ... 359.2 359.5 359.8
    Attributes:
        GRIB_edition:                    2
        GRIB_centre:                     babj
        GRIB_subCentre:                  0
        GRIB_tablesVersion:              4
        GRIB_localTablesVersion:         1
        GRIB_dataType:                   fc
        GRIB_dataDate:                   20200318
        GRIB_dataTime:                   0
        GRIB_validityDate:               20200322
        GRIB_validityTime:               900
        GRIB_step:                       105
        GRIB_stepType:                   instant
        GRIB_stepUnits:                  1
        GRIB_stepRange:                  105
        GRIB_endStep:                    105
        GRIB_name:                       Temperature
        GRIB_shortName:                  t
        GRIB_cfName:                     air_temperature
        GRIB_discipline:                 0
        GRIB_parameterCategory:          0
        GRIB_parameterNumber:            0
        GRIB_gridType:                   regular_ll
        GRIB_gridDefinitionDescription:  Latitude/longitude
        GRIB_typeOfFirstFixedSurface:    pl
        GRIB_typeOfLevel:                isobaricInhPa
        GRIB_level:                      850
        GRIB_numberOfPoints:             1036800
        GRIB_missingValue:               9999
        GRIB_units:                      K
        long_name:                       Temperature
        units:                           K

    """
    messages = []

    fixed_level_type, fixed_level_dim = _fix_level(level_type, level_dim)

    if show_progress:
        with open(file_path, "rb") as f:
            total_count = eccodes.codes_count_in_file(f)

    with open(file_path, "rb") as f:
        if show_progress:
            pbar = tqdm(
                total=total_count,
                desc="Filtering",
            )
        while True:
            message_id = eccodes.codes_grib_new_from_file(f)
            if message_id is None:
                break
            if show_progress:
                pbar.update(1)
            if not _check_message(message_id, parameter, fixed_level_type, level, **kwargs):
                eccodes.codes_release(message_id)
                continue
            messages.append(message_id)
            if isinstance(level, typing.List) or level == "all":
                continue
            else:
                break
        if show_progress:
            pbar.close()

    if len(messages) == 0:
        return None

    if len(messages) == 1:
        message_id = messages[0]
        data = create_data_array_from_message(message_id, level_dim_name=fixed_level_dim)
        eccodes.codes_release(message_id)
        return data

    if len(messages) > 1:
        if show_progress:
            pbar = tqdm(
                total=len(messages),
                desc="Decoding",
            )

        def creat_array(message):
            array = create_data_array_from_message(message, level_dim_name=fixed_level_dim)
            if show_progress:
                pbar.update(1)
            return array

        xarray_messages = [creat_array(message) for message in messages]
        for m in messages:
            eccodes.codes_release(m)
        if show_progress:
            pbar.close()

        if level_dim is None:
            if isinstance(level_type, str):
                level_dim_name = level_type
            elif isinstance(level_type, typing.Dict):
                level_dim_name = get_level_coordinate_name(xarray_messages[0])
            else:
                raise ValueError(f"level_type is not supported: {level_type}")
        elif isinstance(level_dim, str):
            level_dim_name = level_dim
        else:
            raise ValueError(f"level_type is not supported: {level_type}")

        if show_progress:
            print("Packing...")

        data = xr.concat(xarray_messages, level_dim_name)
        return data

    return None


def load_field_from_files(
        file_list: List,
        parameter: Union[str, Dict],
        level_type: Union[str, Dict],
        level: Optional[Union[int, float, List, Dict]],
        level_dim: Optional[str] = None,
        show_progress: bool = False,
        **kwargs
) -> Optional[xr.DataArray]:
    """
    Load one field from multiply files.

    Parameters
    ----------
    file_list: typing.List
        file list.
    parameter: str or typing.Dict
        see ``load_field_from_file``
    level_type: str or typing.Dict
        see ``load_field_from_file``
    level: int or float or typing.List or None
        see ``load_field_from_file``
    level_dim: str or None
        level dimension name.
    show_progress: bool
        see ``load_field_from_file``

    Returns
    -------
    xr.DataArray or None:
        xr.DataArray if found, or None if not.

    """
    field_list = []
    for file_path in file_list:
        print(file_path)
        field = load_field_from_file(
            file_path,
            parameter=parameter,
            level_type=level_type,
            level=level,
            level_dim=level_dim,
            show_progress=show_progress,
            **kwargs
        )
        field_list.append(field)

    data_set = xr.combine_by_coords(
        [f.expand_dims(["time", "step"]).to_dataset() for f in field_list]
    )
    data = _load_first_variable(data_set)
    data = data.transpose("time", "step", ...)
    return data
