#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue May 12 15:23:49 2020

@author: afeher, mrojas, jwelsh
"""

import sys
import configparser
import threading
import queue
import snowflake.connector
import time
import json
import os
import fileinput
import re
import argparse
from datetime import datetime
from os import path
import math
from collections import Counter
import mmap
import itertools


def thread_function(con, index, max_stmnt, stmnt_q, created_q, failed_q, done_q):
    cur = con.cursor()    
    ebuf = []
    cur.execute("set quoted_identifiers_ignore_case = TRUE")
    cur.execute("alter session set TIMESTAMP_TYPE_MAPPING = TIMESTAMP_LTZ")
    while True:
      if (stmnt_q.qsize() > 0) and ((created_q.qsize() + failed_q.qsize()) < max_stmnt):  
          (file,stmnt) = stmnt_q.get()
          try:
              cur.execute(stmnt)
              created_q.put({"stmnt":stmnt,"file":file})
          except snowflake.connector.errors.ProgrammingError as e:
              ebuf.append({"error_msg":e, "statement":stmnt, "file":file})          
              failed_q.put(ebuf[len(ebuf)-1]) 
      else:
          break
    cur.close()
    done_q.put(index)
    return


    
def msg_thread_function(parallelism, msg_freq, session_id, no_of_stmnts, created_q, failed_q, done_q):
    run_dict = {
                "start_time": datetime.now().strftime("%d/%m/%Y %H:%M:%S"), 
                "session_id": session_id,
                "number_of_statements": no_of_stmnts,
                "number_of_created": 0,
                "number_of_failed": 0,
                "end_time": ""
               }
    c = 0
    f = 0
    print("Start time : ", run_dict["start_time"], "\n")
    print("Session ID : ", run_dict["session_id"])
    print("Parallelism: ", '{:5d}'.format(parallelism))
    print("# of stmts : ", '{:5d}'.format(run_dict["number_of_statements"]))
    print("Created.   : ", '{:5d}'.format(c), " Failed In Run:", '{:5d}'.format(f), end = "\r", flush = True)
    if msg_freq:
        time.sleep(msg_freq)
    while done_q.qsize() < parallelism:
      c = created_q.qsize()
      f = failed_q.qsize()
      print("Created.   : ", '{:5d}'.format(c), " Failed In Run:", '{:5d}'.format(f), end = "\r", flush = True)  
      if msg_freq:
        time.sleep(msg_freq)

    #time.sleep(2)
    run_dict["number_of_created"] = created_q.qsize()                                                   
    run_dict["number_of_failed"] = failed_q.qsize()
    run_dict["end_time"] = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
    print("                Total Created: ", '{:5d}'.format(run_dict["number_of_created"]), " Failed In Run:", '{:5d}'.format(run_dict["number_of_failed"]), end = "\r", flush = True)
    print("\n")
    print("            End time: ", run_dict["end_time"])
    print("\n")
    #time.sleep(20)
    while not done_q.empty():
        done_q.get()
    done_q.put(run_dict)
    return

def get_object_key(file):
    pattern = "<sc-\w+>\s*(\w+.\w+)\s*</sc-\w+>"
    f = re.search(pattern, file)
    key = "@@nokey"
    if f:
        key = f.groups()[0]
    return key

def remove_error_msg(msg):
    ret = msg
    i = msg.find("<sc")
    j = msg.find("</sc")
    k = msg[i:j].find("Error")
    if k > -1: ret = msg[0:i+k] + msg[j:] 
    i = ret.find("<sc")
    j = ret.find("</sc")
    k = ret[i:j].rfind("\n")
    if k != -1:
       ret = ret[0:k] + ret[k+1:] 
    return ret



def init(input_directory, workspace, split, splitpattern, object_type):
    #creates queue object with all statements
    stmnt_q = queue.Queue()
    qualified_exclude_dirs = [os.path.join(input_directory,x) for x in exclude_dirs]
    for dirpath, dirnames, files in os.walk(input_directory):
        print(f'Processing directory: {dirpath}')
        if dirpath in qualified_exclude_dirs:
           print(f'Directory {dirpath} was excluded')
        else:
            for file_name in files:
                fname, fextension = os.path.splitext(file_name)
                fextension = fextension.lower()
                if (fextension == ".sql"):
                    full_path = os.path.join(dirpath,file_name)
                    addfile = True
                    
                    if len(object_type) > 0:
                        parts = full_path.split(os.sep)
                        if parts[len(parts)-2] != object_type:
                            addfile = False
                        
                    if addfile:
                        print(f"Processing file {full_path}")
                        f = open(full_path)
                        # queuing a tuple (path, contents, errors)
                        contents = f.read()
                        creates_count = len(re.findall(splitpattern, contents))
                        relative_path = full_path.replace(workspace,"")
                        if creates_count > 1 and split:
                            print(f"File {full_path} seems to have more than one script")
                            # split and remove empty
                            parts = [s for s in re.split(splitpattern, contents) if s.strip() != '']
                            # Now we need to join the parts again
                            i = 0
                            joined_parts = []
                            while i < len(parts):
                               part = parts[i]
                               if re.match(splitpattern, part):
                                  joined_parts.append(part + parts[i+1])
                                  i = i + 1
                               i = i + 1
                            for sql_part in joined_parts:
                                stmnt_q.put((full_path,sql_part))
                        else:
                            stmnt_q.put((full_path,contents))
                        f.close()
    return stmnt_q


def calc_par(no_of_stmnts, parallelism):
    ret = parallelism
    if (no_of_stmnts/30) <= parallelism:
       ret = math.ceil(no_of_stmnts/30)
    return ret



def decode_error(argument):
    switcher = { 
        603: "PROCESS_ABORTED_DUE_TO_ERROR", 
        900: "EMPTY_SQL_STATEMENT", 
        904: "INVALID_IDENTIFIER",
        939: "TOO_MANY_ARGUMENTS_FOR_FUNCTION", 
        979: "INVALID_GROUP_BY_CLAUSE", 
        1002: "SYNTAX_ERROR_1", 
        1003: "SYNTAX_ERROR_2", 
        1007: "INVALID_TYPE_FOR_PARAMETER", 
        1038: "CANNOT_CONVERT_PARAMETER", 
        1044: "INVALID_ARG_TYPE_FOR_FUNCTION", 
        1104: "COLUMN_IN_SELECT_NOT_AGGREGATE_OR_IN_GROUP_BY", 
        1789: "INVALID_RESULT_COLUMNS_FOR_SET_OPERATION", 
        2001: "OBJECT_DOES_NOT_EXIST_1", 
        2003: "OBJECT_DOES_NOT_EXIST_2",  
        2016: "EXTRACT_DOES_NOT_SUPPORT_VARCHAR",  
        2022: "MISSING_COLUMN_SPECIFICATION",
        2025: "DUPLICATE_COLUMN_NAME",
        2026: "INVALID_COLUMN_DEFINITION_LIST",
        2028: "AMBIGUOUS_COLUMN_NAME",
        2262: "DATA TYPE MISMATCH WITH DEFAULT VALUE",
        2040: "UNSUPPORTED_DATA_TYPE",
        2140: "UNKNOWN_FUNCTION",
        2141: "UNKNOWN_USER_DEFINED_FUNCTION",
        2143: "UNKNOWN_USER_DEFINED_TABLE_FUNCTION", 
        2151: "INVALID_COMPONENT_FOR_FUNCTION_TRUNC",
        2212: "MATERIALIZED_VIEW_REFERENCES_MORE_THAN_1_TABLE",
        2401: "LIKE_ANY_DOES_NOT_SUPPORT_COLLATION",  
        2402: "LTRIM_WITH_COLLATION_REQUIRES_WHITESPACE_ONLY",  
        90105: "CANNOT_PERFORM_CREATE_VIEW",
        90216: "INVALID_UDF_FUNCTION_NAME"         
    } 
    return switcher.get(argument, "nothing") 



def main(input_script, workspace, split, splitpattern, object_type, sf_authenticator):
    if not os.path.exists(out_path):
        os.makedirs(out_path)
    global parallelism
    print(f"Connecting account:{sf_account} database: {sf_db} role: {sf_role} warehouse: {sf_warehouse} user: {sf_user} auth: {sf_authenticator}")
    if sf_authenticator:
        con = snowflake.connector.connect (
            account   = sf_account,
            user      = sf_user,
            password  = sf_password,
            database  = sf_db,
            role      = sf_role,
            warehouse = sf_warehouse,
            application = "mobilize.net",
            authenticator = sf_authenticator)
    else:
        con = snowflake.connector.connect (
            account   = sf_account,
            user      = sf_user,
            password  = sf_password,
            database  = sf_db,
            role      = sf_role,
            warehouse = sf_warehouse,
            application = "mobilize.net")
    stmnt_q = init(input_script, workspace, split, splitpattern, object_type)    
    created_q = queue.Queue()
    failed_q = queue.Queue()
    done_q = queue.LifoQueue()
    
    no_of_stmnts = stmnt_q.qsize()

    tot_created_last_run_end = 0

    stmnt_q_cur_run = stmnt_q
    no_of_stmnts_cur_run = no_of_stmnts

    run_num = 1

    while True:
        print("\n")
        print("Recursive Run", run_num, "...")
        parallelism = calc_par(no_of_stmnts_cur_run, parallelism)
        threads = list()
        for index in range(parallelism):
            x = threading.Thread(target=thread_function, args=(con, index, max_stmnt, stmnt_q_cur_run, created_q, failed_q, done_q, ))
            threads.append(x)
            x.start()
        x = threading.Thread(target=msg_thread_function, args=(parallelism, msg_freq, con.session_id, no_of_stmnts_cur_run, created_q, failed_q, done_q, ))
        threads.append(x)
        x.start()
        for index, thread in enumerate(threads):
            thread.join()
        if failed_q.qsize() == 0: 
            print("All objects successfully created.")
            con.close()
            break
        if created_q.qsize() == tot_created_last_run_end:
            print("No new objects created in previous run. Ending recursive runs.")
            con.close()
            break
        tot_created_last_run_end = created_q.qsize()
        stmnt_q_cur_run = queue.Queue()
        for stmnt in list(failed_q.queue):    
            y = remove_error_msg(stmnt["statement"])
            stmnt_q_cur_run.put((stmnt["file"],y))

        no_of_stmnts_cur_run = failed_q.qsize() 
        failed_q = queue.Queue()
        run_num = run_num + 1

    print("Creating output files...")
    for info in list(done_q.queue):
        if isinstance(info,dict):
             for item, val in info.items():  # dct.iteritems() in Python 2
                print("{} : {}".format(item.ljust(20), val))
    execution_summary = 1
    # Sometimes there is an int
    while isinstance(execution_summary, int):
        execution_summary = done_q.get()

    session_id = execution_summary["session_id"]

    list_created = list(created_q.queue)
    if len(list_created):
        outfilename = os.path.join(out_path, f"created_{session_id}.sql")
        f = open(outfilename,"w")
        outfilename = os.path.join(out_path, f"created_{session_id}.csv")
        fcsv = open(outfilename,"w")
        for e in list_created:
            fcsv.write(e["file"] + "\n")
            f.write(remove_error_msg(e["stmnt"]) + "\n")
        f.close()
        fcsv.close()
    
    errno_dict = {}
    error_list = []
    
    list_failed = list(failed_q.queue)
    if len(list_failed):
        outfilename = os.path.join(out_path, f"failed_{session_id}.sql")
        f = open(outfilename,"w")
        for em in list_failed:
            f.write(remove_error_msg(em["statement"]))  
            f.write("\n")
                    
            if errno_dict.get(em["error_msg"].errno) != None:
                
                error_file = errno_dict.get(em["error_msg"].errno)
                
            else:
                error_name = str(decode_error(em["error_msg"].errno))
                error_file = open(out_path + "error_" + str(execution_summary["session_id"]) + "_" + error_name + ".sql", "w+")
                errno_dict[em["error_msg"].errno] = error_file
            
            i = em["statement"].find("</sc")
            error_file.write(em["statement"][0:i] + "\n")
            error_file.write('Error {0} ({1}): {2} ({3})'.format(em["error_msg"].errno, em["error_msg"].sqlstate, em["error_msg"].msg, em["error_msg"].sfqid))
            error_file.write(em["statement"][i:])
    

        for stmnt in list_failed:
            error_name = str(decode_error(stmnt["error_msg"].errno))
            error_list.append((stmnt["file"],error_name,stmnt["error_msg"].errno,stmnt["error_msg"].sqlstate,stmnt["error_msg"].msg,stmnt["error_msg"].sfqid,stmnt['statement']))

    freq = Counter(error_list)
        
    outfilename = os.path.join(out_path, f"error_summary_{session_id}.sql")
    f = open(outfilename, "w+")
    f.write(str(freq))
    f.close()
    
    f = open(os.path.join(out_path , "error_list_summary.txt"), "w+")
    for (file,error_name,errno, sqlstate, msg, sfqid,statement) in error_list:
        msg = msg.replace('\n',' ')
        report = f"ERROR ({error_name}) in FILE:[{file}] ERR:{errno} SQLSTATE: {sqlstate} MSG: {msg} QUERYID: {sfqid}"
        f.write(report + "\n")
    f.close()
    
    f = open(os.path.join(out_path, f"error_list_summary_{session_id}.csv"), "w+")
    f.write("key,path,error,message\n")
    for (file,error_name,errno, sqlstate, msg, sfqid,statement) in error_list:
        msg = msg.replace('\n',' ')
        key = get_object_key(statement)
        msg = msg.replace(",","-")
        report = f"{key},{file},{error_name},{msg}"
        f.write(report + "\n")
    f.close()

    for key in errno_dict:
        
        errno_dict[key].close() 
    
    outfilename = os.path.join(out_path, f"execution_summary_{session_id}.json")
    f = open(outfilename, "w+")
    f.write(json.dumps(execution_summary,indent=2, sort_keys=True))
    f.close()
    
    ## Process done
    print("\nDone")
    sys.exit(len(error_list))
    return

def str2bool(v):
    if isinstance(v, bool):
       return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="""
    SnowConvertStudio Deployment Script
    ===================================

    This script helps you to deploy a collection of .sql files to a Snowflake Account.

    The tool will look for settings like:
    - Snowflake Account
    - Snowflake Warehouse
    - Snowflake Role
    - Snowflake Database

    If the tool can find a config_snowsql.ini file in the current directory or in the workspace\config_snowsql.ini location
    it will read those parameters from there.""",formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument("-A","--Account",dest="Account", help = "Snowflake Account")
    parser.add_argument("-D","--Database",dest="Database", help = "Snowflake Database")
    parser.add_argument("-WH","--Warehouse",dest="Warehouse", help = "Snowflake Warehouse")
    parser.add_argument("-R","--Role",dest="Role", help = "Snowflake Role")
    parser.add_argument("-U","--User",required=True,dest="User", help = "Snowflake User")
    parser.add_argument("-P","--Password",dest="Password", help = "Password")
    parser.add_argument("-W","--Workspace",dest="Workspace", help = "Path for workspace root. Defaults to current dir", default=os.getcwd())
    parser.add_argument("-I","--InPath",required=True, dest="InPath", help = "Path for SQL scripts")
    parser.add_argument("--activeConn",required=False, dest="ActiveConn", help = "When given, it will be used to select connection parameters forn config_snowsql.ini")
    parser.add_argument("--authenticator",dest="authenticator", help = "Use the authenticator with you want to use a different authentication mechanism")
    parser.add_argument("-L","--LogPath",dest="LogPath", help = "Path for process logs. Defaults to current dir",default=os.path.join(os.getcwd(),"logs"))
    parser.add_argument("--SplitPattern", help = "When provided it should be Regex Pattern to use to split scripts. Use capture groups to keep separator. E.g: (CREATE OR REPLACE)", default=r"(CREATE OR REPLACE)")
    parser.add_argument("--ObjectType", help = "Object Type to deploy table,view,procedure,function,macro", nargs='?', default="")


    args = parser.parse_args()
   
    if args.InPath and path.exists(args.InPath):
      input_script = args.InPath
      print(f"Using InputPath = {input_script}") 
    else:
      print("Input Path for SQL scripts does not exist.") 
      sys.exit(0)
    
    config = configparser.ConfigParser()
    
    msg_freq    = 10
    max_stmnt   = sys.maxsize
    parallelism = 0
    exclude_dirs= []
    ini_path = None

    # declare and initialize variables
    sf_account       = None
    sf_warehouse     = None
    sf_role          = None
    sf_db            = None
    sf_user          = None
    sf_password      = None
    sf_authenticator = None
    sf_application   = None

    # First try to find a config_snowsql.ini file
    if args.Workspace and os.path.exists(os.path.join(args.Workspace,'config_snowsql.ini')):
        ini_path = os.path.join(args.Workspace,'config_snowsql.ini')
    if not ini_path and os.path.exists('config_snowsql.ini'):
        ini_path = 'config_snowsql.ini'
    # if an ini file was found read some settings from there
    # we avoid reading user and password from this file
    if ini_path and config.read(ini_path):
      connectionSettingsSection = "connections"
      if args.ActiveConn:
          if args.ActiveConn != "default":
              connectionSettingsSection = "connections." + args.ActiveConn
      sf_account       = args.Account or config[connectionSettingsSection]['accountname']
      sf_warehouse     = args.Warehouse or config[connectionSettingsSection]['warehousename']
      sf_role          = args.Role or config[connectionSettingsSection]['rolename']
      sf_db            = args.Database or config[connectionSettingsSection]['dbname']
      sf_authenticator = args.authenticator or config[connectionSettingsSection]['authenticator']

    # If some settings are still pending try the arguments or environment variables
    sf_account      = args.Account or sf_account or os.getenv("SNOW_ACCOUNT")
    sf_warehouse    = args.Warehouse or sf_warehouse or os.getenv("SNOW_WAREHOUSE")
    sf_role         = args.Role or sf_role or os.getenv("SNOW_ROLE")
    sf_db           = args.Database or sf_db or os.getenv("SNOW_DATABASE")
    sf_user         = args.User or os.getenv("SNOW_USER")
    sf_password     = args.Password or os.getenv("SNOW_PASSWORD")

      
    out_path     = args.LogPath
    msg_freq     = int(os.getenv('CI_MSG_FREQ') or 0)
    parallelism  = 20
    exclude_dirs = []

    main(input_script, args.Workspace, args.SplitPattern,args.SplitPattern,args.ObjectType, args.authenticator)
