from tkinter import *
import matplotlib.pyplot as plt
from matplotlib.mathtext import math_to_image
from io import BytesIO
from PIL import ImageTk, Image
from globalemu.eval import evaluate
import numpy as np
import argparse
import os


def variable(name, lower, upper, position_x, position_y,
             initial, img_x_resize, img_y_resize, tickinterval,
             resolution=0.1):
    buffer = BytesIO()
    math_to_image(name, buffer, dpi=200, format='png')
    buffer.seek(0)

    img_var = ImageTk.PhotoImage(Image.open(buffer).resize(
        (img_x_resize, img_y_resize)))
    entry = Scale(
                window, from_=lower, to=upper,
                orient=HORIZONTAL, length=325,
                resolution=resolution,
                tickinterval=tickinterval,
                background='white', command=signal)
    entry.place(x=position_x, y=position_y)
    entry.set(initial)
    return entry, img_var


def init_signal(mode='null'):
    params = [1e-3, 46.5, 1e-2, 0.0775, 1.25, 1.5, 30]
    if xHI is True:
        idx = [0, 1, 2, 5, 3, 4, 6]
        params = [params[i] for i in idx]
    plt.figure(figsize=(4, 3))
    res = evaluate(params, base_dir=base_dir, xHI=xHI)
    plt.plot(res.z, res.signal, c='k')
    plt.xlabel('z')
    if xHI is False:
        plt.ylabel(r'$\delta T$ [mK]')
        plt.ylim([-250, 30])
    else:
        plt.ylabel(r'$x_{HI}$')
    plt.tight_layout()
    plt.savefig('img/img.png', dpi=100)
    plt.close()
    if mode == 'reset':
        new_img = ImageTk.PhotoImage(Image.open("img/img.png"))
        panel.configure(image=new_img)
        panel.image = new_img


def signal(_):
    f_star = 10**float(f_star_entry.get())
    Vc = float(Vc_entry.get())
    f_x = 10**float(f_x_entry.get())
    tau = float(tau_entry.get())
    alpha = float(alpha_entry.get())
    nu_min = float(nu_min_entry.get())
    rmfp = float(rmfp_entry.get())
    params = [f_star, Vc, f_x, tau, alpha, nu_min, rmfp]
    if xHI is True:
        idx = [0, 1, 2, 5, 3, 4, 6]
        params = [params[i] for i in idx]
    res = evaluate(params, base_dir=base_dir, xHI=xHI)
    plt.figure(figsize=(4, 3))
    plt.plot(res.z, res.signal, c='k')
    plt.xlabel('z')
    if xHI is False:
        plt.ylabel(r'$\delta T$ [mK]')
        plt.ylim([-250, 30])
    else:
        plt.ylabel(r'$x_{HI}$')
        plt.ylim([0, 1])
    plt.tight_layout()
    plt.savefig('img/img.png', dpi=100)
    plt.close()
    new_img = ImageTk.PhotoImage(Image.open("img/img.png"))
    panel.configure(image=new_img)
    panel.image = new_img


def reset():
    init_signal('reset')
    f_star_entry.set(np.log10(1e-3))
    Vc_entry.set(46.5)
    f_x_entry.set(np.log10(1e-2))
    tau_entry.set(0.0775)
    alpha_entry.set(1.25)
    nu_min_entry.set(1.5)
    rmfp_entry.set(30)


parser = argparse.ArgumentParser(
    description='Interactive exploration of the Global 21-cm Signal')
parser.add_argument(
    '--xHI', type=bool, const=True, default=False, nargs='?',
    help='a boolean indicating if the network is for' +
    'neutral fraction history.')
args = parser.parse_args()
xHI = args.xHI
if xHI is False:
    base_dir = 'T_release/'
else:
    base_dir = 'xHI_release/'

window = Tk()
window.geometry("800x450")
window.configure(background='white')

window.title('GlobalEmu GUI')

if not os.path.exists('img/'):
    os.mkdir('img/')

init_signal()

img = ImageTk.PhotoImage(Image.open("img/img.png"))
panel = Label(window, image=img)
panel.place(x=10, y=10)

f_star_entry, fstar_label = variable(
    r'$\log(f_*)$', -6, -0.3, 410, 10, -3, 50, 15, 1)
label = Label(window, image=fstar_label)
label.place(x=740, y=10)

Vc_entry, Vc_label = variable(
    r'$\log(V_c)$', 16.5, 76.5, 410, 70, 46.5, 50, 15, 5)
label = Label(window, image=Vc_label)
label.place(x=740, y=70)

f_x_entry, f_x_label = variable(
    r'$\log(f_x)$', -6, 2, 410, 130, -2, 50, 15, 1)
label = Label(window, image=f_x_label)
label.place(x=740, y=130)

tau_entry, tau_label = variable(
    r'$\tau$', 0.055, 0.1, 410, 190, 0.0775, 15, 15, 0.01, resolution=0.001)
label = Label(window, image=tau_label)
label.place(x=740, y=190)

alpha_entry, alpha_label = variable(
    r'$\alpha$', 1, 1.5, 410, 250, 1.25, 15, 15, 0.1, resolution=0.01)
label = Label(window, image=alpha_label)
label.place(x=740, y=250)

nu_min_entry, numin_label = variable(
    r'$\nu_\mathrm{min}$', 0.1, 3, 410, 310, 1.5, 50, 15, 0.5)
label = Label(window, image=numin_label)
label.place(x=740, y=310)

rmfp_entry, rmfp_label = variable(
    r'$R_\mathrm{mfp}$', 10, 50, 410, 370, 30, 40, 15, 10, resolution=1)
label = Label(window, image=rmfp_label)
label.place(x=740, y=370)

btn = Button(window, text='Reset', command=reset)
btn.place(x=180, y=360)

window.mainloop()
