#!/usr/bin/env python

from collections import defaultdict
from datetime import datetime
from datetime import timedelta
import os
import sys
import argparse

import numpy as np
from peewee import fn, JOIN
from astropy.time import Time

from sdssdb.peewee.sdss5db import opsdb, targetdb

from kronos.designCompletion import checker


r1_db = opsdb.Camera.get(label="r1").pk
b1_db = opsdb.Camera.get(label="b1").pk
ap_db = opsdb.Camera.get(label="APOGEE").pk
Design = targetdb.Design
Conf = opsdb.Configuration
Exp = opsdb.Exposure
Field = targetdb.Field
Cadence = targetdb.Cadence
DesignToStatus = opsdb.DesignToStatus
CompletionStatus = opsdb.CompletionStatus
Quickred = opsdb.Quickred
CameraFrame = opsdb.CameraFrame
Camera = opsdb.Camera


def brightCollate(expDicts):
    sn2 = np.array([e["sn2"] for e in expDicts if e["pk"] == ap_db])
    if len(sn2) == 0:
        return 0
    times = np.array([datetime.timestamp(e["start_time"]) for e in expDicts if e["pk"] == ap_db])
    pos = np.array([e["dither_pixpos"] for e in expDicts if e["pk"] == ap_db])
    a_val = np.min(pos)
    a = np.abs(pos - a_val) < 0.2
    b = ~a
    atimes = times[a]
    btimes = times[b]
    sn2a = sn2[a]
    sn2b = sn2[b]

    aMatch = list()
    bMatch = list()
    for i, t in enumerate(atimes):
        diff = [abs(t - bt) for bt in btimes]
        matched = np.where(np.array(diff) < 7200)
        if len(matched[0] > 0):
            aMatch.append(i)
            # matched are indices into btimes/sn2b now
            for j in matched[0]:
                if j not in bMatch:
                    bMatch.append(j)

    goodA = np.sum([sn2a[i] for i in aMatch])
    goodB = np.sum([sn2b[i] for i in bMatch])

    return goodA + goodB


def darkCollate(exps):
    # we joined to camera to get camera pk
    camera_pk = [e["pk"] for e in exps]
    sn2 = [e["sn2"] for e in exps]

    r = 0
    b = 0

    for c, s in zip(camera_pk, sn2):
        if s > 0.2:
            if c == r1_db:
                r += s
            elif c == b1_db:
                b += s

    return r, b


def checkDone(hours=1, test=False, daemon=False):
    start = datetime.now()
    useTime = start - timedelta(hours=hours)
    db_flavor = opsdb.ExposureFlavor.get(pk=1)  # science
    doneStatus = CompletionStatus.get(label="done").pk

    mostRecent = Exp.select(fn.MAX(Exp.start_time))\
                    .where(Exp.exposure_flavor == db_flavor)\
                    .scalar()

    if (start - mostRecent).seconds > 1800 and daemon:
        # no exposures in the last 30 minutes, don't do anything.
        # has to be 30 min b/c apogee creates exposure entry at the beginning of exposure
        # so could be 15-20 min, better safe probably?
        return

    recent_designs = Design.select(Design.design_id, Design.exposure,
                                   Cadence.nexp, Cadence.obsmode_pk,
                                   Cadence.max_length, Field.pk,
                                   fn.MAX(Exp.start_time))\
                           .join(Conf, on=(Design.design_id == Conf.design_id))\
                           .join(Exp)\
                           .switch(Design)\
                           .join(DesignToStatus, on=(Design.design_id == DesignToStatus.design_id))\
                           .switch(Design)\
                           .join(Field, on=(Design.field_pk == Field.pk))\
                           .join(Cadence, on=(Field.cadence_pk == Cadence.pk))\
                           .where(Exp.start_time > useTime,
                                  Exp.exposure_flavor == db_flavor,
                                  DesignToStatus.status != doneStatus)\
                           .group_by(Design.design_id, Design.exposure,
                                     Cadence.nexp, Cadence.obsmode_pk,
                                     Cadence.max_length, Field.pk)

    print(start.strftime("%Y-%m-%dT%H:%M:%S"), recent_designs.count())

    r1 = defaultdict(float)
    b1 = defaultdict(float)
    apsn2 = defaultdict(list)
    modes = dict()
    lastInEpoch = dict()
    for d in recent_designs.dicts():
        if d["obsmode_pk"] is None:
            continue
        # nexp is cadence nexp, an array
        expCount = [np.sum(d["nexp"][:i+1]) for i in range(len(d["nexp"]))]
        # exposure is 0 indexed, so the first exposure of an epoch will
        # have exposure number equal to the sum of previous epochs
        # so design.exposure will be >= sum(previous epochs) for all
        # designs in an epoch
        current_epoch = np.where(np.array(expCount) >= d["exposure"])[0][0]
        modes[d["design_id"]] = d["obsmode_pk"][current_epoch]
        bright = modes[d["design_id"]] == "bright_time"
        epochMaxLength = d["max_length"][current_epoch]
        if epochMaxLength < 0.1:
            epochMaxLength = 0.1
        # max is the max start time for all exps on the design
        maxLength = d["max"] - timedelta(days=epochMaxLength)
        if d["exposure"] == expCount[current_epoch] - 1 and not bright:
            # not bright because at the moment we treat all bright designs individually
            lastInEpoch[d["design_id"]] = True
            if current_epoch == 0:
                beginExp = 0
            else:
                beginExp = expCount[current_epoch - 1]
            exps = Exp.select(Camera.pk, CameraFrame.sn2)\
                      .join(Conf)\
                      .join(Design)\
                      .switch(Exp)\
                      .join(CameraFrame)\
                      .join(Camera)\
                      .where(Design.field_pk == d["pk"],  # we got field.pk earlier
                             Design.exposure >= beginExp,
                             Design.exposure <= d["exposure"],
                             Exp.start_time > maxLength)
        else:
            lastInEpoch[d["design_id"]] = False
            exps = Exp.select(Exp.start_time, Camera.pk,
                              CameraFrame.sn2, Quickred.dither_pixpos)\
                      .join(Conf)\
                      .join(Design)\
                      .switch(Exp)\
                      .join(Quickred, JOIN.LEFT_OUTER)\
                      .switch(Exp)\
                      .join(CameraFrame)\
                      .join(Camera)\
                      .where(Design.design_id == d["design_id"],
                             Exp.start_time > maxLength)
        if bright:
            apsn2[d["design_id"]] = brightCollate(exps.dicts())
        else:
            r1[d["design_id"]], b1[d["design_id"]] = darkCollate(exps.dicts())

    astroT = Time(start)
    astroT.format = "mjd"
    mjd_now = astroT.value
    summary = ""

    for a in apsn2:
        if modes[a] != "bright_time":
            continue
        done = checker[modes[a]].design(ap=apsn2[a])
        if done:
            statusString = "done"
        else:
            statusString = "partial"
        summary += f"{a:7d}: {statusString:7s} with {apsn2[a]:7.1f} \n"
        if test:
            continue
        if done:
            status = CompletionStatus.get(label="done")
        else:
            status = CompletionStatus.get(label="started")
        design_status = DesignToStatus.get(design=a)
        design_status.status = status
        design_status.mjd = mjd_now
        design_status.save()

    for r in r1:
        if modes[r] == "bright_time":
            continue
        if lastInEpoch[r]:
            done = checker[modes[r]].epoch(r=r1[r],
                                           b=b1[r])
        else:
            done = checker[modes[r]].design(r=r1[r],
                                            b=b1[r])
        if done:
            statusString = "done"
        else:
            statusString = "partial"
        summary += f"{r:7d}: {statusString:7s} with {r1[r]:4.1f}, {b1[r]:4.1f} \n"
        if test:
            continue
        if done:
            status = CompletionStatus.get(label="done")
        else:
            status = CompletionStatus.get(label="started")
        design_status = DesignToStatus.get(design=r)
        design_status.status = status
        design_status.mjd = mjd_now
        design_status.save()

    print("TOOK ", (datetime.now() - start).microseconds//1000, " ms")
    print(summary)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog=os.path.basename(sys.argv[0]),
        description='Add fake exp to local db')

    parser.add_argument('-s', '--since', type=int, default=1,
                        help='how long to look back', dest='since')
    parser.add_argument('-t', '--test', action="store_true",
                        help='only a test, no db commits')
    parser.add_argument('-d', '--daemon', action="store_true",
                        help='assume running as daemon, allow exit if no recent exps')

    args = parser.parse_args()
    since = args.since
    test = args.test
    daemon = args.daemon

    checkDone(since, test=test, daemon=daemon)
