from pathlib import Path

import numpy as np
import numpy.typing as npt
import plotly.express as px  # type: ignore
import sep  # type: ignore
from astropy.io import fits  # type: ignore
from astropy.visualization import AsinhStretch, ZScaleInterval  # type: ignore
from pydantic import BaseModel, Field

from pixelemon._plate_solve import TETRA_SOLVER, PlateSolve
from pixelemon._telescope import Telescope
from pixelemon.constants import PERCENT_TO_DECIMAL
from pixelemon.logging import pixelemon_LOG
from pixelemon.processing import MIN_BACKGROUND_MESH_COUNT, BackgroundSettings, Detections, DetectionSettings

# from PIL import Image  # type: ignore


class TelescopeImage(BaseModel):
    _original_array: npt.NDArray[np.float32] | None = None
    _processed_array: npt.NDArray[np.float32] | None = None
    telescope: Telescope | None = None
    image_scale: float = Field(default=1.0, description="The image scale due to cropping")
    background_settings: BackgroundSettings = Field(default_factory=BackgroundSettings)
    detection_settings: DetectionSettings = Field(default_factory=DetectionSettings.point_source_defaults)

    @classmethod
    def from_fits_file(cls, file_path: Path, telescope: Telescope) -> "TelescopeImage":
        with fits.open(file_path) as hdul:
            img = cls()
            img.telescope = telescope
            img._original_array = getattr(hdul[0], "data").astype(np.float32)
            assert img._original_array is not None
            actual_ratio = img._original_array.shape[1] / img._original_array.shape[0]
            if not np.isclose(img.telescope.aspect_ratio, actual_ratio, rtol=1e-2):
                pixelemon_LOG.warning("Trimming image to match expected aspect")
                new_width = int(img._original_array.shape[0] * img.telescope.aspect_ratio)
                start_x = (img._original_array.shape[1] - new_width) // 2
                img._original_array = img._original_array[:, start_x : start_x + new_width]
                img._original_array = np.ascontiguousarray(img._original_array)
            assert img._original_array is not None
            img._processed_array = img._original_array.copy()
        pixelemon_LOG.info(f"Loaded {img._original_array.shape} image from {file_path}")
        return img

    def write_to_fits_file(self, file_path: Path):
        if self._processed_array is None:
            raise ValueError("Image array is not loaded.")
        fits.writeto(file_path, self._processed_array.astype("uint8"), overwrite=True)
        pixelemon_LOG.info(f"Saved processed image to {file_path}")

    def crop(self, crop_percent: float):
        if self._original_array is None:
            raise ValueError("Image array is not loaded.")

        crop_fraction = crop_percent * PERCENT_TO_DECIMAL
        height, width = self._original_array.shape
        crop_height = int(height * crop_fraction / 2)
        crop_width = int(width * crop_fraction / 2)
        self._processed_array = np.ascontiguousarray(
            self._original_array[crop_height : height - crop_height, crop_width : width - crop_width]
        )
        pixelemon_LOG.info(f"Image cropped to {self._processed_array.shape}")
        self.image_scale = self.image_scale * (1.0 - crop_fraction)
        new_fov = f"{self.horizontal_field_of_view:.2f} x {self.vertical_field_of_view:.2f} degrees"  # noqa: E231
        pixelemon_LOG.info(f"New field of view is {new_fov}")

    def get_brightest_detections(self, count: int) -> Detections:
        detections = self.detections
        sorted_detections = sorted(detections, key=lambda det: det.total_flux, reverse=True)
        return Detections(sorted_detections[:count])

    @property
    def background(self) -> sep.Background:
        if self._processed_array is None:
            raise ValueError("Image array is not loaded.")

        bw = max(MIN_BACKGROUND_MESH_COUNT, int(self._processed_array.shape[1] / self.background_settings.mesh_count))
        bh = max(MIN_BACKGROUND_MESH_COUNT, int(self._processed_array.shape[0] / self.background_settings.mesh_count))
        pixelemon_LOG.info(f"Background mesh size: {bw}x{bh}")

        return sep.Background(
            self._processed_array,
            bw=bw,
            bh=bh,
            fw=self.background_settings.filter_size,
            fh=self.background_settings.filter_size,
            fthresh=self.background_settings.detection_threshold,
        )

    def remove_background(self):
        if self._processed_array is None:
            raise ValueError("Image array is not loaded.")
        self._processed_array = self._processed_array - self.background

    def reset(self):
        if self._original_array is None:
            raise ValueError("Image array is not loaded.")
        self._processed_array = self._original_array.copy()
        self.image_scale = 1.0
        pixelemon_LOG.info("Image reset to original")

    @property
    def horizontal_field_of_view(self) -> float:
        if self.telescope is None:
            raise ValueError("Telescope is not set.")
        return self.telescope.horizontal_field_of_view * self.image_scale

    @property
    def vertical_field_of_view(self) -> float:
        if self.telescope is None:
            raise ValueError("Telescope is not set.")
        return self.telescope.vertical_field_of_view * self.image_scale

    @property
    def diagonal_field_of_view(self) -> float:
        if self.telescope is None:
            raise ValueError("Telescope is not set.")
        return self.telescope.diagonal_field_of_view * self.image_scale

    @property
    def detections(self):
        if self._processed_array is None:
            raise ValueError("Image array is not loaded.")

        objects = sep.extract(
            self._processed_array,
            thresh=self.detection_settings.detection_threshold_sigma * self.background.globalrms,
            minarea=self.detection_settings.min_pixel_count,
            filter_kernel=self.detection_settings.gaussian_kernel,
            deblend_nthresh=self.detection_settings.deblend_mesh_count,
            deblend_cont=self.detection_settings.deblend_contrast,
            clean=False,
            segmentation_map=False,
        )
        pixelemon_LOG.info(f"Detected {len(objects)} objects")

        return Detections.from_sep_extract(objects)

    def plot(self):

        if self._processed_array is None:
            raise ValueError("Image array is not loaded.")

        # Compute display limits using zscale
        interval = ZScaleInterval()
        vmin, vmax = interval.get_limits(self._processed_array)

        # Apply asinh stretch and clip
        stretch = AsinhStretch()
        data_stretched = stretch((np.clip(self._processed_array, vmin, vmax) - vmin) / (vmax - vmin))

        data = np.flipud(data_stretched)

        detections = self.detections
        h = data_stretched.shape[0]

        # Display with Plotly
        fig = px.imshow(data, color_continuous_scale="Viridis", origin="lower", title="FITS Image")
        for o in detections:

            r = o.semi_major_axis
            xc, yc = o.x_centroid, h - 1 - o.y_centroid
            fig.add_shape(
                type="circle",
                x0=xc - r,
                x1=xc + r,
                y0=yc - r,
                y1=yc + r,
                xref="x",
                yref="y",
                line=dict(color="red", width=1),
                opacity=0.6,
            )
        fig.update_layout(xaxis_title="X (pixels)", yaxis_title="Y (pixels)")
        fig.show()

    @property
    def plate_solve(self) -> PlateSolve | None:
        if self._processed_array is None:
            raise ValueError("Image array is not loaded.")
        if self.telescope is None:
            raise ValueError("Telescope is not set.")

        fov = f"{self.horizontal_field_of_view:.2f} x {self.vertical_field_of_view:.2f} degrees"  # noqa: E231
        pixelemon_LOG.info(f"Solving {len(self.detections)} detections and FOV of {fov}")

        tetra_solve = TETRA_SOLVER.solve_from_centroids(
            self.get_brightest_detections(TETRA_SOLVER.settings.verification_star_count).y_x_array,
            size=self._processed_array.shape,
            fov_estimate=self.diagonal_field_of_view,
            fov_max_error=self.diagonal_field_of_view * 0.1,
        )

        if tetra_solve["RA"] is None:
            pixelemon_LOG.warning("Plate solve failed.")
            return None
        else:
            return PlateSolve.model_validate(tetra_solve)
