# AUTOGENERATED! DO NOT EDIT! File to edit: notebooks/03_ctx.ipynb (unless otherwise specified).

__all__ = ['baseurl', 'storage_root', 'edrindex', 'catch_isis_error', 'CTXEDR', 'CTXEDRCollection']

# Cell

import warnings

import hvplot.xarray  # noqa
import rasterio
import rioxarray as rxr
from dask import compute, delayed
from .config import config
from .pds.apps import get_index
from .utils import file_variations, url_retrieve
from tqdm.auto import tqdm
from yarl import URL

try:
    from kalasiris.pysis import (
        ProcessError,
        ctxcal,
        ctxevenodd,
        getkey,
        mroctx2isis,
        spiceinit,
    )
except KeyError:
    warnings.warn("kalasiris has a problem initialing ISIS")

warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)
baseurl = URL(
    "https://pds-imaging.jpl.nasa.gov/data/mro/mars_reconnaissance_orbiter/ctx/"
)

storage_root = config.storage_root / "missions/mro/ctx"
edrindex = get_index("mro.ctx", "edr")

# Cell
def catch_isis_error(func):
    def inner(*args, **kwargs):
        try:
            func(*args, **kwargs)
        except ProcessError as err:
            print("Had ISIS error:")
            print(" ".join(err.cmd))
            print(err.stdout)
            print(err.stderr)

    return inner

# Cell
class CTXEDR:
    storage = storage_root / "edr"

    def __init__(self, id_):
        """Class to manage CTX data products."""
        self.product_id = id_
        (self.cub_name, self.cal_name, self.destripe_name) = file_variations(
            self.local_path, [".cub", ".cal.cub", ".dst.cal.cub"]
        )
        self.is_read = False
        self.is_calib_read = False

    @property
    def product_id(self):
        return self._product_id

    @product_id.setter
    def product_id(self, value):
        self.is_read = False
        self._product_id = value

    @property
    def id(self):
        "for laziness"
        return self.product_id

    @property
    def meta(self):
        s = edrindex.query("PRODUCT_ID == @self.id").squeeze()
        s.index = s.index.str.lower()
        return s

    @property
    def local_folder(self):
        return self.storage / self.id

    @property
    def local_path(self):
        return self.local_folder / f"{self.id}.IMG"

    @property
    def url(self):
        "Calculate URL from input dataframe row."
        url = baseurl / self.meta.volume_id.lower() / "data" / (self.id + ".IMG")
        return url

    def download(self, overwrite=False):
        self.local_folder.mkdir(parents=True, exist_ok=True)
        if self.local_path.exists() and not overwrite:
            print("File exists. Use `overwrite=True` to download fresh.")
            return
        url_retrieve(self.url, self.local_path)

    @catch_isis_error
    def isis_import(self):
        mroctx2isis(from_=self.local_path, to=self.cub_name)

    @catch_isis_error
    def spice_init(self):
        spiceinit(from_=self.cub_name, web="yes")

    @catch_isis_error
    def calibrate(self):
        ctxcal(from_=self.cub_name, to=self.cal_name)
        self.is_calib_read = False

    @catch_isis_error
    def destripe(self):
        if self.do_destripe():
            ctxevenodd(from_=self.cal_name, to=self.destripe_name)
            self.destripe_name.rename(self.cal_name)

    @catch_isis_error
    def do_destripe(self):
        value = int(
            getkey(
                from_=self.cal_name,
                objname="isiscube",
                grpname="instrument",
                keyword="SpatialSumming",
            )
        )
        return False if value == 2 else True

    def calib_pipeline(self, overwrite=False):
        if self.cal_name.exists() and not overwrite:
            return
        pbar = tqdm("isis_import spice_init calibrate destripe".split())
        for name in pbar:
            pbar.set_description(name)
            getattr(self, name)()
        pbar.set_description("Done.")

    def read_edr(self):
        "`da` stands for dataarray, standard abbr. within xarray."
        if not self.local_path.exists():
            raise FileNotFoundError("EDR not downloaded yet.")
        if not self.is_read:
            self.edr_da = rxr.open_rasterio(self.local_path)
            self.is_read = True
        return self.edr_da

    def read_calibrated(self):
        "`da` stands for dataarray, standard abbr. within xarray."
        if not self.is_calib_read:
            self.cal_da = rxr.open_rasterio(self.cal_name)
            self.is_calibd_read = True
        return self.cal_da

    def plot_da(self, data=None):
        data = self.edr_da if data is None else data
        return data.isel(band=0, drop=True).hvplot(
            x="y", y="x", rasterize=True, cmap="gray", data_aspect=1
        )

    def plot_calibrated(self):
        return self.plot_da(self.read_calibrated())

    def __str__(self):
        s = f"PRODUCT_ID: {self.product_id}\n"
        s += f"URL: {self.url}\n"
        s += f"Local: {self.local_path}\n"
        try:
            s += f"Shape: {self.read_edr().shape}"
        except FileNotFoundError:
            s += f"Not downloaded yet."
        return s

    def __repr__(self):
        return self.__str__()

# Cell


class CTXEDRCollection:
    """Class to deal with a set of CTX products."""

    def __init__(self, product_ids):
        self.product_ids = product_ids

    def get_urls(self):
        """Get URLs for list of product_ids.

        Returns
        -------
        List[yarl.URL]
            List of URL objects with the respective PDS URL for download.
        """
        urls = []
        for p_id in self.product_ids:
            ctx = CTXEDR(p_id)
            urls.append(ctx.url)
        self.urls = urls
        return urls

    def download_collection(self):
        lazys = []
        for p_id in self.product_ids:
            ctx = CTXEDR(p_id)
            lazys.append(delayed(ctx.download)())
        print("Launching parallel download...")
        compute(*lazys)
        print("Done.")

    def calibrate_collection(self):
        lazys = []
        for p_id in self.product_ids:
            ctx = CTXEDR(p_id)
            lazys.append(delayed(ctx.calib_pipeline)())
        print("Launching parallel calibration...")
        compute(*lazys)
        print("Done.")

    def calib_exist_check(self):
        return [(p_id, CTXEDR(p_id).cal_name.exists()) for p_id in self.product_ids]