from getpass import getuser
from pathlib import Path
from typing import Dict, List, Union

from pydantic import DirectoryPath, Field, root_validator
from typing_extensions import Literal

from ._basemodel import BaseModel
from ._description_helpers import formatter as f
from ._dimensions import IDSOperationDim
from ._imas import ImasBaseModel
from ._variable import VariableModel
from .data_location import DataLocation
from .matrix_samplers import (CartesianProduct, HaltonSampler, LHSSampler,
                              SobolSampler)
from .workdir import WorkDirectoryModel


class VariableConfigModel(BaseModel):
    __root__: List[VariableModel] = Field([
        VariableModel(name='rho_tor_norm',
                      ids='core_profiles',
                      path='profiles_1d/$time/grid/rho_tor_norm',
                      dims=['x']),
        VariableModel(name='t_i_average',
                      ids='core_profiles',
                      path='profiles_1d/$time/t_i_average',
                      dims=['x']),
        VariableModel(name='zeff',
                      ids='core_profiles',
                      path='profiles_1d/$time/zeff',
                      dims=['x']),
        VariableModel(name='time',
                      ids='core_profiles',
                      path='time',
                      dims=['time']),
    ])

    def __iter__(self):
        yield from self.__root__

    def __getitem__(self, index: int):
        return self.__root__[index]

    def to_variable_dict(self) -> dict:
        """Return dict of variables."""
        return {variable.name: variable for variable in self}


class CreateConfigModel(BaseModel):
    """The options of the `create` subcommand are stored in the `create` key in
    the config."""
    dimensions: List[Union[IDSOperationDim]] = Field([
        IDSOperationDim(variable='t_i_average'),
        IDSOperationDim(variable='zeff')
    ],
                                                     description=f("""
        The `dimensions` specifies the dimensions of the matrix to sample
        from. Each dimension is a compound set of operations to apply.
        From this, a matrix all possible combinations is generated.
        Essentially, it generates the
        [Cartesian product](en.wikipedia.org/wiki/Cartesian_product)
        of all operations. By specifying a different `sampler`, a subset of
        this hypercube can be efficiently sampled.
        """))

    sampler: Union[LHSSampler, HaltonSampler, SobolSampler,
                   CartesianProduct] = Field(default=LHSSampler(),
                                             discriminator='method',
                                             description=f("""
        For efficient UQ, it may not be necessary to sample the entire matrix
        or hypercube. By default, the cartesian product is taken. For more
        efficient sampling of the space, the following `method` choices are
        available:
        [`latin-hypercube`](en.wikipedia.org/wiki/Latin_hypercube_sampling),
        [`sobol`](en.wikipedia.org/wiki/Sobol_sequence),
        [`halton`](en.wikipedia.org/wiki/Halton_sequence).
        Where `n_samples` gives the number of samples to extract.
        """))

    template: DirectoryPath = Field(
        f'/pfs/work/{getuser()}/jetto/runs/duqtools_template',
        description=f("""
        Template directory to modify. Duqtools copies and updates the settings
        required for the specified system from this directory. This can be a
        directory with a finished run, or one just stored by JAMS (but not yet
        started). By default, duqtools extracts the input IMAS database entry
        from the settings file (e.g. jetto.in) to find the data to modify for
        the UQ runs.
        """))

    template_data: ImasBaseModel = Field(None,
                                         description=f("""
        Specify the location of the template data to modify. This overrides the
        location of the data specified in settings file in the template
        directory.
        """))

    data: DataLocation = Field(DataLocation(),
                               description=f("""
        Where to store the in/output IDS data.
        The data key specifies the machine or imas
        database name where to store the data (`imasdb`). duqtools will write the input
        data files for UQ start with the run number given by `run_in_start_at`.
        The data generated by the UQ runs (e.g. from jetto) will be stored
        starting by the run number given by `run_out_start_at`.
        """))


class SubmitConfigModel(BaseModel):
    """The options of the `submit` subcommand are stored under the `submit` key
    in the config.

    The config describes the commands to start the UQ runs.
    """

    submit_script_name: str = Field(
        '.llcmd', description='Name of the submission script.')
    submit_command: str = Field('sbatch', description='Submission command.')


class StatusConfigModel(BaseModel):
    """The options of the `status` subcommand are stored under the `status` key
    in the config.

    These only need to be changed if the modeling software changes.
    """

    status_file: str = Field('jetto.status',
                             description='Name of the status file.')
    in_file: str = Field('jetto.in',
                         description=f("""
            Name of the modelling input file, will be used to check
            if the subprocess has started.
            """))

    out_file: str = Field('jetto.out',
                          description=f("""
            Name of the modelling output file, will be used to
            check if the software is running.
            """))

    msg_completed: str = Field('Status : Completed successfully',
                               description=f("""
            Parse `status_file` for this message to check for
            completion.
            """))

    msg_failed: str = Field('Status : Failed',
                            description=f("""
            Parse `status_file` for this message to check for
            failures.
            """))

    msg_running: str = Field('Status : Running',
                             description=f("""
            Parse `status_file` for this message to check for
            running status.
            """))


class MergeStep(BaseModel):
    """These parameters describe which paths should be merged.

    Three sets of variables need to be defined:
    - time_variable: this points to the data for the time coordinate
    - grid_variable: this points to the data for the grid variable
    - data_variables: these point to the data to be merged

    Note that all variables must be from the same IDS.

    The grid and data variables must share a common dimension. The grid variable
    will be used to rebase all data variables to a common grid.

    The time variable will be used to rebase the grid variable and the data variables
    to a common time coordinate. To denote the time index, use `/$time/` in both
    the grid and data variables.

    Rebasing involves interpolation.

    Note that multiple merge steps can be specified, for example for different
    IDS.
    """
    data_variables: List[Union[str,
                               VariableModel]] = Field(['t_i_average', 'zeff'],
                                                       description=f("""
            This is a list of data variables to be merged. This means
            that the mean and error for these data over all runs are calculated
            and written back to the ouput data location.
            The paths should contain `/$time/` for the time component.
            """))
    grid_variable: Union[str, VariableModel] = Field('rho_tor_norm',
                                                     description=f("""
            This variable points to the data for the grid coordinate. It must share a common
            placeholder dimension with the data variables.
            It will be used to rebase all data variables to same (radial) grid before merging
            using interpolation.
            The path should contain '/$time/' to denote the time component.
            """))
    time_variable: Union[str, VariableModel] = Field('time',
                                                     description=f("""
            This variable determines the time coordinate to merge on. This ensures
            that the data from all runs are on the same time coordinates before
            merging.
            """))


class MergeConfigModel(BaseModel):
    """The options of the `merge` subcommand are stored under the `merge` key
    in the config.

    These keys define the location of the IMAS data, which IDS entries
    to merge, and where to store the output.

    Before merging, all keys are rebased on (1) the same radial
    coordinate specified via `base_ids` and (2) the timestamp.
    """
    data: Path = Field('runs.yaml',
                       description=f("""
            Data file with IMAS handles, such as `data.csv` or `runs.yaml`'
            """))
    template: ImasBaseModel = Field(
        {
            'user': getuser(),
            'db': 'jet',
            'shot': 94785,
            'run': 1
        },
        description=f("""
            This IMAS DB entry will be used as the template.
            It is copied to the output location.
            """))
    output: ImasBaseModel = Field(
        {
            'db': 'jet',
            'shot': 94785,
            'run': 9999
        },
        description='Merged data will be written to this IMAS DB entry.')
    plan: List[MergeStep] = Field([MergeStep()],
                                  description='List of merging operations.')


class ConfigModel(BaseModel):
    """The options for the CLI are defined by this model."""
    plot: dict = Field(None,
                       deprecated=True,
                       description='Options are specified via CLI.',
                       exclude=True)

    submit: SubmitConfigModel = Field(
        SubmitConfigModel(),
        description='Configuration for the submit subcommand')
    create: CreateConfigModel = Field(
        CreateConfigModel(),
        description='Configuration for the create subcommand')
    status: StatusConfigModel = Field(
        StatusConfigModel(),
        description='Configuration for the status subcommand')
    merge: MergeConfigModel = Field(
        MergeConfigModel(),
        description='Configuration for the merge subcommand')

    variables: VariableConfigModel = Field(
        VariableConfigModel(),
        description='Define variables for use in the subcommands.')

    workspace: WorkDirectoryModel = WorkDirectoryModel()
    system: Literal['jetto',
                    'dummy'] = Field('jetto',
                                     description='backend system to use')

    @root_validator(pre=False, skip_on_failure=True)
    def update_variables(cls, values):
        """Grab variable names from different steps and replace them with the
        definitions from the `variables` attribute."""
        var_dict = values['variables'].to_variable_dict()

        def validate_variable(var: Union[str, VariableModel],
                              dct: Dict[str, VariableModel]) -> VariableModel:
            if isinstance(var, VariableModel):
                return var

            try:
                variable_model = dct[var]
            except KeyError:
                raise KeyError(f'Variable: `{var}` has not been defined.')

            return variable_model

        for dimension in values['create'].dimensions:
            dimension.variable = validate_variable(dimension.variable,
                                                   var_dict)

        for step in values['merge'].plan:
            step.grid_variable = validate_variable(step.grid_variable,
                                                   var_dict)
            step.time_variable = validate_variable(step.time_variable,
                                                   var_dict)
            step.data_variables = [
                validate_variable(variable, var_dict)
                for variable in step.data_variables
            ]

        return values
