#!/usr/bin/python3
import sys
import re
import subprocess
import collections
import toyecc
from MultiCommand import MultiCommand

FPDomainParams = collections.namedtuple("FPDomainParams", [ "curvename", "A", "B", "p", "n", "h", "Gx", "Gy" ])

def openssl_get_all_fp_curvenames(openssl_binary):
	cmd = [ openssl_binary, "ecparam", "-list_curves" ]
	output = subprocess.check_output(cmd)
	output = output.decode("utf-8").split("\n")
	for line in output:
		if line.endswith("prime field"):
			line = line.split(":")
			line = line[0]
			line = line.strip()
			yield line


def parse_openssl_type(text_value, filtertype):
	if filtertype == bytes:
		out_value = bytes.fromhex(text_value.replace(":", ""))
	elif filtertype == int:
		if "(" in text_value:
			return int(text_value.split("(")[0])
		else:
			return int(text_value.replace(":", "").replace(" ", "").replace("\t", ""), 16)
	else:
		raise Exception("Do not know how to parse type '%s'" % (str(filtertype)))
	return out_value

def parse_openssl_text(binoutput, filters = None):
	parsed = { }
	single_line_re = re.compile("(?P<key>[^\s].*):\s*(?P<value>[^\s].*)$")
	multi_line_re = re.compile("(?P<key>[^\s].*):\s*$")

	(key, value) = (None, None)
	output = binoutput.decode("utf-8").split("\n")
	for line in output:
		if line == "":
			continue
		if (key is not None) and line.startswith(" "):
			value.append(line)
			continue
		elif (key is not None):
			parsed[key] = value
			key = None
			value = None

		if key is None:
			result = single_line_re.match(line)
			if result:
				result = result.groupdict()
				parsed[result["key"]] = result["value"]
				continue

			result = multi_line_re.match(line)
			if result:
				result = result.groupdict()
				key = result["key"]
				value = [ ]
				continue

			raise Exception("Don't know what to do: \"%s\"" % (line))

	if key is not None:
		parsed[key] = value

	parsed = { key: "".join(value).strip() for (key, value) in parsed.items() }
	if filters is not None:
		for (key, filtertype) in filters.items():
			if key in parsed:
				text_value = parsed[key]
				out_value = parse_openssl_type(text_value, filtertype)
				parsed[key] = out_value

	return parsed

def openssl_get_curve_fp(openssl_binary, curvename):
	cmd = [ openssl_binary, "ecparam", "-name", curvename, "-param_enc", "explicit", "-text", "-noout" ]
	output = subprocess.check_output(cmd)
	parsed = parse_openssl_text(output, {
		"Order":					int,
		"A":						int,
		"B":						int,
		"Cofactor":					int,
		"Prime":					int,
		"Generator (uncompressed)":	bytes,
		"Seed":						bytes,
	})

	assert(parsed["Field Type"] == "prime-field")

	(Gx, Gy) = toyecc.AffineCurvePoint.deserialize_uncompressed(parsed["Generator (uncompressed)"])

	p = parsed["Prime"]
	A = parsed["A"]
	B = parsed["B"]
	if A > p - 100000:
		A -= p
	if B > p - 100000:
		B -= p

	return FPDomainParams(curvename = curvename, A = A, B = B, p = p, n = parsed["Order"], h = parsed["Cofactor"], Gx = Gx, Gy = Gy)

def print_param(name, value, maxbits):
	name = "\"" + name + "\""
	bits = value.bit_length()
	if value < 0x100:
		print("	%-4s: %d," % (name, value))
	elif bits < maxbits // 2:
		print("	%-4s: %#x," % (name, value))
	else:
		padlen = 2 + 2 * ((maxbits + 7) // 8)
		print("	%-4s: %#0*x," % (name, padlen, value))

def print_curve_fp(domainparams):
	maxbits = max(getattr(domainparams, paramname).bit_length() for paramname in [ "A", "B", "p", "n", "h" ])

	print("cdb.register(_CurveDBEntry(\"%s\", ShortWeierstrassCurve, {" % (domainparams.curvename))
	print_param("a", domainparams.A, maxbits)
	print_param("b", domainparams.B, maxbits)
	print_param("p", domainparams.p, maxbits)
	print_param("n", domainparams.n, maxbits)
	print_param("h", domainparams.h, maxbits)
	print_param("Gx", domainparams.Gx, maxbits)
	print_param("Gy", domainparams.Gy, maxbits)
	print("}, origin = \"OpenSSL\"))")
	print()

def action_dumpcurve(cmd, args):
	if args.curve is None:
		curves = openssl_get_all_fp_curvenames(args.openssl)
	else:
		curves = [ args.curve ]
	for curvename in curves:
		curve = openssl_get_curve_fp(args.openssl, curvename)
		print_curve_fp(curve)

def action_pubkey(cmd, args):
	cmd = [ args.openssl, "ec", "-inform", args.inform, "-in", args.filename, "-pubin", "-text", "-noout" ]
	output = subprocess.check_output(cmd, stderr = subprocess.PIPE)
	parsed = parse_openssl_text(output, {
		"pub":		bytes,
	})
	(Px, Py) = toyecc.AffineCurvePoint.deserialize_uncompressed(parsed["pub"])
	print("curve = toyecc.getcurvebyname(\"%s\")" % (parsed["ASN1 OID"]))
	print("pubkey = toyecc.ECPublicKey(toyecc.AffineCurvePoint(0x%x, 0x%x, curve))" % (Px, Py))

def action_privkey(cmd, args):
	cmd = [ args.openssl, "ec", "-inform", args.inform, "-in", args.filename, "-text", "-noout" ]
	output = subprocess.check_output(cmd, stderr = subprocess.PIPE)
	parsed = parse_openssl_text(output, {
		"priv":						int,
		"pub":						bytes,
		"Order":					int,
		"A":						int,
		"B":						int,
		"Cofactor":					int,
		"Prime":					int,
		"Generator (uncompressed)":	bytes,
		"Seed":						bytes,
	})
	(Px, Py) = toyecc.AffineCurvePoint.deserialize_uncompressed(parsed["pub"])
	if "ASN1 OID" in parsed:
		print("curve = toyecc.getcurvebyname(\"%s\")" % (parsed["ASN1 OID"]))
	else:
		(Gx, Gy) = toyecc.AffineCurvePoint.deserialize_uncompressed(parsed["Generator (uncompressed)"])
		print("curve = toyecc.ShortWeierstrassCurve(")
		print("	a = 0x%x," % (parsed["A"]))
		print("	b = 0x%x," % (parsed["B"]))
		print("	p = 0x%x," % (parsed["Prime"]))
		print("	n = 0x%x," % (parsed["Order"]))
		print("	h = %d," % (parsed["Cofactor"]))
		print("	Gx = 0x%x," % (Gx))
		print("	Gy = 0x%x" % (Gy))
		print(")")
	print("privkey = toyecc.ECPrivateKey(0x%x, curve)" % (parsed["priv"]))
	print("assert(privkey.pubkey.point == toyecc.AffineCurvePoint(0x%x, 0x%x, curve))" % (Px, Py))

mc = MultiCommand(help = "openssl_bridge: Bridging the gap between OpenSSL and toyecc in order to provide test and debugging data that can be used within toyecc.")

def genparser(parser):
	parser.add_argument("--openssl", metavar = "filename", type = str, default = "openssl", help = "Points to the OpenSSL CLI binary which shall be used. Defaults to %(default)s.")
	parser.add_argument("-c", "--curve", metavar = "name", type = str, help = "Curve which should be dumped. If not specified, will dump all available curves.")
mc.register("dumpcurve", "Dump one or all OpenSSL curves in a format that can be used by toyecc", genparser, action = action_dumpcurve)

def genparser(parser):
	parser.add_argument("--openssl", metavar = "filename", type = str, default = "openssl", help = "Points to the OpenSSL CLI binary which shall be used. Defaults to %(default)s.")
	parser.add_argument("--inform", choices = [ "pem", "der" ], default = "pem", type = str, help = "Format in which the public key is supplied in. Allowed arguments are %(choices)s. Defaults to %(default)s.")
	parser.add_argument("-f", "--filename", metavar = "filename", required = True, type = str, help = "File name that should be imported. Mandatory argument.")
mc.register("pubkey", "Parse a public key file", genparser, action = action_pubkey)

def genparser(parser):
	parser.add_argument("--openssl", metavar = "filename", type = str, default = "openssl", help = "Points to the OpenSSL CLI binary which shall be used. Defaults to %(default)s.")
	parser.add_argument("--inform", choices = [ "pem", "der" ], default = "pem", type = str, help = "Format in which the private key is supplied in. Allowed arguments are %(choices)s. Defaults to %(default)s.")
	parser.add_argument("-f", "--filename", metavar = "filename", required = True, type = str, help = "File name that should be imported. Mandatory argument.")
mc.register("privkey", "Parse a private key file", genparser, action = action_privkey)

mc.run(sys.argv[1:])

