#!/usr/bin/env python3
#-*- coding: UTF-8 -*-

"""
自动的向给定的库表中完全随机的填充数据
"""

import os
import sys
import logging
import argparse
import concurrent.futures
from mysql import connector
from datetime import datetime
from collections import namedtuple
from mtls.values import InsertSQL,TableMeta
from concurrent.futures import ThreadPoolExecutor,ProcessPoolExecutor


name = os.path.basename(__file__)

InsertStat = namedtuple('InsertStat','start_at end_at tps cost_time')

def parser_cmd_args():
    """
    处理命令行参数
    """
    args = argparse.ArgumentParser(name)
    args.add_argument("--host",type=str,default="127.0.0.1",help="MySQL 主机 IP ")
    args.add_argument("--port",type=int,default=3306,help="MySQL 端口")
    args.add_argument("--user",type=str,default="appuser",help="用户名")
    args.add_argument("--password",type=str,default="mtls@0352",help="密码")
    args.add_argument("--database",type=str,default="tempdb",help="库名")
    args.add_argument("--table",type=str,default="t",help="表名")
    args.add_argument("--rows",type=int,default=100,help="要插入的行数")
    args.add_argument("--process",type=int,default=1,help="并发的进程数")
    args.add_argument("action",choices=('review','execute'))
    return args.parse_args()


def insert(host="127.0.0.1",port=3306,user="apuser",password="mtls@0352",database="tempdb",table="t",rows=100):
    """

    """
    t_meta = TableMeta(host,port,user,password,database,table)

    if t_meta.err != None and len(t_meta.meta):
        # 
        logging.exception(t_meta.err)
        return []

    # 如果执行到这里，说明表存在，并且可以正常的取得元数据
    meta = [_ for _ in t_meta.meta]
    ist = InsertSQL(database,table,meta)
    
    # 
    cnx = None
    start_at = datetime.now()
    try:
        cnx = connector.connect(host=host,port=port,user=user,password=password)
        cursor = cnx.cursor()

        for i in range(rows):
            sql,args,*_ = ist[i]
            cursor.execute(sql,args)
            cnx.commit()
    
    except Exception as err:
        logging.exception(str(err))
        return None
    finally:
        if hasattr(cnx,'close'):
            cnx.close()
    end_at = datetime.now()

    # 微秒级的精度
    cost_time = (end_at - start_at).total_seconds()
    tps = rows
    stat = InsertStat(start_at=start_at,end_at=end_at,tps=tps,cost_time=cost_time)
    return stat

def create_report(stats:InsertStat=None):
    """
    """
    assert stats is not None
    sum_tps = 0
    sum_cost_time = 0
    for s in stats:
        if hasattr(s,'tps'):
            sum_tps = sum_tps + s.tps
            sum_cost_time = sum_cost_time + s.cost_time
    cost_time = sum_cost_time / len(stats)
    tps = sum_tps / cost_time

    print("-"*36)
    print(f"|tps = {tps}")
    print(f"|cost_time = {cost_time}")
    print("-"*36)


def main():
    """
    """

    args = parser_cmd_args()
    # 多进程压力测试
    stats = []
    if args.process > 1:
        # 创建进程池
        with ProcessPoolExecutor(max_workers=args.process) as e:
            futures = [e.submit(insert,args.host,args.port,args.user,args.password,args.database,args.table,args.rows) 
                            for i in range(args.process)]
            
            for future in concurrent.futures.as_completed(futures):
                # 取得“期物的值”以此来触发执行
                _ = future.result()
                stats.append(_)
    else:
        # 单进程压力测试
        tmp = insert(args.host,args.port,args.user,args.password,args.database,args.table,args.rows)
        stats.append(tmp)

    print("\nReport:")
    create_report(stats)
    print("Compelete.\n")



if __name__ == "__main__":
    main()
    
