#!/usr/bin/env python3

import dns.message
import requests
import base64
import json
import argparse
import sys
import time

url = {
    "libredns":      "https://doh.libredns.gr/dns-query",
    "libredns-ads":  "https://doh.libredns.gr/ads",
    "google":        "https://dns.google/dns-query",
    "cloudflare":    "https://cloudflare-dns.com/dns-query",
    "quad9":         "https://dns.quad9.net/dns-query",
    "cleanbrowsing": "https://doh.cleanbrowsing.org/doh/family-filter/?ct",
    "cira":          "https://private.canadianshield.cira.ca/dns-query",
    "cira-protect":  "https://protected.canadianshield.cira.ca/dns-query",
    "cira-family":   "https://family.canadianshield.cira.ca/dns-query",
}

RR = ["A", "AAAA", "CNAME", "MX", "NS", "SOA", "SPF", "SRV", "TXT", "CAA"]

parser = argparse.ArgumentParser(
    "doh-cli",
    description="a simple DNS over HTTPS client",
    epilog="eg. doh-cli libredns.gr")
parser.add_argument("domain", help="The Domain Name System to resolve")
parser.add_argument(
    "RR", help="Resourse Records: A, AAAA or CNAME",
    metavar=RR, default="A", nargs="?")
parser.add_argument(
    "--debug", help="show the entire response", action="store_true")
parser.add_argument(
    "--verbose", help="show the entire request", action="store_true")
parser.add_argument(
    "--output", help="Display DNS response in plain|json format",
    choices=["plain", "json"], default="plain")
parser.add_argument(
    "--dns",
    help="Choose DNS server: {} or provide your own url".format(", ".join(url.keys())),
    metavar=list(url), default="libredns")
parser.add_argument("--time", help="show Query time", action="store_true")

args = parser.parse_args()

if "." in args.RR:
    args.domain, args.RR = args.RR, args.domain

args.RR = args.RR.upper()

if args.RR not in RR:
    parser.print_usage()
    sys.exit(1)

jdns = []
message = dns.message.make_query(args.domain, args.RR)
dns_req = base64.b64encode(message.to_wire()).decode("UTF8").rstrip("=")

endpoint = url[args.dns] if args.dns in url else args.dns

start_time = time.time() * 1000
try:
    r = requests.get(
            url[args.dns] if args.dns in url else args.dns,
            params={"dns": dns_req},
            headers={"Content-type": "application/dns-message"})
except Exception as e:
    sys.stderr.write(str(e))
    sys.exit(1)
query_time = time.time() * 1000

if not r.ok:
    print("HTTP status code: {0} ".format(r.status_code))
    sys.exit(1)

response = dns.message.from_wire(r.content)

DEBUG = VERBOSE = TIME = ""

if args.debug:
    DEBUG = response.to_text()

if args.verbose:
    VERBOSE = url[args.dns] + dns_req

if args.time:
    TIME = format(round((query_time - start_time), 3))

for answer in dns.message.from_wire(r.content).answer:
    if args.output == "plain":
        delimeter = "IN " + args.RR + " "
        DNS = answer.to_text().split(delimeter)[-1]
        print(DNS)
        if DEBUG:
            print("Debug: {0}".format(DEBUG))
        if VERBOSE:
            print("Verbose: {0}".format(VERBOSE))
        if TIME:
            print("Query Time: {0}".format(TIME))
    else:
        DNS = answer.to_text().split()
        jdns.append(
            {"Query": DNS[0], "TTL": DNS[1], "RR": DNS[3], "Answer": DNS[4:99]})
        if DEBUG:
            jdns.append({"Debug": DEBUG})
        if VERBOSE:
            jdns.append({"Verbose": VERBOSE})
        if TIME:
            jdns.append({"Query Time": TIME})
        print(json.dumps(jdns))
