#!/usr/bin/python3
import subprocess
import pyasn1.codec.ber.decoder

def decode_privkey(octetstring):
	return sum((value << (8 * byteno)) for (byteno, value) in enumerate(reversed(octetstring)))

def tobyte(bits):
	return sum((value << bitno) for (bitno, value) in enumerate(reversed(bits)))

def decode_pubkey(bitstring):
	assert((len(bitstring) % 8) == 0)
	bytestring = bytes(tobyte(bitstring[i : i + 8]) for i in range(0, len(bitstring), 8))
	assert(bytestring[0] == 0x04)
	xy = bytestring[1:]
	assert((len(xy) % 2) == 0)
	l = len(xy) // 2
	(x, y) = (decode_privkey(xy[ : l]), decode_privkey(xy[l : ]))
	return (x, y)

available_curves = subprocess.check_output([ "openssl", "ecparam", "-list_curves" ]).decode("utf-8").split("\n")
available_curves = sorted([ line.split(":")[0].strip() for line in available_curves if "prime field" in line ])
for curvename in available_curves:
	for i in range(10):
		der_params_privkey = subprocess.check_output([ "openssl", "ecparam", "-genkey", "-name", curvename, "-outform", "der" ])
		(asn1_ecparams, der_privkey) = pyasn1.codec.ber.decoder.decode(der_params_privkey)
		(asn1_privkey, tail) = pyasn1.codec.ber.decoder.decode(der_privkey)
		assert(len(tail) == 0)

		curve_oid = asn1_privkey[2]
		privkey = decode_privkey(asn1_privkey[1])
		(pubkey_x, pubkey_y) = decode_pubkey(asn1_privkey[3])

		if i == 0:
			print("	def test_%s(self):" % (curvename.replace("-", "_")))
			print("		curve_entry = getcurveentry(\"%s\")" % (curvename))
			print("		self.assertEqual(curve_entry.oid, \"%s\")" % (curve_oid))
			print("		curve = curve_entry()")
		print("		self.assertEqual(0x%x * curve.G, AffineCurvePoint(0x%x, 0x%x, curve))" % (privkey, pubkey_x, pubkey_y))
	print()
#	, privkey, pubkey)

