"""Implements core function nearest_neighbours used for AMD and PDD
calculations.
"""

import collections
from typing import Tuple, Iterable
from itertools import product

import numba
import numpy as np
import numpy.typing as npt
from scipy.spatial import KDTree


def nearest_neighbours(
        motif: npt.NDArray,
        cell: npt.NDArray,
        x: npt.NDArray,
        k: int
) -> Tuple[npt.NDArray[np.float64]]:
    """
    Given a periodic set represented by (motif, cell) and an integer k,
    find the k nearest neighbours in the periodic set to points in x.

    Parameters
    ----------
    motif : :class:`numpy.ndarray`
        Orthogonal (Cartesian) coords of the motif, shape (no points,
        dims).
    cell : :class:`numpy.ndarray`
        Orthogonal (Cartesian) coords of the unit cell, shape (dims,
        dims).
    x : :class:`numpy.ndarray`
        Array of points to query for neighbours. For invariants of
        crystals this is the asymmetric unit.
    k : int
        Number of nearest neighbours to find for each point in x.

    Returns
    -------
    pdd : numpy.ndarray
        Array shape (motif.shape[0], k) of distances from points in x
        to their k nearest neighbours in the periodic set, in order.
        E.g. pdd[m][n] is the distance from x[m] to its n-th nearest
        neighbour in the periodic set.
    cloud : numpy.ndarray
        Collection of points in the periodic set that was generated
        during the nearest neighbour search.
    inds : numpy.ndarray
        Array shape (motif.shape[0], k) containing the indices of
        nearest neighbours in cloud. E.g. the n-th nearest neighbour to
        the m-th motif point is cloud[inds[m][n]].
    """

    cloud_generator = generate_concentric_cloud(motif, cell)
    n_points = 0
    cloud = []
    while n_points <= k:
        l = next(cloud_generator)
        n_points += l.shape[0]
        cloud.append(l)
    cloud.append(next(cloud_generator))
    cloud = np.concatenate(cloud)

    tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
    pdd_, inds = tree.query(x, k=k+1, workers=-1)
    pdd = np.zeros_like(pdd_, dtype=np.float64)

    while not np.allclose(pdd, pdd_, atol=1e-10, rtol=0):
        pdd = pdd_
        cloud = np.vstack((cloud, next(cloud_generator)))
        tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
        pdd_, inds = tree.query(x, k=k+1, workers=-1)

    return pdd_[:, 1:], cloud, inds[:, 1:]


def nearest_neighbours_minval(
        motif: npt.NDArray,
        cell: npt.NDArray,
        min_val: float
) -> npt.NDArray[np.float64]:
    """The same as nearest_neighbours except a value is given instead of
    an integer k and the result has at least enough columns so all
    values in the last column are at least the given value.
    """

    cloud_generator = generate_concentric_cloud(motif, cell)

    cloud = []
    for _ in range(3):
        cloud.append(next(cloud_generator))
    cloud = np.concatenate(cloud)

    tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
    pdd_, _ = tree.query(motif, k=cloud.shape[0], workers=-1)
    pdd = np.zeros_like(pdd_)

    while True:
        if np.all(pdd[:, -1] >= min_val):
            col_where = np.argwhere(np.all(pdd >= min_val, axis=0))[0][0] + 1
            if np.array_equal(pdd[:, :col_where], pdd_[:, :col_where]):
                break
        pdd = pdd_
        cloud = np.vstack((cloud, next(cloud_generator)))
        tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
        pdd_, _ = tree.query(motif, k=cloud.shape[0], workers=-1)

    k = np.argwhere(np.all(pdd >= min_val, axis=0))[0][0]
    return pdd[:, 1:k+1]


def generate_concentric_cloud(
        motif: npt.NDArray,
        cell: npt.NDArray
) -> Iterable[npt.NDArray[np.float64]]:
    """
    Generates batches of points from a periodic set given by (motif,
    cell) which get successively further away from the origin.

    Each yield gives all points (that have not already been yielded)
    which lie in a unit cell whose corner lattice point was generated by
    ``generate_integer_lattice(motif.shape[1])``.

    Parameters
    ----------
    motif : :class:`numpy.ndarray`
        Cartesian representation of the motif, shape (no points, dims).
    cell : :class:`numpy.ndarray`
        Cartesian representation of the unit cell, shape (dims, dims).

    Yields
    -------
    :class:`numpy.ndarray`
        Yields arrays of points from the periodic set.
    """

    m = len(motif)

    for int_lattice in generate_integer_lattice(cell.shape[0]):

        lattice = int_lattice @ cell
        layer = np.empty((m * len(lattice), cell.shape[0]), dtype=np.float64)
        i1 = 0
        for translation in lattice:
            i2 = i1 + m
            layer[i1:i2] = motif + translation
            i1 = i2

        yield layer


def generate_integer_lattice(dims: int) -> Iterable[npt.NDArray[np.float64]]:
    """Generates batches of integer lattice points. Each yield gives all
    points (that have not already been yielded) inside a sphere centered
    at the origin with radius d. d starts at 0 and increments by 1 on
    each loop.

    Parameters
    ----------
    dims : int
        The dimension of Euclidean space the lattice is in.

    Yields
    -------
    :class:`numpy.ndarray`
        Yields arrays of integer points in dims dimensional Euclidean
        space.
    """

    ymax = collections.defaultdict(int)
    d = 0

    if dims == 1:
        yield np.array([[0]])
        while True:
            d += 1
            yield np.array([[-d], [d]])

    while True:
        positive_int_lattice = []
        while True:
            batch = []
            for xy in product(range(d + 1), repeat=dims-1):
                if _dist(xy, ymax[xy]) <= d ** 2:
                    batch.append((*xy, ymax[xy]))
                    ymax[xy] += 1
            if not batch:
                break
            positive_int_lattice.extend(batch)

        yield _reflect_positive_lattice(np.array(positive_int_lattice))
        d += 1


@numba.njit()
def _dist(xy: Tuple[float, float], z: float) -> float:
    s = z ** 2
    for val in xy:
        s += val ** 2
    return s


@numba.njit()
def _reflect_positive_lattice(
        positive_int_lattice: npt.NDArray
) -> npt.NDArray[np.float64]:
    """Reflect a set of points in the +ve quadrant in all axes. Does not
    duplicate points lying on the axes themselves.
    """

    dims = positive_int_lattice.shape[-1]
    batches = []
    batches.extend(positive_int_lattice)

    for n_reflections in range(1, dims + 1):

        indices = np.arange(n_reflections)
        batches.extend(_reflect_batch(positive_int_lattice, indices))

        while True:
            i = n_reflections - 1
            for _ in range(n_reflections):
                if indices[i] != i + dims - n_reflections:
                    break
                i -= 1
            else:
                break
            indices[i] += 1
            for j in range(i+1, n_reflections):
                indices[j] = indices[j-1] + 1

            batches.extend(_reflect_batch(positive_int_lattice, indices))

    int_lattice = np.empty(shape=(len(batches), dims), dtype=np.float64)
    for i in range(len(batches)):
        int_lattice[i] = batches[i]

    return int_lattice


@numba.njit()
def _reflect_batch(
        positive_int_lattice: npt.NDArray,
        indices: npt.NDArray
) -> npt.NDArray:
    """Takes a collection of points in any dimension and the indices of
    axes to reflect in, returning a batch of reflected points not
    including any points which are invariant under the reflections.
    """

    where_on_axes = (positive_int_lattice[:, indices] == 0).sum(axis=-1) == 0
    batch = positive_int_lattice[where_on_axes]
    batch[:, indices] *= -1
    return batch
