from __future__ import annotations

import dataclasses
import math
import warnings
from gettext import gettext as _
from typing import NamedTuple, Optional

from libresvip.core.warning_types import ParamsWarning
from libresvip.model.base import ParamCurve, Points, SongTempo
from libresvip.model.point import Point
from libresvip.utils import find_last_index, hz2midi, midi2hz

from .constants import (
    MIN_DATA_LENGTH,
    TEMP_VALUE_AS_NULL,
    TIME_UNIT_AS_TICKS_PER_BPM,
)


class CeVIOPitchEvent(NamedTuple):
    index: Optional[int]
    repeat: Optional[int]
    value: float


class CeVIOPitchEventFloat(NamedTuple):
    index: Optional[float]
    repeat: Optional[float]
    value: Optional[float]

    @classmethod
    def from_event(cls, event: CeVIOPitchEvent) -> CeVIOPitchEventFloat:
        return cls(
            float(event.index) if event.index is not None else None,
            float(event.repeat) if event.repeat is not None else None,
            event.value,
        )


@dataclasses.dataclass
class CeVIOTrackPitchData:
    events: list[CeVIOPitchEvent]
    tempos: list[SongTempo]
    tick_prefix: int

    @property
    def length(self) -> int:
        last_has_index = find_last_index(self.events, lambda event: event.index is not None)
        length = self.events[last_has_index].index + sum(
            event.repeat or 1 for event in self.events[last_has_index:]
        )
        return length + MIN_DATA_LENGTH


def pitch_from_cevio_track(data: CeVIOTrackPitchData) -> Optional[ParamCurve]:
    converted_points = [Point.start_point()]
    current_value = -100

    events_normalized = shape_events(normalize_to_tick(append_ending_points(data)))

    next_pos = None
    for event in events_normalized:
        pos = event.index - data.tick_prefix
        length = event.repeat
        try:
            value = round(hz2midi(math.e**event.value) * 100) if event.value is not None else -100
            if value != current_value or next_pos != pos:
                converted_points.append(Point(x=round(pos), y=value))
                if value == -100:
                    converted_points.append(Point(x=round(pos), y=value))
                current_value = value
        except OverflowError:
            warnings.warn(_("Pitch value is out of bounds"), ParamsWarning)
        next_pos = pos + length
    converted_points.append(Point.end_point())

    return ParamCurve(points=Points(root=converted_points)) if len(converted_points) > 2 else None


def append_ending_points(data: CeVIOTrackPitchData) -> CeVIOTrackPitchData:
    result = []
    next_pos = None
    for event in data.events:
        pos = event.index if event.index is not None else next_pos
        length = event.repeat if event.repeat is not None else 1
        if next_pos is not None and next_pos < pos:
            result.append(CeVIOPitchEvent(next_pos, None, TEMP_VALUE_AS_NULL))
        result.append(CeVIOPitchEvent(pos, length, event.value))
        next_pos = pos + length
    if next_pos is not None:
        result.append(CeVIOPitchEvent(next_pos, None, TEMP_VALUE_AS_NULL))
    return CeVIOTrackPitchData(result, data.tempos, data.tick_prefix)


def normalize_to_tick(data: CeVIOTrackPitchData) -> list[CeVIOPitchEventFloat]:
    tempos = expand(data.tempos, data.tick_prefix)
    events = [CeVIOPitchEventFloat.from_event(event) for event in data.events]
    events_normalized: list[CeVIOPitchEventFloat] = []
    current_tempo_index = 0
    next_pos = 0.0
    next_tick_pos = 0.0
    for event in events:
        pos = event.index if event.index is not None else next_pos
        tick_pos = next_tick_pos if event.index is None else None
        if event.index is not None:
            while (
                current_tempo_index + 1 < len(tempos)
                and tempos[current_tempo_index + 1][0] <= event.index
            ):
                current_tempo_index += 1
            ticks_in_time_unit = TIME_UNIT_AS_TICKS_PER_BPM * tempos[current_tempo_index][2]
            tick_pos = (
                tempos[current_tempo_index][1]
                + (event.index - tempos[current_tempo_index][0]) * ticks_in_time_unit
            )
        repeat = event.repeat if event.repeat is not None else 1.0
        remaining_repeat = repeat
        repeat_in_ticks = 0.0
        while (
            current_tempo_index + 1 < len(tempos)
            and tempos[current_tempo_index + 1][0] < pos + repeat
        ):
            repeat_in_ticks += tempos[current_tempo_index + 1][1] - max(
                tempos[current_tempo_index][1], tick_pos
            )
            remaining_repeat -= tempos[current_tempo_index + 1][0] - max(
                tempos[current_tempo_index][0], pos
            )
            current_tempo_index += 1
        repeat_in_ticks += (
            remaining_repeat * TIME_UNIT_AS_TICKS_PER_BPM * tempos[current_tempo_index][2]
        )
        next_pos = pos + repeat
        next_tick_pos = tick_pos + repeat_in_ticks
        events_normalized.append(CeVIOPitchEventFloat(tick_pos, repeat_in_ticks, event.value))
    return [
        CeVIOPitchEventFloat(
            tick.index + data.tick_prefix,
            tick.repeat,
            tick.value if tick.value != TEMP_VALUE_AS_NULL else None,
        )
        for tick in events_normalized
    ]


def shape_events(
    events_with_full_params: list[CeVIOPitchEventFloat],
) -> list[CeVIOPitchEventFloat]:
    result: list[CeVIOPitchEventFloat] = []
    for event in events_with_full_params:
        if event.repeat is not None and event.repeat > 0:
            if result:
                last = result[-1]
                if last.index == event.index:
                    result[-1] = event
                else:
                    result.append(event)
            else:
                result.append(event)
    return result


def expand(tempos: list[SongTempo], tick_prefix: int) -> list[tuple[int, float, float]]:
    result: list[tuple[int, float, float]] = []
    for i, tempo in enumerate(tempos):
        if i == 0:
            result.append((0, tick_prefix, tempo.bpm))
        else:
            last_pos, last_tick_pos, last_bpm = result[-1]
            ticks_in_time_unit = TIME_UNIT_AS_TICKS_PER_BPM * last_bpm
            new_pos = last_pos + (tempo.position - last_tick_pos) / ticks_in_time_unit
            result.append((new_pos, tempo.position, tempo.bpm))
    return result


def generate_for_cevio(
    pitch: ParamCurve, tempos: list[SongTempo], tick_prefix: int
) -> Optional[CeVIOTrackPitchData]:
    events_with_full_params = []
    for i, this_point in enumerate(pitch.points.root):
        next_point = pitch.points[i + 1] if i + 1 < len(pitch.points) else None
        end_tick = next_point.x if next_point else None
        index = this_point.x
        repeat = end_tick - index if end_tick else 1
        repeat = max(repeat, 1)
        value = math.log(midi2hz(this_point.y / 100)) if this_point.y != -100 else None
        if value is not None:
            events_with_full_params.append(
                CeVIOPitchEventFloat(float(index), float(repeat), float(value))
            )
    are_events_connected_to_next = [
        this_event.index + this_event.repeat >= next_event.index if next_event else False
        for this_event, next_event in zip(
            events_with_full_params, events_with_full_params[1:] + [None]
        )
    ]
    events = denormalize_from_tick(events_with_full_params, tempos, tick_prefix)
    events = restore_connection(events, are_events_connected_to_next)
    events = merge_events_if_possible(events)
    events = remove_redundant_index(events)
    events = remove_redundant_repeat(events)
    if not events:
        return None
    last_event_with_index = next(
        (event for event in reversed(events) if event.index is not None), None
    )
    if last_event_with_index is not None:
        length = last_event_with_index.index
        for event in events[events.index(last_event_with_index) :]:
            length += event.repeat or 1
    return CeVIOTrackPitchData(events, [], tick_prefix)


def denormalize_from_tick(
    events_with_full_params: list[CeVIOPitchEventFloat],
    tempos_in_ticks: list[SongTempo],
    tick_prefix: int,
) -> list[CeVIOPitchEvent]:
    tempos = expand(
        [
            tempo.model_copy(update={"position": tempo.position + tick_prefix})
            for tempo in tempos_in_ticks
        ],
        tick_prefix,
    )
    events_with_full_params = [
        event if event.index is None else event._replace(index=event.index + tick_prefix)
        for event in events_with_full_params
    ]
    events = []
    current_tempo_index = 0
    for event_double in events_with_full_params:
        if event_double.index is not None:
            tick_pos = event_double.index
        while (
            current_tempo_index + 1 < len(tempos) and tempos[current_tempo_index + 1][1] < tick_pos
        ):
            current_tempo_index += 1
        ticks_per_time_unit = tempos[current_tempo_index][2] * TIME_UNIT_AS_TICKS_PER_BPM
        pos = (
            tempos[current_tempo_index][0]
            + (event_double.index - tempos[current_tempo_index][1]) / ticks_per_time_unit
        )
        repeat_in_ticks = event_double.repeat
        remaining_repeat_in_ticks = repeat_in_ticks
        repeat = 0.0
        while (current_tempo_index + 1 < len(tempos)) and (
            tempos[current_tempo_index + 1][1] < tick_pos + repeat_in_ticks
        ):
            repeat += tempos[current_tempo_index + 1][0] - max(tempos[current_tempo_index][0], pos)
            remaining_repeat_in_ticks -= tempos[current_tempo_index + 1][1] - max(
                tempos[current_tempo_index][1], tick_pos
            )
            current_tempo_index += 1
        repeat += remaining_repeat_in_ticks / (
            TIME_UNIT_AS_TICKS_PER_BPM * tempos[current_tempo_index][2]
        )
        events.append(CeVIOPitchEvent(round(pos), int(round(repeat)), event_double.value))
    return events


def restore_connection(
    events: list[CeVIOPitchEvent], are_events_connected_to_next: list[bool]
) -> list[CeVIOPitchEvent]:
    new_events = []
    for event, is_connected_to_next in zip(events, are_events_connected_to_next):
        new_events.append(event)
        if not is_connected_to_next:
            new_events.append(CeVIOPitchEvent(event.index + event.repeat, 0, event.value))
    return new_events


def merge_events_if_possible(events: list[CeVIOPitchEvent]) -> list[CeVIOPitchEvent]:
    new_events = []
    for event, next_event in zip(events, events[1:] + [None]):
        if (
            next_event
            and event.value == next_event.value
            and event.index + event.repeat == next_event.index
        ):
            new_events.append(
                CeVIOPitchEvent(event.index, event.repeat + next_event.repeat, event.value)
            )
        else:
            new_events.append(event)
    return new_events


def remove_redundant_index(events: list[CeVIOPitchEvent]) -> list[CeVIOPitchEvent]:
    new_events = []
    for prev_event, event in zip([None] + events[:-1], events):
        if (
            prev_event is not None
            and prev_event.index is not None
            and prev_event.repeat is not None
            and prev_event.index + prev_event.repeat == event.index
        ):
            new_events.append(CeVIOPitchEvent(None, event.repeat, event.value))
        else:
            new_events.append(event)
    return new_events


def remove_redundant_repeat(events: list[CeVIOPitchEvent]) -> list[CeVIOPitchEvent]:
    return [event if event.repeat != 1 else event._replace(repeat=None) for event in events]
