import os
import tempfile
import subprocess
import copy
import json
import logging
import shutil

LOGGER = logging.getLogger('playbook_runner')


class AnsiblePlaybook(object):
    def __init__(
            self,
            ansible_playbook_inventory,
            ansible_playbook_directory,
            cleanup_output=False):
        self._ansible_playbook_inventory = ansible_playbook_inventory
        self._ansible_playbook_directory = ansible_playbook_directory

        os.environ['ANSIBLE_SSH_RETRIES'] = '15'
        os.environ['ANSIBLE_INVENTORY_UNPARSED_FAILED'] = 'true'

        self._path_str = tempfile.mkdtemp(prefix='play_book_runner_')

        self._hosts = set()
        self._cleanup_output = cleanup_output

        LOGGER.info('Output path: {0}'.format(self._path_str))
        LOGGER.info('CWD: {0}'.format(self._ansible_playbook_directory))

    def __del__(self):
        if self._cleanup_output:
            try:
                shutil.rmtree(self._path_str)
            except OSError as e:
                LOGGER.error(
                    'Error removing: {0}'.format(
                        self._path_str
                    )
                )

    @staticmethod
    def _get_ansible_cmd(inventory_file, playbook_file,
                        extra_vars_dict=None):
        """
        Return process args list for ansible-playbook run.
        """
        ansible_command = [
            "ansible-playbook",
            "-vv",
            "--fork",
            "50",
            "--timeout", "60",
            "-i", inventory_file,
            playbook_file,
        ]
        if extra_vars_dict:
            extra_vars = ''
            for k, v in extra_vars_dict.items():
                if type(v) == str:
                    extra_vars += '{}="{}" '.format(k, v)
                else:
                    extra_vars += '{}="{}" '.format(k, v)

            extra_vars = extra_vars[:-1]
            ansible_command.insert(-1, '--extra-vars')
            ansible_command.insert(-1, extra_vars)

        return ansible_command

    def _get_hosts_by_group(self, inventory, group, some_playbook):
        get_hosts_by_group_cmd = [
            "ansible-playbook",
            "--timeout",
            "60",
            '--list-hosts',
            "-i",
            inventory,
            some_playbook,
            '-e',
            '{0}={1}'.format('play_host_groups', group)
        ]

        result = subprocess.run(
            get_hosts_by_group_cmd,
            cwd=self._ansible_playbook_directory,
            timeout=60,
            stderr=subprocess.PIPE,
            stdout=subprocess.PIPE
        )

        if result.returncode != 0:
            raise RuntimeError(
                'Failed to run list-hosts. stdout: {0}, stderr: {1}'.format(
                    result.stdout,
                    result.stderr
                )
            )

        hosts_in_group = []
        tmp = result.stdout.decode('utf8').split()
        i = 0
        for item in tmp:
            if item == 'hosts':
                break

            i = i + 1

        if i == len(tmp):
            raise RuntimeError(
                'No hosts were found for group: {0}'.format(group)
            )

        for k in range(i + 1, len(tmp)):
            hosts_in_group.append(tmp[k])

        return hosts_in_group

    def run_playbook(self, play_filename, extra_vars_dict={}):
        if not extra_vars_dict:
            extra_vars_dict = {}
        if 'play_host_groups' not in extra_vars_dict:
            extra_vars_dict['play_host_groups'] = 'localhost'

        local_extra_vars = copy.deepcopy(extra_vars_dict)
        local_extra_vars['output_path'] = self._path_str
        if 'skip_errors' not in local_extra_vars:
            local_extra_vars['skip_errors'] = False
        if 'gather_facts_for_pb' not in local_extra_vars:
            local_extra_vars['gather_facts_for_pb'] = False

        cmd = self._get_ansible_cmd(
            self._ansible_playbook_inventory,
            '{0}/{1}'.format(
                self._ansible_playbook_directory, play_filename),
            extra_vars_dict=local_extra_vars)

        hosts = self._get_hosts_by_group(
            self._ansible_playbook_inventory,
            extra_vars_dict['play_host_groups'],
            '{0}/{1}'.format(
                self._ansible_playbook_directory,
                play_filename
            ),
        )

        if '(1):' in hosts:
            hosts.remove('(1):')

        self._hosts.add(*hosts)

        for host in self._hosts:
            file_path = '{0}/{1}.json'.format(self._path_str, host)
            if not os.path.isfile(file_path):
                with open(file_path, 'w+') as f:
                    f.write('[')

        LOGGER.info(
            'Command is about to be run:\n{0}'.format(' '.join(cmd))
        )
        ansible_output_path = '{0}/ansible_output_path.txt'.format(
            self._path_str
        )
        with open(ansible_output_path, "w") as f_ansible_output_path:
            result = subprocess.run(
                cmd,
                cwd=self._ansible_playbook_directory,
                timeout=120,
                stdout=f_ansible_output_path,
                stderr=subprocess.STDOUT,
            )

            if 'skip_errors' not in extra_vars_dict or \
                    not extra_vars_dict['skip_errors']:
                if result.returncode != 0:
                    LOGGER.info('Failed to run: {0}'.format(cmd))

        return result.returncode

    def get_output(self):
        last_output = {}
        for host in self._hosts:
            data = []
            file_path = '{0}/{1}.json'.format(self._path_str, host)
            with open(file_path) as f:
                lines = f.readlines()
                data = lines

            # format file into a correct json format
            with open(file_path, "w+") as f:
                last_line = data[len(data) - 1]
                last_line = last_line[:-3]
                data[len(data) - 1] = last_line
                f.writelines(data)
                f.write(']')

            with open(file_path) as f:
                last_output[host] = json.load(f)

            # Only if the playbook has written something
            if len(data) <= 1:
                continue

        host_output_dict = {}
        for k, v in last_output.items():
            host_output_dict[k] = []
            for item in v:
                host_output_dict[k].append(item[0])

        return host_output_dict
