#!/usr/bin/env python3

"""
实例单实例架构下的 MySQL 安装与卸载
"""

import re
import os
import sys
import time
import json
import shutil
import psutil
import logging
import argparse
import threading
from socketserver import ThreadingMixIn
from http.server import HTTPServer, BaseHTTPRequestHandler

from dbma import common
from dbma import monitor
from dbma import checkings
from dbma.initialization import is_root, get_uid_gid
from dbma.daemon import start_daemon, stop_daemon

# 定义全局的实例监控字典
instances = {}


def parser_cmd_args():
    """
    实现命令行参数的处理
    """
    name, *_ = os.path.basename(__file__).split('.')
    parser = argparse.ArgumentParser(name)
    parser.add_argument('--bind-port', type=int, default=8080,
                        help="http-server listening port")
    parser.add_argument('--bind-ip', type=str,
                        default="127.0.0.1", help="http-server listening ip")
    parser.add_argument('--monitor-user', type=str,
                        default='monitor', help="mysql monitor user")
    parser.add_argument('--monitor-password', type=str,
                        default="dbma@0352", help="password of mysql monitor user")
    parser.add_argument('--expire-seconds', type=int,
                        default=7, help="expire secondes")
    parser.add_argument('--log', type=str, default='info',
                        choices=['debug', 'info', 'warning', 'error'])
    parser.add_argument('--log-file', type=str,
                        default="/usr/local/dbm-agent/logs/dbm-monitor-gateway.log")
    parser.add_argument('action', type=str, choices=['start', 'stop'])
    args = parser.parse_args()
    return args


def config_log(log_file: str, log_level='info'):
    """
    配置日志
    """
    logger = logging.getLogger('dbm-agent')
    if log_level.upper() == 'INFO':
        logger.setLevel(logging.INFO)
    elif log_level.upper() == 'DEBUG':
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.ERROR)

    file_handler = logging.handlers.RotatingFileHandler(
        filename=log_file, maxBytes=1024*1024*20, backupCount=5)
    file_handler.setLevel(logging.DEBUG)

    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(threadName)s - %(levelname)s - %(message)s')

    file_handler.setFormatter(formatter)

    logger.addHandler(file_handler)


class ThreadingHttpServer(HTTPServer):
    """
    线程化的 http-server 类
    """
    logger = logging.getLogger(
        'dbm-agent').getChild("ThreadingHttpServer")


class MonitorItemHttpRequestHandler(BaseHTTPRequestHandler):
    """
    """
    # / 返回所有被监控的实例端口
    pattern_get_all_instances = re.compile(r"^/instances/{0,1}$")

    # /3306/ 返回当前实例的所有监控项
    pattern_list_all_items = re.compile(r"^/instances/(\d*)/{0,1}$")

    # /3306/com_select/ 返回给定监控项的值
    pattern_get_item = re.compile(r"^/instances/(\d*)/(\S*)/{0,1}$")

    # /metrics  返回 prometheus 的采集项
    prometheus_metrics = re.compile(r"^/metrics")

    logger = logging.getLogger(
        'dbm-agent').getChild("MonitorItemHttpRequestHandler")

    def log_message(self, format, *args):
        logger = self.logger.getChild("log_message")
        logger.info("%s - - [%s] %s\n" %
                    (self.address_string(),
                     self.log_date_time_string(),
                     format % args))

    def router(self):
        """
        实现 URL-路由
        """
        logger = self.logger.getChild("router")
        logger.debug(f"{self.path}")
        if self.pattern_get_all_instances.search(self.path):

            # /instances/
            # 查询当前主机上的 MySQL 实例列表
            return self._get_all_instances
        elif self.pattern_list_all_items.search(self.path):

            # /instances/3306/
            # 列出 3306 实例上的所有监控项
            return self._list_all_items

        elif self.pattern_get_item.search(self.path):

            # /instances/3306/com_select/
            # 查询出 3306 实例上的 com_select 的值
            return self._get_item
        elif self.prometheus_metrics.search(self.path):
            # prometheus
            return self._prometheus
        else:
            # 404
            return self._get_404

    def _get_all_instances(self):
        """
        """
        global instances
        ports = [int(port) for port in instances]

        self.send_response(200)
        self.send_header('Content-type', 'application/json')
        self.end_headers()
        self.wfile.write(json.dumps(ports).encode("utf8"))

    def _list_all_items(self):
        """
        从给定的 port 中列出所有的监控项
        """
        port = self.pattern_list_all_items.search(self.path).group(1)
        if port in instances:
            items = {k: v for k, v in instances[port].items()}
            self.send_response(200)
            self.send_header('Content-type', 'application/json')
            self.end_headers()
            self.wfile.write(json.dumps(items, indent=4).encode("utf8"))
            return

        # 如果执行到这里说明 port 没有里面，也就是说查询了一个不存在的 port
        self.send_response(200)
        self.send_header('Content-type', 'application/json')
        self.end_headers()
        message = {'error-msg': f'mysqld-{port} not exists in this host'}
        self.wfile.write(json.dumps(message).encode("utf8"))

    def _get_item(self):
        """
        """
        port = self.pattern_get_item.search(self.path).group(1)
        item_key = self.pattern_get_item.search(self.path).group(2)

        if port in instances:

            # 说明这个实例存在
            item_value = instances[port][item_key]
            item = {item_key: item_value}
            self.send_response(200)
            self.send_header('Content-type', 'application/json')
            self.end_headers()
            self.wfile.write(json.dumps(item, indent=4).encode("utf8"))
            return

        # 执行到这里说明实例不存在
        self.send_response(200)
        self.send_header('Content-type', 'application/json')
        self.end_headers()
        message = {'error-msg': f'mysqld-{port} not exists in this host'}
        self.wfile.write(json.dumps(message).encode("utf8"))

    def _get_404(self):
        """
        """
        self.send_response(404)
        self.send_header('Content-type', 'application/json')
        self.end_headers()
        message = {'error-msg': f'404 not found'}
        self.wfile.write(json.dumps(message).encode("utf8"))

    def _prometheus(self):
        """
        """
        logger = self.logger.getChild("_prometheus")
        logger.info("start.")
        self.send_response(200)
        self.send_header('Content-type', 'text/plain')
        self.end_headers()
        #
        if os.path.isfile("/tmp/data.txt"):
            with open("/tmp/data.txt") as f:
                content = f.read()
        else:
            content = "com_select 100\n"
        self.wfile.write(content.encode("utf8"))

    def do_GET(self):
        """
        """
        logger = self.logger.getChild("do_get")
        logger.info(f"http get query {self.path}")

        view = self.router()
        view()


def main():
    """
    # 监控网关的主程序
    """
    args = parser_cmd_args()

    # 程序要求以 root 身份运行
    if is_root() == False:

        print(f"must use root user to execute this program.", file=sys.stderr)
        sys.exit(1)

    # dbma 用户是否存在
    if not checkings.is_user_exists('dbma'):
        print("user 'dbma' not exists,use this fix it 'dbm-agent init' ")
        sys.exit(2)

    # 切换到普通用户 dbma
    uid, gid = get_uid_gid('dbma')
    os.setegid(gid)
    os.seteuid(uid)

    # 查询用于保存日志文件的目录是否存在
    dir_name = os.path.dirname(args.log_file)
    if not checkings.is_directory_exists(dir_name):

        # 如果目录不存在就退出程序
        print(f"diectory '{dir_name}' not exists ")
        sys.exit(3)

    # netsta 命令是否存在
    if not checkings.is_file_exists("/usr/bin/netstat"):
        print(
            "/usr/bin/netstat not exists,please install it 'sudo yum -y install net-tools' ")
        sys.exit(4)

    # 切到守护进程模式
    if args.action == 'stop':
        stop_daemon('/tmp/dbm-monitor-gateway.pid')
    elif args.action == 'start':
        start_daemon('/tmp/dbm-monitor-gateway.pid')
    else:
        print("not suported action")
        exit(4)

    # 程序运行到这里说明可以配置日志了
    config_log(log_file=args.log_file, log_level=args.log)

    # 拿到日志对象
    logger = logging.getLogger('dbm-agent')
    logger.info("dbm-monitor-gateway start")
    print(f"Successful start and log file save to '{args.log_file}' ")

    # 检查当前主机上 MySQL 实例的端口信息
    mmps = monitor.Mmps(monitor_user=args.monitor_user,
                        monitor_password=args.monitor_password)

    # 正常的话是这样的格式 [3306,3307,3306]
    ports = mmps.get_all_sql_port()
    logger.info(f"first mmps start,find {ports}")

    # 根据已经存在的实例创建监控对象
    global instances

    for port in ports:
        logger.info(f"create a monitor thread for mysqld-{port}")
        instances[str(port)] = monitor.MySQLMonitor(host='127.0.0.1', port=port,
                                                    monitor_user=args.monitor_user,
                                                    monitor_password=args.monitor_password,
                                                    expire_seconds=args.expire_seconds)
        # 启动监控线程
        instances[str(port)].start()
        logger.info(f"monitor thread of mysqld-{port} started")

    # 执行到这里那么对于每一个已经发现的实例都有一个线程在监控它

    # 定义一个新的函数，周期性的做实例发现

    def mysql_port_scan():
        """
        """
        while True:
            try:
                logger.info("start periodic mmps")
                mmps = monitor.Mmps(monitor_user=args.monitor_user,
                                    monitor_password=args.monitor_password)
                ports = mmps.get_all_sql_port()

                for port in ports:
                    if str(port) in instances:

                        # 如果给定的实例已经存在那么什么都不做
                        logger.info(f"mysqld-{port} has been monitored,skip")
                        pass
                    else:

                        # 如果给定的实例还不存在那就加到监控
                        logger.warning(f"found new instances 'mysqld-{port}' ")
                        instances[str(port)] = monitor.MySQLMonitor(host='127.0.0.1', port=port,
                                                                    monitor_user=args.monitor_user,
                                                                    monitor_password=args.monitor_password)
                        instances[str(port)].start()
            except Exception as err:
                pass
                # 如果是遇到了异常什么都不用做

                #   每 5 分钟执行一次自动发现
            time.sleep(60)

    # 在单独的线程中执行函数
    t_port_scan = threading.Thread(
        target=mysql_port_scan, daemon=True, name="periodic-mmps")
    t_port_scan.start()

    # 开启 http-server 服务
    ip = args.bind_ip
    port = args.bind_port
    with ThreadingHttpServer((ip, port), MonitorItemHttpRequestHandler) as httpd:
        httpd.serve_forever()

    logger.error("this ending")


if __name__ == "__main__":
    main()
