#!/usr/bin/env python

import os
import os.path
import yaml
import socket
import logging
import logging.handlers
import argparse
import dns.exception
from sys import exit
from time import sleep
from zonecheck import ZoneCheckLite
from dns.resolver import Resolver, NXDOMAIN, NoAnswer, NoNameservers

def get_args():
    '''get command line arguments'''
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('--serial-lag', default=2, type=int,
            help='alert if the serial is behind by this value or more')
    parser.add_argument('--log', default='/var/log/zone_status.log',
            help='location of error file')
    parser.add_argument('--puppet-facts', action='store_true')
    parser.add_argument('--puppet-facts-dir',
            default='guess')
    parser.add_argument('--config', default='/usr/local/etc/zone_check.conf',
             help='comma seperated list of zones')
    parser.add_argument('-v', '--verbose', action='count', default=1 )
    return parser.parse_args()

def set_log_level(args_level):
    log_level = logging.CRITICAL
    if args_level == 1:
        log_level = logging.ERROR
    elif args_level == 2:
        log_level = logging.WARN
    elif args_level == 3:
        log_level = logging.INFO
    elif args_level > 3:
        log_level = logging.DEBUG
    logging.basicConfig(level=log_level)

def get_log_file(file_name):
    logger = logging.getLogger('zonecheck')
    if file_name != 'stdout':
        logger.propagate = False
        handler = logging.handlers.RotatingFileHandler(
                file_name, maxBytes=100000000, backupCount=5)
        formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
        handler.setFormatter(formatter)
        logger.addHandler(handler)
    return logger

def check_zone(masters, zone, addr, args, logger):
    logger.debug('{} &  {} with {}'.format(masters, zone,addr))
    zone_check = ZoneCheckLite(addr, masters, zone)
    zone_check.check()
    return zone_check.errors

def get_facts_file(facts_dir, logger):
    possible_paths = [
            '/etc/puppetlabs/facter/facts.d',
            '/var/puppet/facts']
    if facts_dir == 'guess':
        for path in possible_paths:
            if os.path.exists(path):
                return open(os.path.join(path, 'zone_status.txt'), 'w')
        logger.error('unable to guess facts dir')
        exit(1)
    elif os.path.exists(facts_dir):
        return open(os.path.join(facts_dir, 'zone_status.txt'), 'w')
    else:
        loggin.error('invalid facts_dir: {}'.format(facts_dir))
        exit(1)

def main():
    args = get_args()
    set_log_level(args.verbose)
    logger = get_log_file(args.log)
    try:
        config = yaml.load(open(args.config, 'r'), Loader=yaml.FullLoader)
    except:
        config = yaml.load(open(args.config, 'r'))
    errors = { }
    facter_error = False
    master_soa_error = False
    for zone_set, cfg in config['zones'].items():
        if 'masters' not in cfg:
            continue
        for zone in cfg['zones']:
            for addr in config['ip_addresses']:
                check_errors = check_zone(cfg['masters'], zone, addr, args, logger)
                if any(check_errors.values()):
                    if zone not in errors:
                        errors[zone] = { }
                    errors[zone][addr] = check_errors
                    if any(check_errors['general']):
                        facter_error = True
                    if check_errors['master_soa']:
                        master_soa_error = True

    logger.error('errors:{}'.format(yaml.dump(errors)))

    if args.puppet_facts:
        # this is a safety so we dont disable all nodes if there is
        # a problem with the distribution layer
        if master_soa_error:
            facter_error = False
        facts = 'zone_status_errors={}\n'.format(str(facter_error).lower())
        logger.debug('puppet facts = {}'.format(facts))
        facts_file = get_facts_file(args.puppet_facts_dir, logger)
        facts_file.write(facts)
        facts_file.close()

if __name__ == '__main__':
    main()
