import logging
from graph_add_op import CHGraph
from clickhouse_driver import Client


LOG_FORMAT = "%(asctime)s:%(levelname)s:%(name)s:%(funcName)s:%(message)s"
logging.basicConfig(filename='./CHGraph.log', level=logging.WARNING, format=LOG_FORMAT)
logger = logging.getLogger('ModelService')

def get_client(clickhouse_connect):
    graph_client = Client(host=clickhouse_connect["ip"])
    return graph_client

class ModelService(object):
    # 切换图空间
    def use_graph(self, graph_name, db):
        res = db.use_tables(graph_name)
        print(res)
        if res is None:
            logger.warning("graph name [" + graph_name + "] does not exist")
            return
        else:
            self.graph_name = graph_name
            self.graph_cfg = res
            logger.info("use graph [" + graph_name + "] done")

    # 根据条件过滤查询子图
    def search_subgraph_by_condition(self, data,
                                     config_param):
        if "subGraph" in data.keys():
            graph_name = data["subGraph"]
        else:
            return None
        # if "fieldList" in data.keys():
        #     fieldList = data["fieldList"]
        #     fields = ",".join(fieldList)
        #
        edge_types = None
        if "edgeTypes" in data.keys():
            edge_types = data["edgeTypes"]

        node_types = None
        if "nodeTypes" in data.keys():
            node_types = data["nodeTypes"]

        db = config_param["db"]
        #graph_client = config_param["graphClient"]
        clickhouse_connect = config_param["clickhouse_connect"]
        graph_client = get_client(clickhouse_connect)
        graph = CHGraph(graph_client)
        self.use_graph(graph_name, db)
        ######
        edges = self.graph_cfg["edges"]
        vertexes = self.graph_cfg["vertexes"]

        if "edgeConditions" in data.keys():
            edge_conditions = data["edgeConditions"]
            edge_condition_dict, edge_order_dict, special_sql_dict = ConditionOperation().conditionalOperation(
                edge_conditions, edges)
        else:
            edge_condition_dict, edge_order_dict = {}, {}
        if "nodeConditions" in data.keys():
            node_conditions = data["nodeConditions"]
            node_condition_dict, node_order_dict, special_sql_dict = ConditionOperation().conditionalOperation(
                node_conditions, vertexes)
        else:
            node_condition_dict, node_order_dict = {}, {}

        data = {}
        path_data = {}
        edge_data_list = []
        vertexes_data_list = []

        if edge_types:
            edge_data_list = executeController(edge_types, edges, edge_condition_dict, edge_order_dict, graph, "edge")
        else:
            # edge_data_list = executeController(edges, edges, edge_condition_dict, edge_order_dict, graph, "edge")
            pass
        if node_types:
            vertexes_data_list = executeController(node_types, vertexes, node_condition_dict, node_order_dict, graph,
                                                   "vertexes")
        else:
            # vertexes_data_list = executeController(vertexes, vertexes, node_condition_dict, node_order_dict, graph,
            #                                      "vertexes")
            pass

        path_data["graphEdges"] = edge_data_list
        path_data["graphNodes"] = vertexes_data_list
        data["pathList"] = path_data
        return data

    # 时间轴获取点边的模型
    def time_line_search(self, data, config_param):
        res_data = {}
        if "subGraph" in data.keys():
            graph_name = data["subGraph"]
        else:
            return None
        # if "fieldList" in data.keys():
        #     fieldList = data["fieldList"]
        #     fields = ",".join(fieldList)
        #
        edge_types = None
        if "edgeTypes" in data.keys():
            edge_types = data["edgeTypes"]

        node_types = None
        if "nodeTypes" in data.keys():
            node_types = data["nodeTypes"]

        db = config_param["db"]
        #graph_client = config_param["graphClient"]
        clickhouse_connect = config_param["clickhouse_connect"]
        graph_client = get_client(clickhouse_connect)
        graph = CHGraph(graph_client)
        self.use_graph(graph_name, db)
        ######
        edges = self.graph_cfg["edges"]
        vertexes = self.graph_cfg["vertexes"]
        if "edgeConditions" in data.keys() or "nodeConditions" in data.keys():
            if "edgeConditions" in data.keys():
                edge_conditions = data["edgeConditions"]
                edge_condition_dict, edge_order_dict, special_sql_dict = ConditionOperation().conditionalOperation(
                    edge_conditions, edges)
                if "sql" in special_sql_dict:
                    res = execute_graph_sql(graph, special_sql_dict["sql"])
            if "nodeConditions" in data.keys():
                node_conditions = data["nodeConditions"]
                node_condition_dict, node_order_dict, special_sql_dict = ConditionOperation().conditionalOperation(
                    node_conditions, vertexes)
                if "sql" in special_sql_dict:
                    res = execute_graph_sql(graph, special_sql_dict["sql"])
        if res:
            res_dict = res["data"]
            res_data["columns"] = res_dict["schema"]
            res_data["rowList"] = res_dict["detail"]
        return res_data

    # 统计任意两点间的边
    def count_src_dst_round(self, data, config_param):
        if "subGraph" in data.keys():
            graph_name = data["subGraph"]
        else:
            return None
        # if "fieldList" in data.keys():
        #     fieldList = data["fieldList"]
        #     fields = ",".join(fieldList)
        #
        edge_types = None
        if "edgeTypes" in data.keys():
            edge_types = data["edgeTypes"]

        db = config_param["db"]
        #graph_client = config_param["graphClient"]
        clickhouse_connect = config_param["clickhouse_connect"]
        graph_client = get_client(clickhouse_connect)
        graph = CHGraph(graph_client)
        self.use_graph(graph_name, db)
        ######
        edges = self.graph_cfg["edges"]

        if "edgeConditions" in data.keys():
            edge_conditions = data["edgeConditions"]
            edge_condition_dict, edge_order_dict, special_sql_dict = ConditionOperation().conditionalOperation(
                edge_conditions, edges)
        else:
            edge_condition_dict, edge_order_dict = {}, {}

        if edge_types:
            edge_data_list = executeController(edge_types, edges, edge_condition_dict, edge_order_dict, graph,
                                               "edgeCount")
        else:
            edge_data_list = executeController(edges.keys(), edges, edge_condition_dict, edge_order_dict, graph,
                                               "edgeCount")
        res_data = {}
        res_list = []
        if edge_data_list:
            for res_dict in edge_data_list:
                res = res_dict["data"]
                res_data["columns"] = res["schema"]
                res_data["rowList"] = res["detail"]
                res_data["type"] = res_dict["type"]
                res_list.append(res_data)

        return res_list


# 执行多条sql，并拼接返回
def executeController(array, schema, condition_dict, order_dict, graph, type):
    result = []
    for key in array:
        data = schema[key]
        condition = condition_splice(key, condition_dict)
        order = condition_splice(key, order_dict)
        if type == "edge":
            main_sql = edges_splice(data)
        elif type == "edgeCount":
            main_sql = edges_count_splice(data)
            group_splice = edges_group_splice(data)
        else:
            main_sql = vertexes_splice(data)
        if main_sql:
            # sql = main_sql + condition + order + " limit 10"
            if type == "edgeCount":
                sql = main_sql + condition + group_splice + order + " limit 10"
            else:
                sql = main_sql + condition + order
            vertexes_data = execute_graph_sql(graph, sql, key)
            result.append(vertexes_data)
    return result


def edges_group_splice(edges_schema):
    if "src" in edges_schema:
        src = edges_schema["src"]
    else:
        return

    if "dst" in edges_schema:
        dst = edges_schema["dst"]
    else:
        return
    group_splice = " group by " + src + "," + dst + " "
    return group_splice


# 组合sql执行，返回图对象
def execute_graph_sql(graph, sql, type=None):
    result_dict = {}
    data = {}
    res = graph.execute(sql)
    schemas = res.columns.values.tolist()
    field_data = res.values.tolist()
    data["schema"] = schemas
    data["detail"] = field_data
    result_dict["data"] = data
    if type:
        result_dict["type"] = type
    return result_dict


# 拼接条件
def condition_splice(type, condition={}):
    if len(condition) > 0 and type in condition:
        return condition[type]
    return ""


# 拼接点主体查询sql
def vertexes_splice(vertexes_schema):
    if "db" in vertexes_schema:
        db = vertexes_schema["db"]
    else:
        db = "default"

    if "table" in vertexes_schema:
        table = vertexes_schema["table"]
    else:
        return

    if "id" in vertexes_schema:
        id = vertexes_schema["id"]
    else:
        return

    if "label" in vertexes_schema:
        label = vertexes_schema["label"]
    else:
        label = ""
        # return
    main_sql = "select " + id + "," + label + " from " + db + "." + table
    # main_sql = "select " + id + " from " + db + "." + table
    return main_sql


# 拼接边主体查询sql
def edges_splice(edges_schema):
    if "db" in edges_schema:
        db = edges_schema["db"]
    else:
        db = "default"

    if "table" in edges_schema:
        table = edges_schema["table"]
    else:
        return

    if "src" in edges_schema:
        src = edges_schema["src"]
    else:
        return

    if "dst" in edges_schema:
        dst = edges_schema["dst"]
    else:
        return
    if "rank" in edges_schema:
        rank = edges_schema["rank"]
    else:
        return
    main_sql = "select " + src + ", " + dst + "," + rank + "  from  " + db + "." + table
    return main_sql


# 拼接边主体查询sql
def edges_count_splice(edges_schema):
    if "db" in edges_schema:
        db = edges_schema["db"]
    else:
        db = "default"

    if "table" in edges_schema:
        table = edges_schema["table"]
    else:
        return

    if "src" in edges_schema:
        src = edges_schema["src"]
    else:
        return

    if "dst" in edges_schema:
        dst = edges_schema["dst"]
    else:
        return
    if "rank" in edges_schema:
        rank = edges_schema["rank"]
    else:
        return
    main_sql = "select " + src + ", " + dst + ", count(1) count  from  " + db + "." + table
    return main_sql


# 拼接order by
def orderAttributeOperation(orderAttribute):
    if orderAttribute:
        order = ",".join(orderAttribute)
        return "order by " + order


# 拼接group by
# 暂时不用，没有处理聚合函数
def groupAttributeOperation(groupAttribute):
    if groupAttribute:
        group = ",".join(groupAttribute)
        return "group by " + group


# 遍历字符串并加上单引号
def foreachStringArray2String(array=[]):
    if array:
        str = None
        for tmp in array:
            if str:
                str = str + "'" + tmp + "',"
            else:
                str = "'" + tmp + "',"
        return str[0:len(str)]
    return


class ConditionOperation(object):
    singular = ["=", ">", "<", "like"]
    allCondition = ["in"]
    twoCondition = ["between"]
    special = ["timeType"]
    attribute_type_num = "num"
    attribute_type_time = "time"

    # 合并conditional条件
    def conditionalOperation(self, typeConditions, schemas={}):

        typeConditionDict = {}
        typeOrderDict = {}
        typeSpecialSql = {}
        if typeConditions:
            for typeCondition in typeConditions:
                val = None
                type = typeCondition["type"]
                schema = schemas[type]

                if "srcNodes" in typeCondition:
                    srcNodes = typeCondition["srcNodes"]
                    if srcNodes:
                        if "src" in schema:
                            if schema["src_data_type"] == self.attribute_type_num:
                                src_node = ",".join(srcNodes)
                            else:
                                src_node = foreachStringArray2String(srcNodes)
                            if val:
                                val = val + schema["src"] + " in(" + src_node + ") and "
                            else:
                                val = " where " + schema["src"] + " in(" + src_node + ") and "

                if "destNodes" in typeCondition:
                    destNodes = typeCondition["destNodes"]
                    if destNodes:
                        if "dst" in schema:
                            if schema["dst_data_type"] == self.attribute_type_num:
                                dest_node = ",".join(destNodes)
                            else:
                                dest_node = foreachStringArray2String(destNodes)
                            if val:
                                val = val + schema["dst"] + " in(" + dest_node + ") and "
                            else:
                                val = " where " + schema["dst"] + " in(" + dest_node + ") and "
                if "conditions" in typeCondition:
                    conditions = typeCondition["conditions"]
                    val_sql = val
                    val_dict = self.splice_condition(conditions, schema, val_sql)
                    if "sql" in val_dict:
                        typeSpecialSql["sql"] = val_dict["sql"]
                    val = val_dict["val"]

                if val:
                    typeConditionDict[type] = val + " 1=1"

                if "orderAttribute" in typeCondition:
                    type_order = typeCondition["orderAttribute"]
                    order_val = orderAttributeOperation(type_order)
                    typeOrderDict[type] = order_val
            return typeConditionDict, typeOrderDict, typeSpecialSql

        return

    # 拼接条件
    def splice_condition(self, conditions, schema, val_sql=None):
        if conditions:
            tem_dict = {}
            if val_sql:
                val = val_sql
            else:
                val = " where "
            for condition in conditions:
                symbol = condition["symbol"]
                attribute_type = condition["type"]
                if symbol in self.singular:
                    if attribute_type == self.attribute_type_num:
                        val = val + condition["attribute"] + condition["symbol"] \
                              + condition["conditional"][0] + " and "
                    elif attribute_type == self.attribute_type_time:
                        val = val + " formatDateTime(" + condition["attribute"] + ",'%F %T') " + condition["symbol"] \
                              + " '" + condition["conditional"][0] + "' and "
                    else:
                        val = val + condition["attribute"] + condition["symbol"] \
                              + "'" + condition["conditional"][0] + "' and "

                if symbol in self.allCondition:
                    if attribute_type == self.attribute_type_num:
                        val = val + condition["attribute"] + condition["symbol"] + "(" \
                              + ",".join(condition["conditional"]) + ")" + " and "
                    else:
                        val = val + condition["attribute"] + condition["symbol"] + "(" \
                              + foreachStringArray2String(condition["conditional"]) + ")" + " and "
                if symbol in self.twoCondition:
                    if attribute_type == self.attribute_type_num:
                        val = val + condition["attribute"] + " between " \
                              + condition["conditional"][0] + " and " \
                              + condition["conditional"][1] + " and "
                    elif attribute_type == self.attribute_type_time:
                        val = val + " formatDateTime(" + condition["attribute"] + ",'%F %T') " + condition["symbol"] \
                              + " '" + condition["conditional"][0] + "' and "
                    else:
                        val = val + condition["attribute"] + " between '" \
                              + condition["conditional"][0] + "' and '" \
                              + condition["conditional"][1] + "' and "
                if symbol in self.special:
                    sql = self.time_line_count_sql(condition, schema, val_sql)
                    tem_dict["sql"] = sql
            tem_dict["val"] = val
            return tem_dict

    # 处理特殊sql的拼接
    def time_line_count_sql(self, condition, schema, val_sql):
        if val_sql:
            val_sql = val_sql + " 1=1 "
            if condition["conditional"][0] == "year":
                sql = "select toYear(" + condition["attribute"] + ") year,count(1) count from " \
                      + schema["db"] + "." + schema["table"] + val_sql + " group by year order by year "
            elif condition["conditional"][0] == "quarter":
                sql = "select concat(toString(year),\' \',toString(quarter)) yearQuarter ,count from (select toYear(" + \
                      condition["attribute"] + ") year,toQuarter(" + condition["attribute"] \
                      + ") quarter ,count(1) count from " + schema["db"] + "." + schema["table"] \
                      + val_sql \
                      + " group by year,quarter order by year,quarter) "
            elif condition["conditional"][0] == "month":
                sql = "select concat(toString(year),\' \',toString(month)) yearMonth, count from ( select toYear(" + \
                      condition["attribute"] + ") year,toMonth(" + condition["attribute"] \
                      + ") month ,count(1) count from " + schema["db"] + "." + schema["table"] \
                      + val_sql \
                      + " group by year,month order by year,month)"
            elif condition["conditional"][0] == "week":
                sql = "select concat(toString(yw),\' \',toString(week)) ywWeek, count from (select toYearWeek(" + \
                      condition[
                          "attribute"] + ") yw, toDayOfWeek(" + condition["attribute"] + \
                      ") week ,count(1) count from " + schema["db"] + "." + schema["table"] \
                      + val_sql \
                      + " group by yw,week order by yw,week )"
            elif condition["conditional"][0] == "day":
                sql = "select formatDateTime(" + condition["attribute"] + ",'%F') day ,count(1) count from " \
                      + schema["db"] + "." + schema["table"] + val_sql + " group by day order by day"
            elif condition["conditional"][0] == "hours":
                sql = "select formatDateTime(" + condition["attribute"] + ",'%F %H') h ,count(1) count from " \
                      + schema["db"] + "." + schema["table"] + val_sql + " group by h order by h"
            elif condition["conditional"][0] == "minute":
                sql = "select formatDateTime(" + condition["attribute"] + ",'%F %R') m ,count(1) count from " \
                      + schema["db"] + "." + schema["table"] + val_sql + " group by m order by m"
            else:
                sql = "select formatDateTime(" + condition["attribute"] + ",'%F') day ,count(1) count from " \
                      + schema["db"] + "." + schema["table"] + val_sql + " group by day order by day"
        else:
            if condition["conditional"][0] == "year":
                sql = "select toYear(" + condition["attribute"] + ") year,count(1) count from " \
                      + schema["db"] + "." + schema["table"] + " group by year order by year "
            elif condition["conditional"][0] == "quarter":
                sql = "select concat(toString(year),\' \',toString(quarter)) yearQuarter ,count from (select toYear(" + \
                      condition["attribute"] + ") year,toQuarter(" + condition["attribute"] \
                      + ") quarter ,count(1) count from " + schema["db"] + "." + schema["table"] \
                      + " group by year,quarter order by year,quarter) "
            elif condition["conditional"][0] == "month":
                sql = "select concat(toString(year),\' \',toString(month)) yearMonth, count from ( select toYear(" + \
                      condition["attribute"] + ") year,toMonth(" + condition["attribute"] \
                      + ") month ,count(1) count from " + schema["db"] + "." + schema["table"] \
                      + " group by year,month order by year,month)"
            elif condition["conditional"][0] == "week":
                sql = "select concat(toString(yw),\' \',toString(week)) ywWeek, count from (select toYearWeek(" + \
                      condition[
                          "attribute"] + ") yw, toDayOfWeek(" + condition["attribute"] + \
                      ") week ,count(1) count from " + schema["db"] + "." + schema["table"] \
                      + " group by yw,week order by yw,week )"
            elif condition["conditional"][0] == "day":
                sql = "select formatDateTime(" + condition["attribute"] + ",'%F') day ,count(1) count from " \
                      + schema["db"] + "." + schema["table"] + " group by day order by day"
            elif condition["conditional"][0] == "hours":
                sql = "select formatDateTime(" + condition["attribute"] + ",'%F %H') h ,count(1) count from " \
                      + schema["db"] + "." + schema["table"] + " group by h order by h"
            elif condition["conditional"][0] == "minute":
                sql = "select formatDateTime(" + condition["attribute"] + ",'%F %R') m ,count(1) count from " \
                      + schema["db"] + "." + schema["table"] + " group by m order by m"
            else:
                sql = "select formatDateTime(" + condition["attribute"] + ",'%F') day ,count(1) count from " \
                      + schema["db"] + "." + schema["table"] + " group by day order by day"
        return sql


def main():
    from clickhouse_driver import Client
    graphClient = Client(host="10.202.255.93", port="9090", user="default", password="root")
    from graph_op import CHGraph
    graph = CHGraph(graphClient)
    res = graph.execute("select * from anti_money_launder.transactions limit 10")
    list = res.columns.values.tolist();
    print(res.values.tolist())
    print(list)


def test():
    import data_load_config as config
    config_params = config.load_config()
    str(config_params["db"])
    print(str(config_params["db"]))
    dict_data = {"subGraph": "anti_money_launder"}
    model_service = ModelService()
    result_dict = model_service.search_subgraph_by_condition(dict_data, config_params)
    print(result_dict)


if __name__ == '__main__':
    # main()
    # test()
    s = "ery5" % {"er": "rte"}
    print(s)
    val = 'ew'
    val1 = val
    val1 = "ew" + val1
    print(val, "========", val1)
