#!/usr/bin/env python3

import abc
import argparse
import json
import logging
import logging.config
import logging.handlers
import os
import socket
import sys

import dns.name
import dns.rdataclass
import dns.rdatatype
import dns.rdtypes.IN.A
import dns.rdtypes.IN.AAAA
import dns.resolver
import dns.tsigkeyring
import dns.update
import netifaces


class Platform(metaclass=abc.ABCMeta):
    """
    Encapsulates knowledge about a specific operating system.
    """

    @abc.abstractmethod
    def getName(self):
        """Return the name as shown in os.name."""
        pass

    @abc.abstractmethod
    def getPermanentIPv6Addresses(self, addresses):
        """
        Return only the IPv6 addresses that are permanent (i.e. not generated
        by RFC4941 privacy extensions).

        For IPv6, most NICs will have multiple addresses: one permanent and a
        bunch of temporary generated by SLAAC privacy extensions. We should
        include only the permanent address, not the temporary ones.

        This is platform-specific because netifaces does not return any
        metadata about the address indicating whether it is permanent or
        temporary (and the addresses themselves are indistinguishable).
        However, specific operating systems appear to have a consistent
        ordering as to whether permanent addresses are returned before or after
        temporary addresses.
        """
        pass

    @abc.abstractmethod
    def getDefaultConfigFilename(self):
        """Return the default location of the configuration file."""
        pass

    @abc.abstractmethod
    def getDefaultCacheFilename(self):
        """Return the default location of the cache file."""
        pass

    @abc.abstractmethod
    def platformSpecificSetup(self):
        """
        Do any work that is specific to this platform for initializing the
        program.
        """
        pass


class POSIXPlatform(Platform):
    def getName(self):
        return "posix"

    def getPermanentIPv6Addresses(self, addresses):
        # Linux returns permanent addresses last.
        if len(addresses) == 0:
            return []
        else:
            return [addresses[-1]]

    def getDefaultConfigFilename(self):
        return "/etc/pydyndns.conf"

    def getDefaultCacheFilename(self):
        return "/run/pydyndns.cache"

    def platformSpecificSetup(self):
        pass


class WindowsPlatform(Platform):
    def getName(self):
        return "nt"

    def getPermanentIPv6Addresses(self, addresses):
        # Windows returns permanent addresses first.
        if len(addresses) == 0:
            return []
        else:
            return [addresses[0]]

    def getDefaultConfigFilename(self):
        return os.path.join(os.path.dirname(os.path.abspath(__file__)), "pydyndns.conf")

    def getDefaultCacheFilename(self):
        localAppData = os.environ.get("LOCALAPPDATA")
        if localAppData is None:
            localAppData = os.path.join(os.path.expand("~"), "AppData", "Local")
        return os.path.join(localAppData, "Temp", "pydyndns.cache")

    def platformSpecificSetup(self):
        # Python’s NTEventLogHandler class unconditionally tries to add the
        # event source to the Windows registry. This fails when running as a
        # low-privileged account. If somebody had already added the event
        # source then logging events using that source could work, but
        # NTEventLogHandler doesn’t bother trying that, it just fails in
        # construction if the registration fails.  Work around this by
        # swallowing exceptions during source registration.
        try:
            import pywintypes
            import win32evtlogutil
            oldAddSourceToRegistry = win32evtlogutil.AddSourceToRegistry
            def replacement(appname, dllname, logtype):
                try:
                    oldAddSourceToRegistry(appname, dllname, logtype)
                except pywintypes.error:
                    pass
            win32evtlogutil.AddSourceToRegistry = replacement
        except ImportError as exp:
            # Guess Win32 extensions are not installed.
            pass


class UnknownPlatform(Platform):
    def getName(self):
        return "unknown"

    def getPermanentIPv6Addresses(self, addresses):
        # No idea what the convention is on this platform, so just return all
        # of them.
        return addresses

    def getDefaultConfigFilename(self):
        return "pydyndns.conf"

    def getDefaultCacheFilename(self):
        return "pydyndns.cache"

    def platformSpecificSetup(self):
        pass


class Family(metaclass=abc.ABCMeta):
    """
    Encapsulates knowledge about a specific address family.
    """

    @abc.abstractmethod
    def getName(self):
        """Return the name used as a cache key for addresses in this family."""
        pass

    @abc.abstractmethod
    def getNetIFacesConstant(self):
        """Return the numeric ID used as a key in netifaces’ output."""
        pass

    @abc.abstractmethod
    def addAddressToUpdate(self, update, hostPart, ttl, address):
        """Add an address in this family to a DNS update request."""
        pass

    @abc.abstractmethod
    def filterAddressList(self, addresses):
        """
        Return only those addresses that are useful, e.g. not loopback,
        link-local, temporary, or other special addresses that should not be
        registered.
        """
        pass


class IPv4(Family):
    def getName(self):
        return "ipv4"

    def getNetIFacesConstant(self):
        return netifaces.AF_INET

    def addAddressToUpdate(self, update, hostPart, ttl, address):
        update.add(hostPart, ttl, dns.rdtypes.IN.A.A(dns.rdataclass.IN, dns.rdatatype.A, address))

    def filterAddressList(self, addresses):
        # For IPv4 most NICs have only one address. It’s not clear that there
        # are any specific rules about how multiple addresses ought to be
        # handled. Just include all of them that are acceptable.
        return [x for x in addresses if self.includeAddress(x)]

    def includeAddress(self, address):
        parts = [int(part) for part in address.split(".")]
        if parts[0] == 127:
            return False # Loopback address
        elif parts[0] >= 240:
            return False # Multicast or reserved address
        return True


class IPv6(Family):
    def __init__(self, platform, config):
        self._platform = platform
        self._config = config

    def getName(self):
        return "ipv6"

    def getNetIFacesConstant(self):
        return netifaces.AF_INET6

    def addAddressToUpdate(self, update, hostPart, ttl, address):
        update.add(hostPart, ttl, dns.rdtypes.IN.AAAA.AAAA(dns.rdataclass.IN, dns.rdatatype.AAAA, address))

    def filterAddressList(self, addresses):
        return self._platform.getPermanentIPv6Addresses([x for x in addresses if self.includeAddress(x)])

    def includeAddress(self, address):
        first_word = int(address.split(":")[0] or "0", 16)
        second_word = int(address.split(":")[1] or "0", 16)
        if first_word == 0x0000:
            return False # Unspecified, local, or IPv6-mapped address
        elif first_word == 0x0100:
            return False # Discard address
        elif (first_word & 0xFE00) == 0xFC00:
            return False # Unique local address
        elif first_word == 0xFE80:
            return False # Link-local address
        elif (first_word & 0xFF00) == 0xFF00:
            return False # Multicast address
        elif ((first_word == 0x2001) and (second_word == 0x0000)) and not self._config["teredo"]:
            return False # Teredo address
        return True


def run(platform, args, config, logger):
    """
    Run the program.

    platform -- an instance of a subclass of Platform
    args -- a module containing parsed command-line arguments
    config -- a dict containing the parsed configuration file
    logger -- a logger to log messages to
    """
    # Decide which families to use.
    families = []
    if config["ipv4"]:
        families.append(IPv4())
    if config["ipv6"]["enable"]:
        families.append(IPv6(platform, config["ipv6"]))
    if not families:
        logger.error("No address families are enabled.")
        return

    # Grab the TTL from the config file.
    ttl = int(config["ttl"])

    # Decide which cache file to use, if any.
    if isinstance(config["cache"], str):
        cacheFile = config["cache"]
    elif config["cache"] == True:
        cacheFile = platform.getDefaultCacheFilename()
    else:
        cacheFile = None
    if cacheFile is None:
        logger.debug("Using no cache file.")
    else:
        logger.debug("Using cache file {}.".format(cacheFile))

    # Wipe the cache file if in force mode. Doing this, rather than just
    # unconditionally updating right now, means that if this update fails, the
    # cache file will not be written and the next update will also be
    # unconditional, which is a more useful behaviour for force (you can run in
    # force mode once and be sure that at least one update will happen
    # successfully before we stop trying).
    if args.force and cacheFile is not None:
        logger.debug("Wiping cache due to --force.")
        try:
            os.remove(cacheFile)
        except OSError:
            pass

    # Load the cache file, if any.
    if cacheFile is not None:
        try:
            with open(cacheFile, "r") as fp:
                cache = json.load(fp)
        except (KeyboardInterrupt, SystemExit):
            raise
        except:
            cache = None
    else:
        cache = None

    # Rip apart my hostname.
    fqdn = dns.name.from_text(socket.getfqdn())
    zone = fqdn.parent()
    hostPart = fqdn.relativize(zone)

    # Find which nameserver we should talk to using an SOA query.
    resp = dns.resolver.resolve(zone, dns.rdatatype.SOA, search=True)
    if len(resp.rrset) != 1:
        raise RuntimeError("Got {} SOA records for zone {}, expected 1.".format(len(resp.rrset), zone))
    server = resp.rrset[0].mname.to_text(omit_final_dot=True)
    logger.debug("Using nameserver {}.".format(server))

    # Find my addresses.
    addresses = {family.getName(): [] for family in families}
    for interface in (args.interface or netifaces.interfaces()):
        for family in families:
            ifAddresses = netifaces.ifaddresses(interface).get(family.getNetIFacesConstant(), [])
            addresses[family.getName()] += family.filterAddressList([addr["addr"] for addr in ifAddresses])
    for family in families:
        addresses[family.getName()].sort()

    # Get the hostname and addresses most recently sent from the cache.
    if cache:
        last_hostname = cache.get("hostname")
        last_addresses = cache.get("addresses")
    else:
        last_hostname = None
        last_addresses = None

    # Check if the current hostname and addresses are the same as the last one.
    if fqdn.to_text() == last_hostname and addresses == last_addresses:
        logger.info("Eliding DNS record update for {} to {} as cache says addresses have not changed.".format(fqdn, addresses))
    else:
        # Resolve the nameserver hostname from the SOA record to one or more IP
        # addresses, and try to connect a socket to each one in turn until one succeeds.
        sock = socket.create_connection((server, 53))
        sock.setblocking(False)

        # Issue the update.
        logger.info("Updating DNS record for {} to {}.".format(fqdn, addresses))
        update = dns.update.Update(zone)
        update.delete(hostPart)
        for family in families:
            for address in addresses[family.getName()]:
                family.addAddressToUpdate(update, hostPart, ttl, address)
        if "tsig" in config:
            knownAlgorithms = {
                "hmac-md5": dns.tsig.HMAC_MD5,
                "hmac-sha1": dns.tsig.HMAC_SHA1,
                "hmac-sha224": dns.tsig.HMAC_SHA224,
                "hmac-sha256": dns.tsig.HMAC_SHA256,
                "hmac-sha384": dns.tsig.HMAC_SHA384,
                "hmac-sha512": dns.tsig.HMAC_SHA512,
            }
            if config["tsig"]["algorithm"] not in knownAlgorithms:
                raise RuntimeError("TSIG algorithm {} not recognized.".format(config["tsig"]["algorithm"]))
            tsigAlgorithm = knownAlgorithms[config["tsig"]["algorithm"]]
            tsigRing = dns.tsigkeyring.from_text({config["tsig"]["keyname"]: config["tsig"]["key"]})
            update.use_tsig(keyring=tsigRing, algorithm=tsigAlgorithm)
            logger.debug("Update will be authenticated with TSIG {}.".format(tsigAlgorithm))
        else:
            logger.debug("Update will be unauthenticated.")
        resp = dns.query.tcp(update, where=None, sock=sock)
        if resp.rcode() != dns.rcode.NOERROR:
            raise RuntimeError("Update failed with rcode {}.".format(resp.rcode()))

        # Update the cache to remember that we did this.
        if cacheFile is not None:
            with open(cacheFile, "w") as fp:
                json.dump({"hostname": fqdn.to_text(), "addresses": addresses}, fp, ensure_ascii=False, allow_nan=False)


def main():
    # Choose a platform.
    platform = UnknownPlatform()
    for i in (POSIXPlatform(), WindowsPlatform()):
        if i.getName() == os.name:
            platform = i
    platform.platformSpecificSetup()

    # Parse command-line arguments.
    parser = argparse.ArgumentParser(description="Dynamically update DNS records.")
    parser.add_argument("-c", "--config", default=platform.getDefaultConfigFilename(), type=str, help="which configuration file to read (default: {})".format(platform.getDefaultConfigFilename()), metavar="FILE")
    parser.add_argument("-f", "--force", action="store_true", help="update even if cache says unnecessary")
    parser.add_argument("interface", nargs="*", help="the name of an interface whose address(es) to register (default: all interfaces)")
    args = parser.parse_args()

    # Load configuration file.
    with open(args.config, "r") as configFile:
        config = json.load(configFile)

    # Configure logging.
    if "logging" in config:
        logging.config.dictConfig(config["logging"])
    else:
        logging.basicConfig(level=logging.DEBUG)
        logging.warn("No logging section in config file.")
    logging.captureWarnings(True)

    # Run the program.
    try:
        run(platform, args, config, logging.getLogger("pydyndns"))
    except (KeyboardInterrupt, SystemExit):
        raise
    except:
        logging.getLogger("pydyndns").error("Unhandled exception", exc_info=True)
