from __future__ import annotations

from functools import partial
from typing import Any, Callable, Sequence

import vapoursynth as vs
from vsexprtools.util import mod4
from vsutil import (
    depth, disallow_variable_format, disallow_variable_resolution, fallback, get_depth, get_neutral_value,
    get_peak_value, scale_value, split
)

from .mask import adg_mask
from .types import Grainer

core = vs.core


_grainer_str_map = {'addgrain': Grainer.AddGrain, 'addnoise': Grainer.AddNoise}

GrainerFunc = Callable[[vs.VideoNode], vs.VideoNode]

GrainerFuncGenerator = Callable[[float, float, int, bool], GrainerFunc]


@disallow_variable_format
@disallow_variable_resolution
def adaptive_grain(
    clip: vs.VideoNode, strength: float | list[float] = 0.25, size: float = 1, sharp: int = 50, static: bool = False,
    luma_scaling: float = 12, grainer: Grainer | GrainerFuncGenerator | str | None = Grainer.AddGrain,
    fade_edges: bool = True, tv_range: bool = True, lo: int | None = None, hi: int | None = None,
    protect_neutral: bool = True, seed: int = -1, show_mask: bool = False, temporal_average: int = 0, **kwargs: Any
) -> vs.VideoNode:
    """
    A modified kagefunc.adaptive_grain using a GrainFactory3 mod (sizedgrain) as a grainer,
    whiclip.height implements grain size and sharpness, fading values near illegal range and protecting neutral grays.

    To access the graining without using the adaptive_grain mask, see ``sizedgrain``.

    :param clip:                Input clip.
    :param strength:            Grainer strength. Use a list to specify [luma, chroma] graining.
                                Default chroma grain is luma / 2.
    :param size:                Grain size multiplier.
                                Passed as x/ysize if plugin supports it. Else:
                                Grain generated with > 1 value is generated by upscaling it from a lower resolution.
                                Grain generated with < 1 value is generated by downscaling it from a higher resolution.
    :param sharp:               Grain sharpness. No effect if size = 1.
                                Determines b and c values in bicubic resampler used for grain size.
    :param static:              Static or dynamic grain.
    :param luma_scaling:        Adaptive mask's sensitivity, calculated on luma for all planes.
                                Lower values will grain bright areas more, higher values will grain them less.
    :param grainer:             Specify another grainer than grain.Add.
                                It can be a memeber of the ``Grainer`` enum or the name of it,
                                or a function that takes [luma_str, chroma_str, static, seed] and returns a grainer.
    :param fade_edges:          Keeps grain from exceeding legal range.
                                With this, values whiclip.height go towards the neutral point, but would generate
                                illegal values if they pointed in the other direction are also limited.
                                This is better at maintaining average values and prevents flickering pixels on OLEDs.
    :param tv_range:            TV or PC legal range.
    :param lo:                  Overwrite legal range's minimums. Value is scaled from 8-bit to clip depth.
    :param hi:                  Overwrite legal range's maximums. Value is scaled from 8-bit to clip depth.
    :param protect_neutral:     Disable chroma grain on neutral chroma.
    :param seed:                Grain seed for the grainer.
    :param show_mask:           Show the adaptive mask built on luma.
    :param temporal_average:    Reference frame radius for temporal softening and grain consistency.
    :param kwargs:              Kwargs passed to the graining function.

    :returns: Masked grained clip.
    """

    mask = adg_mask(clip, luma_scaling)

    # Should we use this?
    # if temporal_average:
    #     mask = mask.median.TemporalMedian(temporal_average)

    vdepth = get_depth(clip)
    if get_depth(mask) != vdepth:
        mask = depth(mask, vdepth)

    if show_mask:
        return mask

    grained = sizedgrain(
        clip, strength, size, sharp, static, grainer, fade_edges,
        tv_range, lo, hi, protect_neutral, seed, temporal_average,
        **kwargs
    )

    return clip.std.MaskedMerge(grained, mask)


@disallow_variable_format
@disallow_variable_resolution
def sizedgrain(
    clip: vs.VideoNode,
    strength: float | list[float] = 0.25, size: float = 1, sharp: int = 50,
    static: bool = False, grainer: Grainer | GrainerFuncGenerator | str | None = Grainer.AddGrain,
    fade_edges: bool = True, tv_range: bool = True,
    lo: int | Sequence[int] | None = None, hi: int | Sequence[int] | None = None,
    protect_neutral: bool = True, seed: int = -1, temporal_average: int = 0, **kwargs: Any
) -> vs.VideoNode:
    """
    A grainer that includes GrainFactory3's grain size and sharpness, fading values near illegal range and protecting
    neutral grays.

    :param clip:                Input clip.
    :param strength:            Grainer strength. Use a list to specify [luma, chroma] graining.
                                Default chroma grain is luma / 2.
    :param size:                Grain size multiplier.
                                Passed as x/ysize if plugin supports it. Else:
                                Grain generated with > 1 value is generated by upscaling it from a lower resolution.
                                Grain generated with < 1 value is generated by downscaling it from a higher resolution.
    :param sharp:               Grain sharpness. No effect if size = 1.
                                Determines b and c values in bicubic resampler used for grain size.
    :param static:              Static or dynamic grain.
    :param grainer:             Specify another grainer than grain.Add.
                                It can be a memeber of the ``Grainer`` enum or the name of it,
                                or a function that takes [luma_str, chroma_str, static, seed] and returns a grainer.
    :param fade_edges:          Keeps grain from exceeding legal range.
                                With this, values whiclip.height go towards the neutral point, but would generate
                                illegal values if they pointed in the other direction are also limited.
                                This is better at maintaining average values and prevents flickering pixels on OLEDs.
    :param tv_range:            TV or PC legal range.
    :param lo:                  Overwrite legal range's minimums. Value is scaled from 8-bit to clip depth.
    :param hi:                  Overwrite legal range's maximums. Value is scaled from 8-bit to clip depth.
    :param protect_neutral:     Disable chroma grain on neutral chroma.
    :param seed:                Grain seed for the grainer.
    :param temporal_average:    Reference frame radius for temporal softening and grain consistency.
    :param kwargs:              Kwargs passed to the graining function.

    :returns: Grained clip.
    """
    assert clip.format

    sx, sy = clip.width, clip.height
    vdepth = get_depth(clip)

    def scale_val8x(value: int, chroma: bool = False) -> float:
        return scale_value(value, 8, vdepth, scale_offsets=not tv_range, chroma=chroma)

    neutral = [
        get_neutral_value(clip), get_neutral_value(clip, True), get_neutral_value(clip, True)
    ][:clip.format.num_planes]

    b = sharp / -50 + 1
    c = (1 - b) / 2

    if not isinstance(strength, list):
        strength = [strength, .5 * strength]
    elif len(strength) > 2:
        raise ValueError('sizedgrain: Only 2 strength values are supported!')

    grainer = fallback(grainer, Grainer.AddGrain)  # type: ignore

    if isinstance(grainer, str):
        grainer_name = grainer.lower()

        if grainer_name in _grainer_str_map:
            grainer = _grainer_str_map[grainer_name]

    supports_size = grainer == Grainer.AddNoise

    grainer_func: GrainerFunc

    if isinstance(grainer, Grainer):
        if grainer in {Grainer.AddGrain, Grainer.AddNoise}:
            plugin = getattr(core, 'grain' if grainer is Grainer.AddGrain else 'noise')
            grainer_func = partial(plugin.Add, var=strength[0], uvar=strength[1], constant=static, seed=seed)
            if supports_size:
                grainer_func = partial(grainer_func, xsize=size, ysize=size)
        else:
            raise NotImplementedError
    elif callable(grainer):
        grainer_func = grainer(strength[0], strength[1], seed, static)
    else:
        raise ValueError('sizedgrain: Invalid grainer specified!')

    if not supports_size:
        if size != 1:
            sx, sy = (mod4(x / size) for x in (sx, sy))

        sxa, sya = mod4((clip.width + sx) / 2), mod4((clip.height + sy) / 2)

    blank = clip.std.BlankClip(sx, sy, color=neutral)

    grain = grainer_func(blank, **kwargs)

    if not supports_size and size != 1 and (sx != clip.width or sy != clip.height):
        if size > 1.5:
            grain = grain.resize.Bicubic(sxa, sya, filter_param_a=b, filter_param_b=c)

        grain = grain.resize.Bicubic(clip.width, clip.height, filter_param_a=b, filter_param_b=c)

    if static is False and temporal_average > 0:
        grain = grain.std.Merge(grain.std.AverageFrames(weights=[1] * 3), weight=temporal_average / 100)

    if fade_edges:
        if lo is None:
            lovals = [scale_val8x(16), scale_val8x(16, True)]
        elif not isinstance(lo, Sequence):
            lovals = [scale_val8x(lo), scale_val8x(lo, True)]
        else:
            lovals = list(lo)

        if hi is None:
            hivals = [scale_val8x(235), scale_val8x(240, True)]
        elif not isinstance(hi, Sequence):
            hivals = [scale_val8x(hi), scale_val8x(hi, True)]
        else:
            hivals = list(hi)

        limit_expr = ['x y {mid} - abs - {low} < x y {mid} - abs + {high} > or x y {mid} - x + ?']

        if clip.format.sample_type == vs.FLOAT:
            limit_expr[1] = 'x y abs + {high} > x abs y - {low} < or x x y + ?'

        grained = core.std.Expr([clip, grain], [
            expr.format(mid=mid, low=low, high=high)
            for expr, mid, low, high in zip(limit_expr, neutral, lovals, hivals)
        ])

        if protect_neutral and strength[1] > 0 and clip.format.color_family == vs.YUV:
            neutral_mask = clip.resize.Bicubic(format=clip.format.replace(subsampling_h=0, subsampling_w=0).id)

            # disable grain if neutral chroma
            neutral_mask = core.std.Expr(
                split(neutral_mask), f'y {neutral[1]} = z {neutral[1]} = and {get_peak_value(clip)} 0 ?'
            )

            grained = grained.std.MaskedMerge(clip, neutral_mask, planes=[1, 2])
    else:
        if lo is not None or hi is not None:
            print(
                Warning("sizedgrain: setting lo and hi won't do anything when fade_edges=False")
            )

        if clip.format.sample_type == vs.INTEGER:
            grained = clip.std.MergeDiff(grain)
        else:
            grained = core.std.Expr([clip, grained], [f'y {mid} - x +' for mid in neutral])

    return grained


adaptivegrainmod = adaptive_grain
adptvgrnMod = adaptive_grain
sizedgrn = sizedgrain
