#!/usr/bin/env python3

import os
import sys
import docker
import subprocess
from enum import Enum
from typing import List, Tuple
from docker import DockerClient

"""
<sphinx>

tls-docker-network
-------------

Automated TLS certificate verification for docker-based flows like docker-gen.
Scans list of docker containers basing on a label or environment variable that contains a domain name.

Parameters:

- parameter_type: Label or environment variable
- parameter_name: Name of the label or environment variable
- alert_days_before: Number of days before expiration date to start alerting (defaults to 3)
- docker_host: (Optional) The URL to the Docker host.
- docker_tls_verify: (Optional) Verify the host against a CA certificate.
- docker_cert_path: (Optional) A path to a directory containing TLS certificates to use when connecting to the Docker host
- debug: (Optional) Debugging mode

</sphinx>
"""


class ParamTypes(Enum):
    ENVIRONMENT = 'environment'
    LABEL = 'label'


class TlsDockerNetworkCheck(object):
    param_type: str
    param_name: str
    alert_days_before: int
    client: DockerClient

    def __init__(self, param_type: str, param_name: str, alert_days_before: int):
        self.param_type = param_type
        self.param_name = param_name
        self.alert_days_before = alert_days_before
        self.client = docker.from_env()

    def main(self) -> Tuple[str, bool]:
        """
        Run a check for each domain
        """

        domains = self.find_domains_to_check()
        checks_dir = os.path.dirname(os.path.realpath(__file__))
        has_at_least_one_failure = False
        out = ''

        for host in domains:
            try:
                subprocess.check_output(['tls'], env={
                    'DOMAIN': host[0],
                    'PORT': str(host[1]),
                    'ALERT_DAYS_BEFORE': str(self.alert_days_before),
                    'PATH': checks_dir + ':' + os.getenv('PATH')
                }).decode('utf-8')

                out += 'Domain {} is OK'.format(host[0]) + " | \n"

            except subprocess.CalledProcessError as e:
                out += 'Domain "{}" check failure: "{}" | '.format(host[0], str(e.output.decode('utf-8')))
                out += "\n"
                has_at_least_one_failure = True

        if not out:
            return 'No any domains found, maybe the containers are not running?', False

        return out.rstrip('| \n'), not has_at_least_one_failure

    def find_domains_to_check(self) -> List[Tuple[str, int]]:
        """
        Find containers matching given labels/environment variables
        :return:
        """

        containers = self.client.containers.list()
        domains = []

        for container in containers:
            try:
                if self.param_type == ParamTypes.ENVIRONMENT.value:
                    for env_variable in container.attrs['Config']['Env']:
                        sep = env_variable.split('=', maxsplit=1)

                        if sep[0] == self.param_name:
                            domain_sep = sep[1].split(',')

                            for domain in domain_sep:
                                domain = self.purify_domain_name(domain)

                                if not domain:
                                    print(f' Warning: Empty domain in {container} ')
                                    continue

                                domains.append((domain, 443))

                elif self.param_type == ParamTypes.LABEL.value:
                    domain_sep = container.attrs['Config']['Labels'][self.param_name].split(',')

                    for domain in domain_sep:
                        domain = self.purify_domain_name(domain)

                        if not domain:
                            print(f' Warning: Empty domain in {container} ')
                            continue

                        domains.append((domain, 443))

            except KeyError as err:
                if self.is_debug_mode():
                    print('KeyError: {} for container {}'.format(str(err), container.id))

                continue

        return domains

    @staticmethod
    def is_debug_mode() -> bool:
        return os.getenv('DEBUG', 'false').lower() == 'true'

    @staticmethod
    def purify_domain_name(domain: str) -> str:
        """
        Strip quotes and blank characters - who knows what docker daemon returns, and what were defined in containers
        """

        return domain.strip('"\'').strip()


if __name__ == '__main__':
    app = TlsDockerNetworkCheck(
        param_type=os.getenv('PARAM_TYPE', 'environment'),
        param_name=os.getenv('PARAM_NAME', 'LETSENCRYPT_HOST'),
        alert_days_before=int(os.getenv('ALERT_DAYS_BEFORE', '3'))
    )

    output, result = app.main()
    print(output)
    sys.exit(0 if result else 1)
