#!/usr/bin/env python
import pathlib

import paramiko
import requests
import sys
import time
import threading
import configparser
import os
import logging
from termcolor import colored

logger = logging.getLogger('MasterSSH')
logger.setLevel(logging.INFO)
logging.basicConfig(format='[%(asctime)s][%(levelname)s] %(threadName)s: %(message)s', datefmt='%Y/%m/%d %H:%M:%S')


class MasterSSH:
    connection_pool = {}
    selected_servers = []
    credentials = []
    config = None

    commandPool = []

    """
    Initalize everything
    """

    def __init__(self, args):
        # Set logging to debug
        if args.verbose:
            logger.setLevel(logging.DEBUG)

        # Load configuration
        self.parse_config(args)

        # Set log's path:
        self.init_logs()

        cred_source_type = None
        cred_source_path = None

        # Validate that file exists and assign it as credentials source
        if args.cred_file:
            if os.path.isfile(args.cred_file):
                cred_source_path = args.cred_file
                cred_source_type = 'file'
            else:
                self.print_error("%s does not exist!" % args.cred_file)
                sys.exit(1)

        # Assign URL as credentials source
        if args.cred_url:
            cred_source_path = args.cred_url
            cred_source_type = 'url'

        # If there are specific servers selected, store their names
        if args.servers:
            self.selected_servers = args.servers.strip().split(',')

        # If credentials path has not been defined, show an error
        if cred_source_path and args.manual:
            self.print_error("You need to specify a credentials getting method [--cred-url, --cred-file, --manual]")
            sys.exit(1)

        # If credentials source is file, pull info from it
        if cred_source_type == 'file':
            with open(cred_source_path, 'r') as f:
                for line in f.readlines():
                    data = line.strip().split(',')

                    self.credentials.append({
                        'name': data[0],
                        'host': data[1],
                        'username': data[2],
                        'password': data[3]
                    })

        # If credentials source is an URL, download the data
        if cred_source_type == 'url':
            try:
                request = requests.get(cred_source_path)
            except requests.ConnectionError:
                self.print_error("Failed to download data! Please check your URL.")
                sys.exit(1)
            except requests.HTTPError:
                self.print_error("Bad HTTP request!")
                sys.exit(1)
            except requests.Timeout:
                self.print_error("Connection to your domain timed out, please check your server!")
                sys.exit(1)
            except requests.TooManyRedirects:
                self.print_error("There are too many redirects for your URL, check your server configuration!")
                sys.exit(1)
            except Exception as e:
                self.print_error("Something went wrong!")
                self.print_exception(e)
                sys.exit(1)

            response = request.text.strip()

            if request.status_code not in [200, 301, 302]:
                self.print_error("%s did not respond correctly: %s!" % (cred_source_path, request.status_code))
                sys.exit(1)

            if response == "":
                self.print_error("%s does not contain any data!" % cred_source_path)
                sys.exit(1)

            for line in response.split('\n'):
                data = line.split(',')

                if len(data) == 4:
                    self.credentials.append({
                        'name': data[0],
                        'host': data[1],
                        'username': data[2],
                        'password': data[3]
                    })

    """
    Create connection threads
    """

    def create_connections(self):
        if self.credentials:
            self.print_message("Connecting to servers...")

        thread_pool = {}
        thread_pos = 1
        use_selected_servers = False

        if len(self.selected_servers) > 0:
            use_selected_servers = True

        for cred in self.credentials:
            # If there are specific servers user wants to use, use them ...
            if use_selected_servers:
                if cred['name'] in self.selected_servers:
                    thread_pool[thread_pos] = threading.Thread(target=self.connect, name=cred['name'], args=(cred['name'], cred['host'], cred['username'], cred['password'], self.config.getint('connection', 'port')))
                    thread_pool[thread_pos].daemon = True
                    thread_pool[thread_pos].start()
                    thread_pos += 1

            # ... if not, use all of them
            else:
                thread_pool[thread_pos] = threading.Thread(target=self.connect, name=cred['name'], args=(cred['name'], cred['host'], cred['username'], cred['password'], self.config.getint('connection', 'port')))
                thread_pool[thread_pos].daemon = True
                thread_pool[thread_pos].start()
                thread_pos += 1

        for i in range(1, thread_pos):
            thread_pool[i].join()

        self.print_success("Welcome to master-ssh!")

    """
    Connect to the server
    """

    def connect(self, name, host, username, password, port):
        tries = 1
        max_tries = self.config.getint('connection', 'max_retries')
        delay = self.config.getfloat('connection', 'delay')

        while True:
            if tries == max_tries:
                self.print_error("Unable to connect, giving up!")
                break

            try:
                client = paramiko.SSHClient()
                client.set_log_channel('SSHClient')
                client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
                client.connect(host, username=username, password=password, port=port)
                client.get_transport().set_keepalive(30)

                self.connection_pool[name] = client
                self.print_success("Successfully connected!")

                break
            except paramiko.ssh_exception.AuthenticationException:
                self.print_error("Failed to connect - Wrong login details!")
                break
            except Exception as e:
                self.print_error("Failed to connect! Retrying (%s/%s)" % (tries, max_tries))
                self.print_exception(e)

                connected = False
                tries += 1
                time.sleep(delay)

    """
    Create connection closing threads
    """

    def close_connections(self):
        thread_pool = {}
        thread_pos = 1

        for name in self.connection_pool:
            thread_pool[thread_pos] = threading.Thread(target=self.exit, name=name, args=(name,))
            thread_pool[thread_pos].daemon = True
            thread_pool[thread_pos].start()
            thread_pos += 1

        for i in range(1, thread_pos):
            thread_pool[i].join()

        self.print_success("Bye, bye!")

    """
    Close all open connections
    """

    def exit(self, name):
        try:
            if self.connection_pool[name].get_transport():
                if self.connection_pool[name].get_transport().is_active():
                    self.connection_pool[name].exec_command('exit', timeout=self.config.getint('timeout', 'exit_cmd'))
        except Exception as e:
            self.print_exception(e)

        self.connection_pool[name].close()
        self.print_action("Disconnected!")

    """
    Listen for user commands
    """

    def listen(self):

        while True:
            # Listen for user input
            cmd = input('master-ssh$ ').strip()

            if cmd == "":
                self.print_error('Please enter your command!')
            elif cmd == "exit":
                self.close_connections()
                break
            else:

                # Check for internal commands
                if cmd.startswith('#'):

                    """
                    INTERNAL COMMANDS
                    """

                    if cmd == "#help" or cmd == "#":

                        print("Welcome to master-ssh!")
                        print("Made by Jaan Porkon <jaantrill[at]gmail.com>")
                        print("https://github.com/JaanPorkon/master-ssh")
                        print("")
                        print("Example syntax:")
                        print("#[command]")
                        print("")
                        print("Example command:")
                        print("#help")
                        print("")
                        print("Below you can find available commands that you can use:")
                        print("")
                        print("help                                        - Displays this message")
                        print("list-connections                            - Lists the state of all connections")
                        print("disconnect:[hostname,hostname2,...]         - Disconnects a specific host and removes it from the pool")
                        print("                                            - Example:")
                        print("                                              #disconnect:host1")
                        print("                                            - or multiple servers:")
                        print("                                              #disconnect:host2,host3")
                        print("ignore:[hostname,hostname2] [command]       - Ignores servers that are listed")
                        print("connect:[host],[username],[password],[name] - Connects to a server. [name] is optional, but suggested")

                    # List all connections
                    elif cmd == "#list-connections":
                        for name, connection in self.connection_pool.items():
                            is_alive = connection.get_transport().is_active()

                            if is_alive:
                                self.print_action("%s is active!" % name)
                            else:
                                self.print_server_error("%s is disconnected!" % name)

                    elif cmd.startswith('#connect:'):
                        data = cmd.replace('#connect:', '').split(':')

                        # host:username:password:name*
                        # * name is optional

                        if len(data) == 3:
                            self.connect(data[0], data[0], data[1], data[2])
                        elif len(data) == 4:
                            self.connect(data[3], data[0], data[1], data[2])
                        else:
                            self.print_error("Not enough arguments!")

                    # Disconnect specific servers from the pool
                    elif cmd.startswith('#disconnect:'):

                        disconnected = False
                        servers = cmd.strip().split(':')[-1]

                        for server in servers.split(','):
                            if server in self.connection_pool:
                                connection = self.connection_pool[server]

                                try:
                                    if connection.get_transport().is_active():
                                        connection.exec_command('exit', timeout=int(self.config.get('timeout', 'exit_cmd')))
                                        connection.close()

                                        self.connection_pool.pop(server)
                                        self.print_success(server + " successfully disconnected!")
                                    else:
                                        self.connection_pool.pop(server)

                                    disconnected = True

                                except Exception as e:
                                    self.print_error('Unable to disconnect: ' + server + ' (' + str(e) + ')')

                            else:
                                self.print_error(server + " does not exist!")

                        if len(self.connection_pool) == 0 and disconnected:
                            self.print_message("Connection pool is empty, closing the program..")
                            sys.exit(0)

                    else:

                        """
                        SERVER COMMANDS
                        """

                        use_ignore_list = False
                        server_list_start = 1

                        find_space = cmd.find(' ')
                        command = cmd[(find_space + 1):len(cmd)].strip()

                        # If user wishes to ignore specific servers, do so ...
                        if cmd.startswith("#ignore:"):
                            server_list_start = 8
                            use_ignore_list = True

                        server = cmd[server_list_start:find_space].strip()

                        if use_ignore_list:
                            ignore_list = server.strip().split(',')

                            thread_pool = {}
                            thread_pos = 1

                            for name, connection in self.connection_pool.items():
                                if name not in ignore_list:
                                    thread_pool[thread_pos] = threading.Thread(target=self.execute, name=name, args=(command, name, connection,))
                                    thread_pool[thread_pos].daemon = True
                                    thread_pool[thread_pos].start()
                                    thread_pos += 1

                            for i in range(1, thread_pos):
                                thread_pool[i].join()

                            ignore_list = []

                        # ... if not, send the command to ...
                        else:

                            # ... specific servers that user has defined or ...
                            if "," in server:

                                thread_pool = {}
                                thread_pos = 1

                                for server in server.split(','):

                                    if server not in self.connection_pool:
                                        self.print_error("%s does not exist!" % server)
                                    else:
                                        connection = self.connection_pool[server]

                                        thread_pool[thread_pos] = threading.Thread(target=self.execute, name=server, args=(command, server, connection,))
                                        thread_pool[thread_pos].daemon = True
                                        thread_pool[thread_pos].start()
                                        thread_pos += 1

                                for i in range(1, thread_pos):
                                    thread_pool[i].join()

                            # ... send the command to one specific server
                            else:

                                if server not in self.connection_pool:
                                    self.print_error("%s does not exist!" % server)
                                else:
                                    stdin, stdout, stderr = self.connection_pool[server].exec_command(command)

                                    error = ""

                                    for line in stderr:
                                        error += "\n" + line

                                    self.print_server_error(error.strip())

                                    response = ""

                                    for line in stdout:
                                        response += line

                                    self.print_server_action(response.strip())

                # If user haven't defined any internal comments, send user's command to all of the servers
                else:
                    if len(self.connection_pool) == 0:
                        self.print_error('There are no active connections!')
                    else:
                        thread_pool = {}
                        thread_pos = 1

                        for name, connection in self.connection_pool.items():
                            thread_pool[thread_pos] = threading.Thread(target=self.execute, name=name, args=(cmd, name, connection,))
                            thread_pool[thread_pos].daemon = True
                            thread_pool[thread_pos].start()
                            thread_pos += 1

                        for i in range(1, thread_pos):
                            thread_pool[i].join()

    """
    Execute user's command
    """

    def execute(self, cmd, name, connection):
        stdin, stdout, stderr = connection.exec_command(cmd)

        error = ""

        for line in stderr:
            error += "\n%s" % line

        error = error.strip()

        if error != "":
            self.print_server_error(error.strip())

        response = ""

        for line in stdout:
            response += "\n%s" % line

        response = response.strip()

        if response != "":
            self.print_server_action(response.strip())

    """
    Helper methods
    """

    def print_message(self, message):
        logger.info(colored(message, color='green'))

    def print_action(self, message):
        logger.info(colored(message, color='green'))

    def print_server_action(self, message):
        logger.info('\n%s' % colored(message, color='green'))

    def print_success(self, message):
        logger.info(colored(message, color='green'))

    def print_error(self, message):
        logger.error('master-ssh: %s' % colored(message, on_color='on_red', attrs=['bold']))

    def print_server_error(self, message):
        logger.error(colored(message, on_color='on_red', attrs=['bold']))

    def print_exception(self, err):
        logger.exception(err)

    def parse_config(self, args):
        self.config = configparser.ConfigParser()
        self.config.read("/%s/config.ini" % pathlib.Path(__file__).parent.absolute())

        if args.port:
            self.config.set('connection', 'port', args.port)

    def init_logs(self):
        log_file = self.config.get('log', 'ssh_client')

        if not os.path.isfile(log_file):
            open(log_file, 'w').close()

        paramiko.util.log_to_file(log_file)
