#!/usr/bin/env python
"""bind related dns methods"""

import datetime
import pathlib
import secrets

from waflibs import domain, log, shell

default_bind_dir = "/etc/bind"
default_serial_file = f"{default_bind_dir}/serial"

INFOGREEN = "\033[32m"
INFORED = "\033[31m"
ENDCOLOR = "\033[0m"
BANNER = """\
; @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
;         THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
; @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
"""

logger = log.logger().get_logger()


def restart_bind(sudo=False):
    command_arr = ["service", "bind9", "reload"]
    if sudo:
        command_arr.insert(0, "sudo")
    logger.debug("command to run: {}".format(" ".join(command_arr)))

    stdout, stderr, _ = shell.command(command_arr)
    logger.debug("stdout: {}".format(stdout))
    logger.debug("stderr: {}".format(stderr))


restart = restart_bind


def sign_zone(zone, bind_dir=default_bind_dir, sudo=False):
    salt = secrets.token_hex(8)

    domain = zone
    if zone.endswith("home"):
        domain = ".".join(zone.split(".")[0:-1])

    command_arr = [
        "dnssec-signzone",
        "-S",
        "-3",
        salt,
        "-o",
        domain,
        "-N",
        "INCREMENT",
        "-f",
        f"signed/db.{zone}",
        f"db.{zone}",
    ]
    if sudo:
        command_arr.insert(0, "sudo")
    logger.debug("command to run: {}".format(" ".join(command_arr)))

    stdout, stderr, _ = shell.command(command_arr, cwd=bind_dir)
    logger.debug("stdout: {}".format(stdout))
    logger.debug("stderr: {}".format(stderr))


def sign_zones(
    zones,
    bind_dir=default_bind_dir,
    sudo=False,
):
    for zone in zones:
        sign_zone(zone, bind_dir=bind_dir, sudo=sudo)


def sign_zones_and_restart(
    zones,
    bind_dir=default_bind_dir,
    sudo=False,
):
    sign_zones(zones, bind_dir=bind_dir, sudo=sudo)
    restart_bind(sudo=sudo)


sign_and_restart = sign_zones_and_restart


def generate_serial_number(
    serial_file=default_serial_file,
    serial_number=None,
):
    if serial_number is None:
        now = datetime.datetime.now()
        serial_format = "%Y%m%d"
        date_as_str = now.strftime(serial_format)
        default_serial_number = date_as_str + "00"

        try:
            s_num = open(serial_file).read().strip()
            if s_num.startswith(date_as_str):
                serial_number = int(s_num) + 1
            else:
                serial_number = default_serial_number
        except FileNotFoundError as e:
            logger.debug("serial file not found... ignoring")
            logger.debug(f"error: {e}")

            serial_number = default_serial_number

    return int(serial_number)


def write_serial_number(serial_number, serial_file):
    try:
        with open(serial_file, "w") as f:
            f.write(f"{serial_number}\n")
    except PermissionError as e:
        logger.debug(f"error writing to {serial_file}: {e}")


write_serial = write_serial_number


gen_serial_number = generate_serial_number
gen_serial_num = generate_serial_number
gen_serial = generate_serial_number
generate_serial = generate_serial_number


def generate_bind_file(
    all_records,
    domain_name,
    zone_file_name=None,
    template_file_name=None,
    domain_template_file_name=None,
    serial_number=None,
    bind_dir=default_bind_dir,
    serial_file=default_serial_file,
    logger=logger,
    dry_run=False,
):
    """generate bind file from given arguments"""

    if not serial_number:
        serial_number = generate_serial(
            serial_file=serial_file,
        )
    write_serial(serial_number, serial_file)

    text.eprint(
        f"{INFOGREEN}Generating zone records for {domain}...{ENDCOLOR} ",
        newline=False,
    )

    logger.debug("all records: {}".format(all_records))
    count = 0
    records = ""
    for record in all_records:
        record_name = record["name"]
        record_type = record["type"]
        content = record["content"]
        proxied = record["proxied"]
        count += 1

        logger.debug(
            f"record name: {domain.idna(record_name)}, record type: {record_type}"
        )
        logger.debug(f"record content: {content}, proxied: {proxied}")

        if record_type in ("MX"):
            records += (
                f"{record_name} IN {record_type} {record['priority']} {content}\n"
            )
        elif record_type in ("SRV"):
            records += f"{record_name} IN {record_type} {record['data']['priority']} {content}\n"
        elif record_type in ("SPF", "TXT"):
            records += f'{record_name} IN {record_type} "{content}"\n'
        elif record_type in ("NS", "CNAME"):
            records += f"{record_name} IN {record_type} {content}\n"
        elif record_type in ("A", "AAAA"):
            records += f"{record_name} IN {record_type} {content}\n"
        else:
            logger.fatal(f"unsupported dns record type {record_type}")

    records += f"""\n; {count} total records found"""

    text.eprint(f"{INFOGREEN}done.{ENDCOLOR}")

    if zone_file_name is None:
        zone_file_name = f"db.{domain_name}"
    zone_file = pathlib.Path(bind_dir, zone_file_name)
    logger.debug(f"zone file: {zone_file}")

    if template_file_name:
        template_file = open(template_file_name)
        if domain_template_file_name:
            domain_template_file = open(domain_template_file_name)
            tmpl = template_file.read() + "\n" + domain_template_file.read() + "\n"
            domain_template_file.close()
        else:
            tmpl = template_file.read() + "\n"
        logger.debug(f"template: {tmpl}")

        nameserver = "ns.home.waf.hk"
        contents = tmpl.format(
            banner=BANNER,
            zone=domain_name,
            records=records,
            serial=serial_number,
            nameserver=nameserver,
        )
        template_file.close()
    else:
        contents = f"{BANNER}\n{records}"

    logger.debug(f"bind file: {contents}")

    return write_zone_file(zone_file, contents, dry_run)


def write_zone_file(zone_file, contents, dry_run=False):
    if dry_run:
        logger.info(f"{INFOGREEN}would write dns zone file to {zone_file}{ENDCOLOR}")
        logger.debug(f"contents: {contents}")
    else:
        logger.debug(f"{INFOGREEN}Writing to file {zone_file}...{ENDCOLOR}")
        try:
            with open(zone_file, "w") as f:
                f.write(contents)
                logger.debug(
                    f"{INFOGREEN}finished writing to file {zone_file}{ENDCOLOR}"
                )

            return zone_file
        except IOError as ioe:
            logger.debug(f"{INFORED}ERROR!{ENDCOLOR}")
            logger.fatal(ioe)
