#!/usr/bin/evn python3
"""
创建指定结构的表、并向其插件数据、主要用于分析数据库实例的写性能
"""
import os
import sys
import time
import mysql
import random
import string
import logging
import argparse
import threading
import concurrent.futures
from mysql import connector
from collections import namedtuple
from mysql.connector import errorcode
from concurrent.futures import ThreadPoolExecutor,ProcessPoolExecutor


name = os.path.basename(__file__)

def check_python_version() -> None:
    """
    检测当前的 python 版本是否被支持，只支持 python-3.0.x 以上的环境
    """
    if sys.version_info.major <= 2:
        print("only support python-3.x",file=sys.stderr)
        sys.exit(1)

def parse_cmd_arags() -> argparse.ArgumentParser:
    """
    处理命令行参数
    """
    def to_bool(value):
        return value.upper() in ['YES','TRUE','1','ON']
    
    parser = argparse.ArgumentParser(name)
    parser.add_argument('--host',type=str,default="127.0.0.1",help="mysql host")
    parser.add_argument('--port',type=int,default=3306,help="mysql port")
    parser.add_argument('--user',type=str,default='appuser',help="mysql user")
    parser.add_argument('--password',type=str,default='apps_352',help="mysql user's passowrd ")
    parser.add_argument('--database',type=str,default="tempdb",help="work schema(database)")
    parser.add_argument('--table',type=str,default="t",help="work table")
    parser.add_argument('--parallel',type=int,default=1,help="parallel workers")
    #parser.add_argument('--thread',type=int,default=1,help='parallel threads')
    parser.add_argument('--rows',type=int,default=100,help="rows")
    parser.add_argument('--log-level',type=str,choices=['info','debug','error'],default="info")
    parser.add_argument('--auto-primary-key',type=to_bool,default=True,choices=[False,True],help="whether table has primary key")
    parser.add_argument('--ints',type=int,default=0,help="int column counts")
    parser.add_argument('--floats',type=int,default=0,help="float column counts")
    parser.add_argument('--doubles',type=int,default=0,help='double column counts')
    parser.add_argument('--texts',type=int,default=0,help="text column counts")
    parser.add_argument('--varchars',type=int,default=0,help="varchar column counts")
    parser.add_argument('--varchar-length',type=int,default=128,help="varchar column length default 128")
    parser.add_argument('--decimals',type=int,default=0,help="decimal column counts")
    parser.add_argument('--decimal-precision',type=int,default=12,help="total digits length")
    parser.add_argument('--decimal-scale',type=int,default=2,help="the scale of decimal(the number of digits to the right of the decimal point)")
    parser.add_argument('action',type=str,choices=['create','drop','insert'])
    arags = parser.parse_args()
    return arags

def config_logger(args:argparse.ArgumentParser) -> None:
    """
    配置日志的输出格式
    """
    logger = logging.getLogger(name)
    if args.log_level == "debug":
        logger.setLevel(logging.DEBUG)
    elif args.log_level == "info":
        logger.setLevel(logging.INFO)
    elif args.log_level == "error":
        logger.setLevel(logging.ERROR)
    
    handler = logging.StreamHandler(sys.stderr)
    formater = logging.Formatter("%(asctime)s  %(name)s  %(process)d  %(threadName)s  %(levelname)s  %(message)s")
    handler.setFormatter(formater)
    logger.addHandler(handler)

def create(args:argparse.ArgumentParser):
    """
    根据 args 指定的参数来创建表
    """
    logger = logging.getLogger(name)
    columns = []
    # 检查是否自动加 primary key .
    if args.auto_primary_key == True:
        columns.append("id int not null auto_increment primary key")
    # 检查 int 字段的数量
    if args.ints >= 1:
        for i in range(args.ints):
            columns.append(f"i{i} int not null")
    # 检查 varchar 字段的数量
    if args.varchars >= 1:
        for c in range(args.varchars):
            columns.append(f"c{c} varchar({args.varchar_length}) not null")
    # 检查 float 字段的数量
    if args.floats >= 1:
        for f in range(args.floats):
            columns.append(f"f{f} float not null")
    # 检查 double 字段的数量
    if args.doubles >= 1:
        for d in range(args.doubles):
            columns.append(f"d{d} double not null")
    # 检查 decimal 字段的数量
    if args.decimals >= 1:
        if args.decimal_precision < args.decimal_scale:
            logger.error("decimal-precision argument must big then decimal-scale")
            sys.exit(2)
        if args.decimal_precision <=0:
            logger.error("decimal-precision argument must big then 0")
            sys.exit(3)
        for d in range(args.decimals):
            columns.append(f"dm{d} decimal({args.decimal_precision},{args.decimal_scale}) not null")
    # 检查 text 字段的数量
    if args.texts >=1:
        for t in range(args.texts):
            columns.append(f"t{t} text not null")
    
    # 如果没有指定任何类型的列，那么直接退出
    if len(columns) == 1 and args.auto_primary_key==True:
        logger.error(f"do not have any columns in table {args.database}.{args.table}")
        sys.exit(1)
    
    # 拼接 SQL 
    sql = f"create table {args.database}.{args.table} ( {','.join(columns)});"

    # 执行 SQL 语句创建表
    logger.info(f"create table sql statement: {sql}")
    cnx = None
    try:
        cnx = connector.connect(host=args.host,port=args.port,user=args.user,password=args.password,database=args.database)
        cursor = cnx.cursor()
        cursor.execute(sql)
        cnx.commit()
    except connector.Error as err:
        if err.errno == errorcode.ER_ACCESS_DENIED_ERROR:
            logger.error(f"host={args.host} port={args.port} user={args.user},passwrod={args.password}")
    finally:
        if cnx != None and hasattr(cnx,'close'):
            cnx.close()

    logger.info(f"complete")
    

def get_int_value():
    """
    整数生成器
    """
    while True:
        yield random.randint(1,200000000)
get_int_value = get_int_value()


def get_str_value(length=64):
    """
    随机字符串生成器
    """
    s = string.ascii_letters + string.digits 
    ss = [ i for i in s + s]
    if len(ss) <= length:
        raise RuntimeError(f"varchar length must less then {length}")
    while True:
        random.shuffle(ss)
        yield ''.join(ss[:length])
get_str_value = get_str_value()


def get_float_value():
    """
    浮点数生成器
    """
    while True:
        yield '{0:.2f}'.format(random.random() * 10000000000)
get_float_value = get_float_value()


# 类型与生成器的对应关系
value_generates = {
    'ints': get_float_value,
    'varchar': get_str_value,
    'float': get_float_value,
    'double': get_float_value,
    'decimal': get_float_value,
}

def insert_sql(args:argparse.ArgumentParser) -> tuple:
    """
    """
    logger = logging.getLogger(name)
    sql = None 
    columns = []
    if args.ints >=1:
        for i in range(args.ints):
            columns.append( (f'i{i}',get_int_value) )

    if args.varchars >=1:
        for v in range(args.varchars):
            columns.append( (f"c{v}",get_str_value) )
    
    if args.floats >= 1:
        for f in range(args.floats):
            columns.append( (f"f{f}",get_float_value) )
    
    if args.doubles >= 1:
        for d in range(args.doubles):
            columns.append( (f"d{d}",get_float_value) )

    if args.decimals >= 1:
        for dm in range(args.decimals):
            columns.append( (f"dm{dm}",get_float_value) )
    
    if args.texts >= 1:
        for t in range(args.texts):
            columns.append( (f"t{t}",get_str_value) )

    if len(columns) == 0:
        logger.error(f"columns counts equal 0")
        sys.exit(2)

    # 拼接出 SQL 模板
    names = [n for n,v in columns]
    s = ['%s' for n in names]
    sql = f"insert into {args.database}.{args.table} ({','.join(names)}) values({','.join(s)})"

    # 计算出动态的 SQL 参数
    sql_args = [v for n,v in columns]
    

    return (sql,sql_args)

def insert(args,rows):
    """
    执行 insert 操作
    """
    logger = logging.getLogger(name)
    sql,value_generates = insert_sql(args)
    cnx = None
    logger.info(f"sql statement: {sql}")
    #print(sql)
    try:
        cnx = connector.connect(host=args.host,port=args.port,user=args.user,password=args.password,database=args.database)
        cursor = cnx.cursor()
        for r in range(rows):
            values = [next(v) for v in value_generates]
            #logger.debug(values)
            cursor.execute(sql,values)
            cnx.commit()
            #logger.debug(f"{sql},{values}")
    except connector.Error as err:
        logger.info(str(err))
        exit(2)
    finally:
        if cnx and hasattr(cnx,'close'):
            cnx.close()

def drop(args):
    """
    删除表
    """
    logger = logging.getLogger(name)
    cnx = None
    try:
        cnx = connector.connect(host=args.host,port=args.port,user=args.user,password=args.password,database=args.database)
        cursor = cnx.cursor()
        cursor.execute(f"drop table {args.database}.{args.table};")
        cnx.commit()
    except connector.Error as err:
        logger.error(str(err))
        sys.exit(3)
    finally:
        if cnx and hasattr(cnx,'close'):
            cnx.close()


    

def main():
    check_python_version()
    args = parse_cmd_arags()
    config_logger(args)
    logger = logging.getLogger(name)

    if args.action == 'create':
        create(args)
    elif args.action == 'insert':
        start = time.time()
        logger.info(f"start time = {start}")
        logger.info("****")
        logger.info("****")

        # 多进程压力测试
        if args.parallel > 1:
            # 计算出每一个进程要执行的插入行数
            batch_size = int(args.rows / args.parallel)
            # 创建进程池
            with ProcessPoolExecutor(max_workers=args.parallel) as e:
                futures = [e.submit(insert,args,batch_size) for i in range(args.parallel)]
                for future in concurrent.futures.as_completed(futures):
                    # 取得“期物的值”以此来触发执行
                    _ = future.result()
        else:
            # 单进程压力测试
            insert(args,args.rows)
        # 完成后统计结果
        stop = time.time()
        duration = "{0:.2f}".format((stop - start))
        tps = "{0:.2f}".format(args.rows/(stop - start))
        logger.info("****")
        logger.info("****")
        logger.info(f"stop time = {stop}")
        logger.info(f"TPS:{tps} duration {duration}(s)")

    elif args.action == 'drop':
        drop(args)


if __name__ == "__main__":
    main()

