#!/usr/bin/env python3

import argparse
import matplotlib.pyplot as plt  # type: ignore
import numpy as np
import os
import slack  # type: ignore
import tempfile
import time
from datetime import datetime, timedelta
from tqdm import tqdm

parser = argparse.ArgumentParser(
    description="Get reports for your training process via Slack"
)
required = parser.add_argument_group("required arguments")
required.add_argument("--run_dir", type=str, help="tensorboard run dir", required=True)
required.add_argument("--interval_hour", type=float, help="report interval hour", required=True)
required.add_argument("--slack_channels", type=str, required=True,
                      help="slack channels (comma separated) which messages will be sent (can be user ids)",
                      )
required.add_argument("--tag", type=str, action="append", help="tags to report. Default: all tags")
required.add_argument("--all_time", action='store_true', help="include all summaries (even old ones). Default:False")

args = parser.parse_args()

SLACK_BOT_TOKEN = os.environ.get("SLACK_BOT_TOKEN")
assert SLACK_BOT_TOKEN is not None, "missing `SLACK_BOT_TOKEN` env variable"

client = slack.WebClient(token=SLACK_BOT_TOKEN)

# these imports take time, keep them after argparse in case of bad args
from tensorboard_reporter.loader import load_summaries, group_by_tag
from tensorboard_reporter.stats import current_mean
from tensorboard_reporter.summary import tags, wall_time_in_range

while True:
    print('Time for a report!')

    summaries = list(load_summaries(args.run_dir))

    tags_to_report = args.tag or [s.tag for s in summaries]  # fallback to all tags
    summaries = filter(tags(tags_to_report), summaries)
    print(f'found {len(list(summaries))} summaries')

    if not args.all_time:
        summaries = filter(
            wall_time_in_range(
                min=datetime.now() - timedelta(hours=args.interval_hour), max=None
            ),
            summaries,
        )
        print(f'found {len(list(summaries))} summaries matching walltime')

    summaries_by_tag = group_by_tag(summaries)
    print(f'{len(list(summaries_by_tag))} summaries groupted by tags')

    fig, axs = plt.subplots(len(tags_to_report), sharex=True)

    stats_by_tag = dict()

    for i, (tag, summaries) in enumerate(summaries_by_tag.items()):
        ax = axs[i]
        ax.set_title(tag)

        steps = []
        values = []
        for summary in summaries:
            steps.append(summary.step)
            values.append(summary.value)

        values_np = np.array(values)

        stats_by_tag[tag] = dict(
            mean=values_np.mean(), min=values_np.min(), max=values_np.max()
        )
        ax.plot(steps, values)

    summary_text = ""
    for tag, stats in stats_by_tag.items():
        summary_text += "*{}*, min=*{:.2f}*, max=*{:.2f}*, mean=*{:.2f}* \n".format(
            tag, stats["min"], stats["max"], stats["mean"]
        )

    if stats_by_tag:
        with tempfile.NamedTemporaryFile() as file:
            fig.savefig(file.name, format="png")
            client.files_upload(
                channels=args.slack_channels,
                file=file.name,
                title="Report_{}".format(str(datetime.now())),
                initial_comment=summary_text,
            )

    print(summary_text)

    # waiting with progress bar
    wait_s = int(args.interval_hour * 60 * 60)
    time_iter = iter(time.sleep(1) for i in range(wait_s))
    list(tqdm(time_iter, 'next report in', wait_s, bar_format='{desc}: {remaining} |{bar}|'))
