import numpy as np

from . import common_args
from ..util import read_param_file, scale_samples, compute_groups_matrix


def sample(problem, N, seed=None):
    """Generate model inputs using Latin hypercube sampling (LHS).

    Returns a NumPy matrix containing the model inputs generated by Latin
    hypercube sampling.  The resulting matrix contains N rows and D columns,
    where D is the number of parameters.

    Parameters
    ----------
    problem : dict
        The problem definition
    N : int
        The number of samples to generate
    seed : int
        Seed to generate a random number

    References
    ----------
    1. McKay, M.D., Beckman, R.J., Conover, W.J., 1979.
           A comparison of three methods for selecting values of input
           variables in the analysis of output from a computer code.
           Technometrics 21, 239-245.
           https://doi.org/10.2307/1268522

    2. Iman, R.L., Helton, J.C., Campbell, J.E., 1981.
           An Approach to Sensitivity Analysis of Computer Models:
           Part I—Introduction, Input Variable Selection and
           Preliminary Variable Assessment.
           Journal of Quality Technology 13, 174-183.
           https://doi.org/10.1080/00224065.1981.11978748

    """
    num_samples = N

    if seed:
        np.random.seed(seed)

    groups = problem.get("groups")
    if groups:
        num_groups = len(set(groups))
        G, group_names = compute_groups_matrix(groups)
    else:
        num_groups = problem["num_vars"]

    result = np.empty([num_samples, problem["num_vars"]])
    temp = np.empty([num_samples])
    d = 1.0 / num_samples

    temp = np.array(
        [
            np.random.uniform(low=sample * d, high=(sample + 1) * d, size=num_groups)
            for sample in range(num_samples)
        ]
    )

    for group in range(num_groups):

        np.random.shuffle(temp[:, group])

        for sample in range(num_samples):
            if groups:
                grouped_variables = np.where(G[:, group] == 1)
                result[sample, grouped_variables[0]] = temp[sample, group]
            else:
                result[sample, group] = temp[sample, group]

    result = scale_samples(result, problem)

    return result


# No additional CLI options
cli_parse = None


def cli_action(args):
    """Run sampling method

    Parameters
    ----------
    args : argparse namespace
    """
    problem = read_param_file(args.paramfile)
    param_values = sample(problem, args.samples, seed=args.seed)
    np.savetxt(
        args.output,
        param_values,
        delimiter=args.delimiter,
        fmt="%." + str(args.precision) + "e",
    )


if __name__ == "__main__":
    common_args.run_cli(cli_parse, cli_action)
