#! /usr/bin/python3
# -*- coding: utf-8 -*-


"""
    
    Example:

    >>> casa_icrs = SkyCoord(350.850000, +58.815000, unit='deg', frame='icrs')
    >>> vira_icrs = SkyCoord(187.70593075958, +12.39112329392, unit='deg', frame='icrs')
    >>> cyga_icrs = SkyCoord(299.86815190954, +40.73391573600, unit='deg', frame='icrs')
    >>> taua_icrs = SkyCoord(83.63308333333, +22.0145, unit='deg', frame='icrs')
    >>> 
    >>> target = ExtraSolar_Target(
    >>>    taua_icrs
    >>> )
    >>> #target = SolarSystem_Target(
    >>> #    'jupiter'
    >>> #)
    >>> 
    >>> target.precision = 'low'
    >>> transits = target.timeAtMeridianTransit(
    >>>     tMin=Time('2020-09-09 00:00:00'),
    >>>     duration=TimeDelta(86400*1, format='sec')
    >>> )
    >>> 
    >>> for transit in transits:
    >>>     crosstimes = target.timeAtElevation(
    >>>         tCenter=transit,
    >>>         elevation=np.array([1, 40])*u.deg,
    >>>         duration=TimeDelta(86400*1, format='sec')
    >>>     )
    >>>     print('transit', transit)
    >>>     print('elevation crosses', crosstimes)
    >>>     print('elev at transit', target.elevationAtMeridianTransit(transit))
"""


__author__ = 'Alan Loh'
__copyright__ = 'Copyright 2020, nenupy'
__credits__ = ['Alan Loh']
__maintainer__ = 'Alan'
__email__ = 'alan.loh@obspm.fr'
__status__ = 'Production'


from astropy.coordinates import (
    EarthLocation,
    SkyCoord,
    AltAz,
    Angle,
    FK5,
    solar_system_ephemeris,
    get_body
)
import astropy.units as u
from astropy.time import Time, TimeDelta
import numpy as np
#import more_itertools as mit
import argparse


nenufarLocation = EarthLocation(
    lat=47.376511 * u.deg,
    lon=2.192400 * u.deg,
    height=150 * u.m
)

allowedSolarSystem = [
    'sun',
    'moon',
    'mercury',
    'venus',
    'mars',
    'jupiter',
    'saturn',
    'uranus',
    'neptune'
]


# ============================================================= #
# ============================================================= #
class NenuFAR_Target(object):
    """
    """


    def __init__(self):
        self.precision = 'low' # apparent mean


    # --------------------------------------------------------- #
    # --------------------- Getter/Setter --------------------- #
    @property
    def precision(self):
        return self._precision
    @precision.setter
    def precision(self, p):
        allowedValues = ['low', 'mean', 'apparent']
        if not isinstance(p, str):
            raise TypeError(
                'precision should be a string'
            )
        p = p.lower()
        if not p in allowedValues:
            raise ValueError(
                'precision should be in {}'.format(
                    allowedValues
                )
            )
        self._precision = p


    # --------------------------------------------------------- #
    # ------------------------ Methods ------------------------ #
    def timeAtElevation(self, tCenter, elevation, duration=None):
        """
        """
        elevationCrossTimes = []

        if not isinstance(tCenter, Time):
            raise TypeError(
                'tMin should be an astropy.Time instance.'
            )
        if duration is None:
            duration = TimeDelta(86400, format='sec')
        elif not isinstance(duration, TimeDelta):
            raise TypeError(
                'duration should be an astropy.TimeDelta instance.'
            )
        if not isinstance(elevation, u.Quantity):
            raise TypeError(
                'elevation should be an astopy.Quantity instance.'
            )
        if elevation.isscalar:
            elevation = np.array([elevation.value]) * elevation.unit 
        
        largeDt = TimeDelta(3600, format='sec')
        mediumDt = TimeDelta(600, format='sec')
        smallDt = TimeDelta(60, format='sec')

        # Broad search
        times = self._timeRange(
            tMin=tCenter - duration/2.,
            tMax=tCenter + duration/2.,
            dt=largeDt
        )
        sourceEq = self.sourceAtTime(times, precession=False)
        hourAngle = self._localHourAngle(
            time=times
        )
        elevations = self._toElevation(hourAngle, sourceEq.dec)
        
        # Iterate over the elevation thresholds
        for el in elevation:
            
            elevationCrossTimes.append([])
            
            if all(elevations < el):
                # elevationCrossTimes[-1].append(times[0])
                # elevationCrossTimes[-1].append(times[0])
                elevationCrossTimes[-1].append('None')
                elevationCrossTimes[-1].append('None')
            elif all(elevations > el):
                elevationCrossTimes[-1].append(times[0])
                elevationCrossTimes[-1].append(times[-1])
                continue
            
            indices = self._boundaryIndices(elevations, el)            
        
            # Iterate over the transit(s)
            for index in indices:
                # Medium search
                medTimes = self._timeRange(
                    tMin=times[index],
                    tMax=times[index + 1],
                    dt=mediumDt
                )
                sourceEq = self.sourceAtTime(medTimes, precession=False)
                hourAngle = self._localHourAngle(
                    time=medTimes
                )
                medEl = self._toElevation(hourAngle, sourceEq.dec)
                medIndices = self._boundaryIndices(medEl, el)
                medIndex = medIndices[0]

                # Finest search
                smallTimes = self._timeRange(
                    tMin=medTimes[medIndex],
                    tMax=medTimes[medIndex + 1],
                    dt=smallDt
                )
                sourceEq = self.sourceAtTime(smallTimes, precession=False)
                hourAngle = self._localHourAngle(
                    time=smallTimes
                )
                smallEl = self._toElevation(hourAngle, sourceEq.dec)
                smallIndices = self._boundaryIndices(smallEl, el)

                elevationCrossTimes[-1].append(
                    smallTimes[smallIndices[0]] + smallDt/2.
                )
        
        return elevationCrossTimes


    # def timeAtElevation_old(self, tCenter, elevation, duration=None):
    #     """
    #     """
    #     elevationCrossTimes = []

    #     if not isinstance(tCenter, Time):
    #         raise TypeError(
    #             'tMin should be an astropy.Time instance.'
    #         )
    #     if duration is None:
    #         duration = TimeDelta(86400, format='sec')
    #     elif not isinstance(duration, TimeDelta):
    #         raise TypeError(
    #             'duration should be an astropy.TimeDelta instance.'
    #         )
    #     if not isinstance(elevation, u.Quantity):
    #         raise TypeError(
    #             'elevation should be an astopy.Quantity instance.'
    #         )
    #     if elevation.isscalar:
    #         elevation = np.array([elevation.value]) * elevation.unit 
        
    #     largeDt = TimeDelta(3600, format='sec')
    #     mediumDt = TimeDelta(600, format='sec')
    #     smallDt = TimeDelta(60, format='sec')

    #     # Broad search
    #     times = self._timeRange(
    #         tMin=tCenter - duration/2.,
    #         tMax=tCenter + duration/2.,
    #         dt=largeDt
    #     )
    #     sourceEq = self.sourceAtTime(times, precession=False)
    #     sourceHo = self._toAltaz(sourceEq, times)
        
    #     elevations = sourceHo.alt
        
    #     # Iterate over the elevation thresholds
    #     for el in elevation:
            
    #         elevationCrossTimes.append([])
            
    #         if all(elevations < el):
    #             elevationCrossTimes[-1].append(times[0])
    #             elevationCrossTimes[-1].append(times[0])
    #         elif all(elevations > el):
    #             elevationCrossTimes[-1].append(times[0])
    #             elevationCrossTimes[-1].append(times[-1])
    #             continue
            
    #         indices = self._boundaryIndices(elevations, el)            
        
    #         # Iterate over the transit(s)
    #         for index in indices:
    #             # Medium search
    #             medTimes = self._timeRange(
    #                 tMin=times[index],
    #                 tMax=times[index + 1],
    #                 dt=mediumDt
    #             )
    #             sourceEq = self.sourceAtTime(medTimes, precession=False)
    #             sourceHo = self._toAltaz(sourceEq, medTimes)
    #             medIndices = self._boundaryIndices(sourceHo.alt, el)
    #             medIndex = medIndices[0]

    #             # Finest search
    #             smallTimes = self._timeRange(
    #                 tMin=medTimes[medIndex],
    #                 tMax=medTimes[medIndex + 1],
    #                 dt=smallDt
    #             )
    #             sourceEq = self.sourceAtTime(smallTimes, precession=False)
    #             sourceHo = self._toAltaz(sourceEq, smallTimes)
    #             smallIndices = self._boundaryIndices(sourceHo.alt, el)

    #             elevationCrossTimes[-1].append(
    #                 smallTimes[smallIndices[0]] + smallDt/2.
    #             )
        
    #     return elevationCrossTimes



    def timeAtMeridianTransit(self, tMin, duration=None):
        """
        """
        transitTimes = []

        if duration is None:
            duration = TimeDelta(86400, format='sec')
        elif not isinstance(duration, TimeDelta):
            raise TypeError(
                'duration should be an astropy.TimeDelta instance.'
            )
        extDuration = TimeDelta(3600, format='sec')
        largeDt = TimeDelta(1800, format='sec')
        mediumDt = TimeDelta(60, format='sec')
        smallDt = TimeDelta(1, format='sec')

        # Broad search
        times = self._timeRange(
            tMin=tMin - extDuration,
            tMax=tMin + duration + extDuration,
            dt=largeDt
        )
        hourAngles = self._localHourAngle(
            time=times
        )
        indices = self._jumpIndices(hourAngles)
        
        # Iterate over the transit(s)
        for index in indices:
            # Medium search
            medTimes = self._timeRange(
                tMin=times[index],
                tMax=times[index + 1],
                dt=mediumDt
            )
            hourAngles = self._localHourAngle(
                time=medTimes
            )
            medIndices = self._jumpIndices(hourAngles)
            medIndex = medIndices[0]

            # Finest search
            smallTimes = self._timeRange(
                tMin=medTimes[medIndex],
                tMax=medTimes[medIndex + 1],
                dt=smallDt
            )
            hourAngles = self._localHourAngle(
                time=smallTimes
            )
            smallIndices = self._jumpIndices(hourAngles)

            # transitTimes.append(
            #     smallTimes[smallIndices[0]] + smallDt/2.
            # )
            transitTime = smallTimes[smallIndices[0]] + smallDt/2.
            if (transitTime >= tMin) and (transitTime <= tMin + duration):
                transitTimes.append(
                    transitTime
                )

        return transitTimes
    
    
    def elevationAtMeridianTransit(self, transitTime):
        """
        """
        source = self.sourceAtTime(transitTime)
        elevation = 90 - np.abs(nenufarLocation.lat.deg - source.dec.deg)
        return elevation


    def elevationAroundMeridianTransit(self, transitTime, duration=86400, dt=60, returnArray=False):
        """
        """
        if not isinstance(transitTime, Time):
            try:
                transitTime = Time(transitTime)
            except:
                raise TypeError("Wrong type for `transitTime`")
        if not isinstance(duration, TimeDelta):
            try:
                duration = TimeDelta(duration, format='sec')
            except:
                raise TypeError("Wrong type for `duration`")
        if not isinstance(dt, TimeDelta):
            try:
                dt = TimeDelta(dt, format='sec')
            except:
                raise TypeError("Wrong type for `dt`")
        
        nTimes = np.floor(duration/dt)
        if nTimes % 2 == 0:
            # Make it odd to center the transit time
            nTimes += 1
        times = transitTime + (np.arange(nTimes) - np.floor(nTimes/2)) * dt
        source = self.sourceAtTime(times, precession=False)
        hourAngle = self._localHourAngle(
            time=times
        )
        elevations = self._toElevation(hourAngle, source.dec)
        
        if returnArray:
            return times, elevations
        else:
            return [{'elevation': el.value, 'date': ti.isot} for el, ti in zip(elevations, times)]


    # --------------------------------------------------------- #
    # ----------------------- Internal ------------------------ #
    def _localSiderealTime(self, time):
        """ Fast computation of an approximate LST.
            http://www.roma1.infn.it/~frasca/snag/GeneralRules.pdf
        """
        if not isinstance(time, Time):
            raise TypeError(
                'astropy.Time instance expected.'
            )

        if self.precision == 'low':
            # Number of days since 2000 January 1, 12h UT
            nDays = time.jd - 2451545.
            # Greenwich mean sidereal time
            gmst = 18.697374558 + 24.06570982441908 * nDays
            gmst %= 24.
            # Local Sidereal Time
            lst = gmst + nenufarLocation.lon.hour
            if np.isscalar(lst):
                if lst < 0:
                    lst += 24
            else:
                lst[lst < 0] += 24.   
            return Angle(lst, 'hour')
        else:
            lst = time.sidereal_time(self.precision, nenufarLocation.lon)
            return lst


    def _localHourAngle(self, time):
        """
        """
        source = self.sourceAtTime(time)
        lst = self._localSiderealTime(time)
        ha = lst - source.ra
        twopi = Angle(360.000000 * u.deg)
        if ha.isscalar:
            if ha.deg < 0:
                ha += twopi
            elif ha.deg > 360:
                ha -= twopi
        else:
            ha[ha.deg < 0] += twopi
            ha[ha.deg > 360] -= twopi
        return ha


    @staticmethod
    def _timeRange(tMin, tMax, dt):
        """
        """
        if not (isinstance(tMin, Time) and isinstance(tMax, Time)):
            raise TypeError(
                'tMin and tMax should be astropy.Time instances.'
            )
        if not isinstance(dt, TimeDelta):
            raise TypeError(
                'dt should be an astropy.TimeDelta instance.'
            )
        nTimes = (tMax - tMin) / dt
        times = tMin + np.arange(nTimes + 1) * dt
        return times


    @staticmethod
    def _jumpIndices(angleArray):
        """ ``angleArray`` is expected to contain increasing
            angular values.
            This function enables to detect the jumps between 360
            and 0 deg for instance (or 180 and -180 deg).
        """
        indices = np.where(
            (np.roll(angleArray, -1) - angleArray)[:-1] < 0
        )[0]
        return indices
    

    @staticmethod
    def _boundaryIndices(array, value):
        """ Find the indices bounding ``value`` in ``array``.
        """
        indices = np.where(
            [
                ((array[i]<=value) & (array[i+1]>=value)) or ((array[i]>=value) & (array[i+1]<=value)) for i in range(0, array.size-1)
            ]
        )[0]
        return indices

    
    @staticmethod
    def _toAltaz(skycoord, time):
        """
        """
        if not isinstance(skycoord, SkyCoord):
            raise TypeError(
                'skycoord should be an astropy.SkyCoord instance.'
            )
        if not isinstance(time, Time):
            raise TypeError(
                'times should be an astropy.Time instance.'
            )
        hoFrame = AltAz(
            obstime=time,
            location=nenufarLocation
        )
        return skycoord.transform_to(hoFrame)


    @staticmethod
    def _toElevation(hourAngle, declination):
        r"""
            With the declination :math:`\delta`, the observer's
            latitude :math:`l`, the local Hour Angle :math:`h`,
            the source elevation :math:`\theta` is:

            .. math::
                \sin(\theta) = \sin(\delta) \sin(l) + \cos(\delta) \cos(l) \cos(h)
        """
        decRad = np.radians(declination)
        haRad = hourAngle.rad
        latRad = nenufarLocation.lat.rad
        sinEl = np.sin(decRad) * np.sin(latRad) +\
            np.cos(decRad) * np.cos(latRad) * np.cos(haRad)
        return np.degrees( np.arcsin(sinEl) )


    @staticmethod
    def _toAzimuth(hourAngle, declination):
        r"""
            .. math::
                \cos(\varphi) = \frac{\sin(\delta) - \sin(\theta) \sin(l)}{\cos(\theta) \cos(l)}
        """
        elDeg = self._toElevation(hourAngle, declination)
        decRad = np.radians(declination)
        elRad = np.radians(elDeg)
        latRad = nenufarLocation.lat.rad
        cosAz = (np.sin(decRad) - np.sin(elRad) * np.sin(latRad))/\
            (np.cos(elRad) * np.cos(latRad))
        azRad = np.arccos(cosAz)
        azDeg = np.degrees(azRad)
        posMask = np.sin(hourAngle.rad) > 0
        azDeg[posMask] -= 360
        azDeg[posMask] *= -1
        return azDeg


# ============================================================= #
# ============================================================= #


# ============================================================= #
# ============================================================= #
class SolarSystem_Target(NenuFAR_Target):
    """
    """


    def __init__(self, name):
        self.name = name

        super().__init__()


    # --------------------------------------------------------- #
    # --------------------- Getter/Setter --------------------- #
    @property
    def name(self):
        return self._name
    @name.setter
    def name(self, n):
        if not isinstance(n, str):
            raise TypeError(
                'Name of the source should be a string.'
            )
        n = n.lower()
        if n not in allowedSolarSystem:
            raise ValueError(
                'Solar System target should be one of {}'.format(
                    allowedSolarSystem
                )
            )
        self._name = n


    # --------------------------------------------------------- #
    # ------------------------ Methods ------------------------ #
    def sourceAtTime(self, times, precession=True):
        """
        """
        if not isinstance(times, Time):
            raise TypeError(
                'times should be an astropy.Time instance.'
            )
        with solar_system_ephemeris.set('builtin'):
            source =  get_body(
                self.name,
                times,
                nenufarLocation
            ) # GCRS
        source = SkyCoord(source.ra, source.dec) # ICRS
        if precession:
            if times.isscalar:
                source = source.transform_to(
                    FK5(equinox=times)
                )
            else:
                source = source.transform_to(
                    FK5(equinox=times[0])
                )
        return source # FK5 precessed at the current epoch

# ============================================================= #
# ============================================================= #


# ============================================================= #
# ============================================================= #
class ExtraSolar_Target(NenuFAR_Target):
    """
    """


    def __init__(self, skycoord):
        self.skycoord = skycoord

        super().__init__()


    # --------------------------------------------------------- #
    # --------------------- Getter/Setter --------------------- #
    @property
    def skycoord(self):
        return self._skycoord
    @skycoord.setter
    def skycoord(self, s):
        if not isinstance(s, SkyCoord):
            raise TypeError(
                'astropy.SkyCoord instance expected.'
            )
        self._skycoord = s


    # --------------------------------------------------------- #
    # ------------------------ Methods ------------------------ #
    def sourceAtTime(self, times, precession=True):
        """
        """
        if not isinstance(times, Time):
            raise TypeError(
                'times should be an astropy.Time instance.'
            )
        source = self.skycoord
        if precession:
            if times.isscalar:
                source = source.transform_to(
                    FK5(equinox=times)
                )
            else:
                source = source.transform_to(
                    FK5(equinox=times[0])
                )
                source = SkyCoord(
                    ra=np.ones(times.size) * source.ra,
                    dec=np.ones(times.size) * source.dec,
                    equinox=times[0]
                )
        return source

    # --------------------------------------------------------- #
    # ----------------------- Internal ------------------------ #

# ============================================================= #
# ============================================================= #


# ============================================================= #
# ============================================================= #
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '-t',
        '--time',
        type=str,
        help='Start time from which to search for horizontal coordinates (or transit time if --elevations==1)',
        required=True
    )
    parser.add_argument(
        '-v1',
        '--vmin',
        type=float,
        help='Minimal elevation (in degrees) to display',
        default=1.,
        required=False
    )
    parser.add_argument(
        '-v2',
        '--vok',
        type=float,
        help='Standard elevation (in degrees) to display',
        default=40.,
        required=False
    )
    parser.add_argument(
        '-r',
        '--ra',
        type=float,
        help='Right Ascension in degrees',
        default=1000,
        required=False
    )
    parser.add_argument(
        '-d',
        '--dec',
        type=float,
        help='Declination in degrees',
        default=1000,
        required=False
    )
    parser.add_argument(
        '-p',
        '--planet',
        type=str,
        help='Solar System object',
        default='None',
        required=False
    )
    parser.add_argument(
        '-n',
        '--npoints',
        type=int,
        help='Number of points to look for a transit',
        default=200,
        required=False
    )
    parser.add_argument(
        '-e',
        '--elevations',
        type=int,
        help='Get elevations around transit time if 1',
        default=0,
        required=False
    )
    args = parser.parse_args()


    if (args.ra != 1000) and (args.dec != 1000):
        source = SkyCoord(
            args.ra,
            args.dec,
            unit='deg'
        )
        target = ExtraSolar_Target(
            skycoord=source
        )
    elif args.planet.lower() != 'none':
        target = SolarSystem_Target(
            name=args.planet
        )
    else:
        raise ValueError(
            'Ambiguous between SolarSystem or ExtraSolar object.'
        )


    target.precision = 'low' # 'low'

    if args.elevations == 0:

        transits = target.timeAtMeridianTransit(
            tMin=Time(args.time),
            duration=TimeDelta(86400, format='sec')
        )

        response = ''
        for transit in transits:
            elTransit = target.elevationAtMeridianTransit(transit)
            response += '{:.19} 180.000 {:.3f} '.format(
               transit.isot,
               elTransit 
            )

            crosstimes = target.timeAtElevation(
                tCenter=transit,
                elevation=np.array([args.vmin, args.vok])*u.deg,
                duration=TimeDelta(86400, format='sec')
            )
            # Response for vok, ie. 40 degrees
            try:
                response += '{:.19} {:.19} '.format(
                    crosstimes[1][0].isot, crosstimes[1][1].isot
                )
            except AttributeError:
                response += 'null null '
            # Response for vmin, ie. 1 degrees
            try:
                response += '{:.19} {:.19} '.format(
                    crosstimes[0][0].isot, crosstimes[0][1].isot
                )
            except AttributeError:
                response += 'null null '

    elif args.elevations == 1:
        response = target.elevationAroundMeridianTransit(
            transitTime=args.time,
            duration=86400,
            dt=60,
        )
    
    print(response)
# ============================================================= #
# ============================================================= #

