#! /usr/bin/env python3

import pandas as pd
from astropy.coordinates import SkyCoord
import astropy.units as u
from collections import namedtuple
import numpy as np
from astropy.coordinates import SkyCoord
import astropy.units as u

import argparse

def plate2sky(x, y, linear=False):
    """
    Convert position on plate to position on sky, relative to plate centre.
    x and y are input as positions on the plate in microns, with (0, 0) at
    the centre. Return a named tuple (xi, eta) with the angular coordinates in arcseconds,
    relative to plate centre. If linear is set
    to True then a simple linear scaling is used, otherwise pincushion
    distortion model is applied.
    
    """
    
    # Define the return named tuple type
    AngularCoords = namedtuple('AngularCoords', ['xi', 'eta'])

    # Make sure we're dealing with floats
    x = np.array(x, dtype='d')
    y = np.array(y, dtype='d')

    if np.size(x) == 1 and np.size(y) == 1 and x == 0.0 and y == 0.0:
        # Plate centre, return zeros before you get an error
        return AngularCoords(0.0, 0.0)

    if linear:
        # Just do a really simple transformation
        plate_scale = 15.22 / 1000.0   # arcseconds per micron
        coords = AngularCoords(x * plate_scale, y * plate_scale)
    else:
        # Include pincushion distortion, found by inverting:
        #    x = xi * f * P(xi, eta)
        #    y = eta * f * P(xi, eta)
        # where:
        #    P(xi, eta) = 1 + p * (xi**2 + eta**2)
        #    p = 191.0
        #    f = 13.515e6 microns, the telescope focal length
        p = 191.0
        f = 13.515e6
        a = p * (x**2 + y**2) * f
        twentyseven_a_squared_d = 27.0 * a**2 * (-x**3)
        root = np.sqrt(twentyseven_a_squared_d**2 +
                       4 * (3 * a * (x**2 * f))**3)
        xi = - (1.0/(3.0*a)) * ((0.5*(twentyseven_a_squared_d +
                                      root))**(1.0/3.0) -
                                (-0.5*(twentyseven_a_squared_d -
                                       root))**(1.0/3.0))
        # Convert to arcseconds
        xi *= (180.0 / np.pi) * 3600.0
        eta = y * xi / x
        if np.size(x) > 1 and np.size(y) > 1:
            # Check for (0,0) values in input arrays
            zeros = ((x == 0.0) & (y == 0.0))
            xi[zeros] = 0.0
            eta[zeros] = 0.0
        coords = AngularCoords(xi, eta)

    return coords

def print_offsets(offset_x, offset_y, hexabundle_1, hexabundle_2):

    if offset_y <= 0:
        NS_direction = 'S'
    else:
        NS_direction = "N"

    if offset_x <= 0:
        EW_direction = 'E'
    else:
        EW_direction = 'W'

    if hexabundle_1 == 'CENTRE':
        print(f"\nMoving from the centre of the plate to Hexabundle {hexabundle_2}:")
    elif hexabundle_2 == 'CENTRE':
        print(f"\nMoving from Hexabundle {hexabundle_1} to the centre of the plate:")
    else:
        print(f"\nMoving from Hexabundle {hexabundle_1} to Hexabundle {hexabundle_2}:")

    print(f"\t--> offset {np.abs(offset_y):.1f} arcseconds {NS_direction} and {np.abs(offset_x):.1f} arcseconds {EW_direction}\n")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('filename', help='A Hector Tile file, usually called Tile_FinalFormat_...',)
    parser.add_argument('hexabundle_1', type=str, help='The name of the hexabundle to start from, or the string "centre" to use the plate centre')
    parser.add_argument('hexabundle_2', type=str, help='The name of the hexabundle to finish at, or the string "centre" to use the plate centre')

    args = parser.parse_args()

    fname = args.filename
    hexabundle_1 = args.hexabundle_1
    hexabundle_2 = args.hexabundle_2

    df = pd.read_csv(fname, skiprows=11)
    df = df.loc[df["type"] >= 0]

    # Capitalise each Hexabundle name
    hexabundle_1 = hexabundle_1.upper()
    hexabundle_2 = hexabundle_2.upper()


    if len(df) >= 28:
        raise ValueError(f"We seem to have {len(df)} bundles but Hector only has 27!")

    if hexabundle_1 == 'CENTRE':
        delta_x1 = 0
        delta_y1 = 0
    else:
        try:
            bundle_1 = df.loc[df.Hexabundle == f"{hexabundle_1}"]
        except KeyError:
            raise ValueError(f"The first Hexabundle is {index_1} but it doesn't appear in the file!")

        delta_x1, delta_y1 = plate2sky(bundle_1['x'] * 1000, bundle_1['y'] * 1000, linear=True)

    if hexabundle_2 == "CENTRE":
        delta_x2 = 0
        delta_y2 = 0
    else:
        try:
            bundle_2 = df.loc[df.Hexabundle == f"{hexabundle_2}"]
        except KeyError:
            raise ValueError(f"The second Hexabundle is {index_2} but it doesn't appear in the file!")


        # Calculate the offsets
        delta_x2, delta_y2  = plate2sky(bundle_2['x'] * 1000, bundle_2['y'] * 1000, linear=True)

    offset_x = delta_x2 - delta_x1
    offset_y = delta_y2 - delta_y1

    print_offsets(offset_x[0], offset_y[0], hexabundle_1, hexabundle_2)

