#!/usr/bin/env python3

import sys
import os
import time
import json
import glob
import docker
import threading
import argparse
from argparse import ArgumentTypeError as err
from base64 import b64decode
import watchdog.events
import watchdog.observers
import time
from pathlib import Path
import logging

DOCKER_LABLE = "com.github.ravensorb.traefik-certificate-exporter.domain-restart"

###########################################################################################################
###########################################################################################################
settings = {
    "dataPath": "./",
    "fileSpec": "*.json",
    "outputPath": "./certs",
    "traefikResolverId": None,
    "resolverInPathName": True,
    "flat": False,
    "dryRun": False,
    "restartContainers": False,
    "domains": {
        "include": [],
        "exclude": []
    }
}

###########################################################################################################
class AcmeCertificateExporter:
    def __init__(self, settings : dict):

        self.__settings = settings

    def __exportCertificate(self, data : dict, resolverName : str = None, keys : str = "lowercase") -> list:
        names = []

        # Determine ACME version
        acme_version = 2 if 'acme-v02' in data['Account']['Registration']['uri'] else 1

        # Find certificates
        if acme_version == 1:
            certs = data['DomainsCertificate']['Certs']
        elif acme_version == 2:
            certs = data['Certificates']

        # Loop over all certificates
        for c in certs:
            if acme_version == 1:
                name = c['Certificate']['Domain']
                privatekey = c['Certificate']['PrivateKey']
                fullchain = c['Certificate']['Certificate']
                sans = c['Domains']['SANs']
            elif acme_version == 2:
                if keys == "uppercase":
                    name = c['Domain']['Main']
                    privatekey = c['Key']
                    fullchain = c['Certificate']
                    sans = c['Domain']['SANs']
                else:
                    name = c['domain']['main']
                    privatekey = c['key']
                    fullchain = c['certificate']
                    sans = c['domain']['sans'] if'sans' in c['domain'] else []  # not sure what this is - can't find any here...

            if name.startswith("*."):
                name = name[2:]

            if (self.__settings["domains"]["include"] and name not in self.__settings["domains"]["include"]) or (self.__settings["domains"]["exclude"] and name in self.__settings["domains"]["exclude"]):
                continue

            # Decode private key, certificate and chain
            privatekey = b64decode(privatekey).decode('utf-8')
            fullchain = b64decode(fullchain).decode('utf-8')
            start = fullchain.find('-----BEGIN CERTIFICATE-----', 1)
            cert = fullchain[0:start]
            chain = fullchain[start:]

            if not self.__settings["dryRun"]:
                # Create domain     directory if it doesn't exist
                directory = Path(self.__settings["outputPath"])
                if "resolverInPathName" in self.__settings and self.__settings["resolverInPathName"] and resolverName and len(resolverName) > 0:
                    directory = directory / resolverName

                if not directory.exists():
                    directory.mkdir(parents=True, exist_ok=True)

                if self.__settings["flat"]:
                    # Write private key, certificate and chain to flat files
                    with (directory / (str(name) + '.key')).open('w') as f:
                        f.write(privatekey)

                    with (directory / (str(name) + '.crt')).open('w') as f:
                        f.write(fullchain)

                    with (directory / (str(name) + '.chain.pem')).open('w') as f:
                        f.write(chain)

                    # if sans:
                    #     for name in sans:
                    #         with (directory / (str(name) + '.key')).open('w') as f:
                    #             f.write(privatekey)
                    #         with (directory / (str(name) + '.crt')).open('w') as f:
                    #             f.write(fullchain)
                    #         with (directory / (str(name) + '.chain.pem')).open('w') as f:
                    #             f.write(chain)
                else:
                    directory = directory / name
                    if not directory.exists():
                        directory.mkdir(parents=True, exist_ok=True)

                    # Write private key, certificate and chain to file
                    with (directory / 'privkey.pem').open('w') as f:
                        f.write(privatekey)

                    with (directory / 'cert.pem').open('w') as f:
                        f.write(cert)

                    with (directory / 'chain.pem').open('w') as f:
                        f.write(chain)

                    with (directory / 'fullchain.pem').open('w') as f:
                        f.write(fullchain)

            logging.info("Extracted certificate for: {} ({})".format(name, ', '.join(sans) if sans else ''))

            names.append(name)

    def exportCertificatesForFile(self, sourceFile : str) -> list:
        data = json.loads(open(sourceFile).read())

        resolversToProcess = []
        keys = "uppercase"
        if self.__settings["traefikResolverId"] and len(self.__settings["traefikResolverId"]) > 0:
            if self.__settings["traefikResolverId"] in data:
                resolversToProcess.append(self.__settings["traefikResolverId"])
                keys = "lowercase"
            else:
                logging.warning("Specified traefik resolver id '{}' is not found in acme file '{}'. Skipping file".format(self.__settings["traefikResolverId"], sourceFile))
                return
        else:
            # Should we try to get the first resolver if it is there?
            elementNames = list(data.keys())
            logging.debug("[DEBUG] Checking node '{}' to see if it is a resolver node".format(elementNames[0]))
            if "Account" in data[elementNames[0]]:
                resolversToProcess = elementNames
                keys = "lowercase"

        names = []

        if len(resolversToProcess) > 0:
            logging.info("Resolvers to process: {}".format(resolversToProcess))

            for resolver in resolversToProcess:
                names.append(self.__exportCertificate(data[resolver], resolverName=resolver, keys=keys))
        else:
            names = self.__exportCertificate(data, keys=keys)

        return names

    def exportCertificates(self) -> list:
        processedDomains = []

        for name in glob.glob(os.path.join(self.__settings["dataPath"], self.__settings["fileSpec"])):
            domains = self.exportCertificatesForFile(name)
            if domains and len(domains) > 0:
                processedDomains.extend(x for x in domains if x not in processedDomains)

        return processedDomains

###########################################################################################################
class DockerManager:
    def __init__(self, settings : dict):
        self.__settings = settings

    def restartLabeledContainers(self, domains : list):
        if not self.__settings["restartContainers"]:
            return

        try:
            client = docker.from_env()
            container = client.containers.list(filters = {"label" : DOCKER_LABLE})
            for c in container:
                restartDomains = str.split(c.labels[ DOCKER_LABLE ], ',')
                if not set(domains).isdisjoint(restartDomains):
                    logging.info("Restarting container: {}".format(c.id))
                    if not self.__settings["dry"]:
                        c.restart()
        except Exception as ex:
            logging.error("Failed restarting containers", exc_info=True)

###########################################################################################################
class AcmeCertificateFileHandler(watchdog.events.PatternMatchingEventHandler):
    def __init__(self, exporter : AcmeCertificateExporter, dockerManager : DockerManager, settings : dict):
        self.__exporter = exporter
        self.__dockerManager = dockerManager
        self.__settings = settings

        self.isWaiting = False
        self.lock = threading.Lock()

        # Set the patterns for PatternMatchingEventHandler
        watchdog.events.PatternMatchingEventHandler.__init__(self, patterns = [ self.__settings["fileSpec"] ],
                                                                    ignore_directories = True, 
                                                                    case_sensitive = False)
   
    def on_created(self, event):
        logging.debug("Watchdog received created event - % s." % event.src_path)
        self.handleEvent(event)

    def on_modified(self, event):
        logging.debug("Watchdog received modified event - % s." % event.src_path)
        self.handleEvent(event)

    def handleEvent(self, event):

        if not event.is_directory:
            logging.info("Certificates changed found in file: {}".format(event.src_path))

            with self.lock:
                if not self.isWaiting:
                    self.isWaiting = True # trigger the work just once (multiple events get fired)
                    self.timer = threading.Timer(2, self.doTheWork, args=[event])
                    self.timer.start()

    def doTheWork(self, *args, **kwargs):
        ''' 
        This is a workaround to handle multiple events for the same file
        '''
        logging.debug("[DEBUG] SStarting the work")

        if not args or len(args) == 0:
            logging.error("No event passed to worker")
            self.isWaiting = False

            return

        domains = self.__exporter.exportCertificatesForFile(args[0].src_path)

        if (self.__settings["restartContainers"]):
            self.__dockerManager.restartLabeledContainers(domains)

        with self.lock:
            self.isWaiting = False
        
        logging.debug('[DEBUG] Finished')

###########################################################################################################
###########################################################################################################

if __name__ == "__main__":
    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.DEBUG)

    logging.info("Traefik Cretificate Exporter starting....")

    ###########################################################################################################
    parser = argparse.ArgumentParser(description="Extract traefik letsencrypt certificates.")

    parser.add_argument("-c", "--config-file", dest="configFile", default=None, type=str,
                                help="the path to watch for changes (default: %(default)s)")
    parser.add_argument("-d", "--data-path", dest="dataPath", default=settings["dataPath"], type=str, 
                                help="the path that contains the acme json files (default: %(default)s)")
    parser.add_argument("-w", "--watch-for-changes", action="store_true", dest="watch",
                                help="If specified, monitor and watch for changes to acme files")
    parser.add_argument("-fs", "--file-spec", dest="fileSpec", default=settings["fileSpec"], type=str, 
                                help="file that contains the traefik certificates (default: %(default)s)")
    parser.add_argument("-o", "--output-path", dest="outputPath", default=settings["outputPath"], type=str, 
                                help="The folder to exports the certificates in to (default: %(default)s)")
    parser.add_argument("--traefik-resolver-id", dest="traefikResolverId", default=settings["traefikResolverId"],
                                help="Traefik certificate-resolver-id.")
    parser.add_argument("-f", "--flat", action="store_true", dest="flat",
                                help="If specified, all certificates into a single folder")
    parser.add_argument("-r", "--restart_container", action="store_true", dest="restartContainer",
                                help="If specified, any container that are labeled with '" + DOCKER_LABLE + "=<DOMAIN>' will be restarted if the domain name of a generated certificates matches the value of the lable. Multiple domains can be seperated by ','")
    parser.add_argument("--dry-run", action="store_true", dest="dry", 
                                help="Don't write files and do not restart docker containers.")
    parser.add_argument("--include-resolvername-in-outputpath", action="store_true", dest="resolverInPathName", 
                                help="Added the resolvername in the path used to export the certificates (ignored if flat is specified).")

    group = parser.add_mutually_exclusive_group()
    group.add_argument("-id", "--include-domains", nargs="*", dest="includeDomains", default=None,
                                help="If specified, only certificates that match domains in this list will be extracted")
    group.add_argument("-xd", "--exclude-domains", nargs="*", dest="excludeDomains", default=None,
                                help="If specified. certificates that match domains in this list will be ignored")
    
    ###########################################################################################################

    args = parser.parse_args()

    # Do we need to load settings from a config file
    if args.configFile and os.path.exists(args.configFile):
        logging.info("Loading Confgile: {}".format(args.configFile))
        settings = json.loads(open(args.configFile).read())

    # Letts override the settings from the dommain line
    settings["dataPath"] = args.dataPath
    settings["fileSpec"] = args.fileSpec
    settings["outputPath"] = args.outputPath
    settings["traefikResolverId"] = args.traefikResolverId
    settings["resolverInPathName"] = args.resolverInPathName

    settings["flat"] = args.flat
    settings["restartContainers"] = args.restartContainer
    settings["dryRun"] = args.dry

    if args.includeDomains:
        settings["domains"]["include"] = args.includeDomains
    if args.excludeDomains:
        settings["domains"]["exclude"] = args.excludeDomains

    # Lets validate the path we are being asked to watch actually exists
    if not os.path.exists(settings["dataPath"]):
        logging.error("Data Path does not exist. Exiting...")
        sys.exit(-1)

    logging.info("Data Path: {}".format(settings["dataPath"]))
    logging.info("File Spec: {}".format(settings["fileSpec"]))
    logging.info("Output Path: {}".format(settings["outputPath"]))

    exporter = AcmeCertificateExporter(settings=settings)
    dockerManager = DockerManager(settings=settings)

    if not args.watch:
        logging.info("Exporting certificates....")
        domainsProcessed = exporter.exportCertificates()
        if domainsProcessed and len(domainsProcessed) > 0 and settings["restartContainers"]:
            dockerManager.restartLabeledContainers(domainsProcessed)
    else:
        logging.info("Watching for changes to files....")
        event_handler = AcmeCertificateFileHandler(exporter=exporter, 
                                                    dockerManager=dockerManager,
                                                    settings=settings)

        observer = watchdog.observers.Observer()
        observer.schedule(event_handler, path=settings["dataPath"], recursive=False)

        observer.start()
        try:
            while True:
                time.sleep(1)
        except KeyboardInterrupt:
            observer.stop()
        observer.join()

    logging.info("Traefik Cretificate Exporter stopping....")
