#!/usr/bin/env python3

import os
import argparse
from datetime import datetime, timedelta
import time
import tempfile

import numpy as np
import matplotlib.pyplot as plt  # type: ignore
import slack  # type: ignore

from tensorboard_reporter.loader import load_summaries, group_by_tag
from tensorboard_reporter.summary import tags, wall_time_in_range
from tensorboard_reporter.stats import current_mean


parser = argparse.ArgumentParser(
    description="Get reports for your training process via Slack"
)
required = parser.add_argument_group("required arguments")
required.add_argument("--run_dir", type=str, help="tensorboard run dir", required=True)
required.add_argument(
    "--tag", type=str, action="append", help="tags to report", required=True
)
required.add_argument(
    "--interval_hour", type=int, help="report interval hour", required=True
)
required.add_argument(
    "--slack_channels",
    type=str,
    help="slack channels (comma seperated) which messages will be sent (can be user ids)",
    required=True,
)

args = parser.parse_args()

SLACK_BOT_TOKEN = os.environ.get("SLACK_BOT_TOKEN")
assert SLACK_BOT_TOKEN is not None, "missing `SLACK_BOT_TOKEN` env variable"

client = slack.WebClient(token=SLACK_BOT_TOKEN)

while True:
    # time.sleep(args.interval_hour * 60 * 60)

    tags_to_report = args.tag

    summaries = load_summaries(args.run_dir)
    summaries = filter(tags(tags_to_report), summaries)
    summaries = filter(
        wall_time_in_range(
            min=datetime.now() - timedelta(hours=args.interval_hour), max=None
        ),
        summaries,
    )

    summaries_by_tag = group_by_tag(summaries)

    fig, axs = plt.subplots(len(tags_to_report), sharex=True)

    stats_by_tag = dict()

    for i, (tag, summaries) in enumerate(summaries_by_tag.items()):
        ax = axs[i]
        ax.set_title(tag)

        steps = []
        values = []
        for summary in summaries:
            steps.append(summary.step)
            values.append(summary.value)

        values_np = np.array(values)

        stats_by_tag[tag] = dict(
            mean=values_np.mean(), min=values_np.min(), max=values_np.max()
        )
        ax.plot(steps, values)

    summary_text = ""
    for tag, stats in stats_by_tag.items():
        summary_text += "*{}*, min=*{:.2f}*, max=*{:.2f}*, mean=*{:.2f}* \n".format(
            tag, stats["min"], stats["max"], stats["mean"]
        )

    with tempfile.NamedTemporaryFile() as file:
        fig.savefig(file.name, format="png")
        client.files_upload(
            channels=args.slack_channels,
            file=file.name,
            title="Report_{}".format(str(datetime.now())),
            initial_comment=summary_text,
        )

    print(summary_text)
    break
