from . import modNum as md
import random
from math import floor, ceil, log2
from . import EllipticCurve as elip


class ElGamal:
    EC = elip.EllipticCurve(0,0)
    #Initializes an ElGamal class object, this object contains the prime, public key, private key, and h - a product of
    # the public and private key
    # All varialbes are labeled with __ to prevent access and modification outside of the class
    def __init__(self):
        self.__prime = ElGamal.EC.get_prime()
        tup = ElGamal.EC.get_init_point()
        self.__public_key = elip.EllipticCurve(tup[0],tup[1])
        self.__private_key = random.randrange(1,self.__prime)
        self.__h = self.__public_key ** self.__private_key

    # decrypt function is inside the class so that the private key is never accessable
    # decrypt builds up a single string by consuming the list of (c0,c1) tuples generated by the encrypt function
    #     one at a time
    # decrypt will only function properly on a message that used the same prime/public key/h values as those of the class
    def decrypt(self, encrypted_messages):

        #Function to turn integers back into strings
        def __decode(m):
            x = m.get_x().value//ElGamal.EC.get_w()
            length = ceil( x.bit_length() / 8)
            return x.to_bytes(length, byteorder='little').decode()

        plain_text = ""
        
        for encrypted_message in encrypted_messages:
            c0 = encrypted_message[0]
            c1 = encrypted_message[1]
        
            m = elip.EllipticCurve.group_inv(c0) ** self.__private_key * c1 # c_0 ^ (-1) ** private_key * c_1
            message = __decode(m)
            plain_text = plain_text + message
        return plain_text

# DO NOT USE. THIS IS ONLY FOR TESTING. THE VALUE USED FOR S, IN G^S AND K^S, IS NOT RANDOM
# THIS IS ONLY FOR TESTING. USE THE ENCRYPT FUNCTION OUTSIDE OF THE CLASS DEFINITION
# ADDED THE SPECIAL_KEY VARIABLE TO FURTHER DISUADE FOLKS FROM USING THIS.
    def test_encrypt_DO_NOT_USE(self, prime, public_key_g, public_key_h, plain_text, special_key):
        test_encrypt = ""
        if special_key == 1234:
            test_encrypt = ElGamal.__test_encrypt(prime, public_key_g, public_key_h, plain_text)
        return test_encrypt
    
# DO NOT USE. THIS IS ONLY FOR TESTING. THE VALUE USED FOR S, IN G^S AND K^S, IS NOT RANDOM
# THIS IS ONLY FOR TESTING. USE THE ENCRYPT FUNCTION OUTSIDE OF THE CLASS DEFINITION
    def __test_encrypt(prime, public_key_g, public_key_h, plain_text):

        def __encode(plain_text):
            message_as_int = int.from_bytes(plain_text.encode(), byteorder= 'little')
            w = ElGamal.EC.get_w()
            prime = ElGamal.EC.get_prime()
            a = ElGamal.EC.get_a()
            b = ElGamal.EC.get_b()
            DO_NOT_CHANGE = 1
            # DO_NOT_CHANGE must be 1. If it is >1 then when you go to decode and calculate
            #  x//w you will potentially get a different x and thus a different message
            # You have to look at multiple values for x though because not every value in the
            # field is a valid point on the curve. This is mostly b/c only ~50% of numbers
            # in a prime finite field have a sqrt. The ones that do all have double roots
            # Example in field mod 7 sqrt(2) = 3 or 4. But sqrt(3) has no solution. No
            # number squared give you 3 mod 7, [3,10,17,24,31,38,45] - None of those are
            # squares.
            # With a w of 256 and ~50% chance of any number in the field be a valid point on
            # the curve there is ~1*10^-78 chance of not finding a solution. That means that
            # if you repeated this test for each atom in the universe ~1-100,000 of them
            # fail to find a valid point on the curve. That's a freaking small amount
            for x in range(w * message_as_int, w * (message_as_int + DO_NOT_CHANGE)):
                x = md.modNum(x,prime)
                c = x ** 3 + a * x + b
                y = md.modNum.sqrt(c)
                if y.value != 0:
                    return elip.EllipticCurve(x.value,y.value)
            print(f"w is to small at size {w}, must increase it, Note the odds of this are extraordinarily rare. ~10^-78 level rare.")
            raise Exception

        def __split(message):
            n = floor( log2( ElGamal.EC.get_prime()/ElGamal.EC.get_w() ) ) // 8 - 1

            messageList = [message[i:i+n] for i in range(0, len(message), n)]
            return messageList

        messages = __split(plain_text) 
        encrypted_messages = []
        w = ElGamal.EC.get_w()
        for message in messages:
            m = __encode(message)
            s = prime // (2 * w) 
            c0 = public_key_g ** s # g^s
            hToTheS = public_key_h ** s # h^s -> (g^k)^s
            c1 = hToTheS * m # h^s * m
            encrypted_messages.append((c0,c1))

        return encrypted_messages


    #method to print out public information about the Encryption system
    def __str__(self):
        return f"Field using {self.__prime}, public key {self.__public_key}, and h {self.__h}"

    #method to print out public information about the Encryption system
    def __repr__(self):
        return f"Field using {self.__prime}, public key {self.__public_key}, and h {self.__h}"

    # function to return the prime in use
    def get_group_prime(self):
        return self.__prime

    # function to return the public key in use
    def get_public_key(self):
        return self.__public_key

    # function to return the h in use
    def get_h(self):
        return self.__h


# Encrypts a plain text message using the published prime, g, and h via the El Gamal Elliptic Curve scheme
# message is broken into pieces whose integer representation is always less than the prime used to define the group

def encrypt(prime, public_key_g, public_key_h, plain_text):

    def __encode(plain_text):
        message_as_int = int.from_bytes(plain_text.encode(), byteorder= 'little')
        w = ElGamal.EC.get_w()
        prime = ElGamal.EC.get_prime()
        a = ElGamal.EC.get_a()
        b = ElGamal.EC.get_b()
        DO_NOT_CHANGE = 1
        # DO_NOT_CHANGE must be 1. If it is >1 then when you go to decode and calculate
        #  x//w you will potentially get a different x and thus a different message
        # You have to look at multiple values for x though because not every value in the
        # field is a valid point on the curve. This is mostly b/c only ~50% of numbers
        # in a prime finite field have a sqrt. The ones that do all have double roots
        # Example in field mod 7 sqrt(2) = 3 or 4. But sqrt(3) has no solution. No
        # number squared give you 3 mod 7, [3,10,17,24,31,38,45] - None of those are
        # squares.
        # With a w of 256 and ~50% chance of any number in the field be a valid point on
        # the curve there is ~1*10^-78 chance of not finding a solution. That means that
        # if you repeated this test for each atom in the universe ~1-100,000 of them
        # fail to find a valid point on the curve. That's a freaking small amount
        # This loops looks in a unique range around the integer value of the message to
        # find a value for which the message is a valid point on the elliptic curve
        # ~50% of all possible points should be valid solutions. It does mean it is still
        # possible for this loop to fail. But it has ~ < .5^256% chance, which is kinda
        # small, ~ for all atoms in the universe between 1 and 100,000 would experience
        # a failure.
        for x in range(w * message_as_int, w * (message_as_int + DO_NOT_CHANGE)):
            x = md.modNum(x,prime)
            c = x ** 3 + a * x + b
            y = md.modNum.sqrt(c)
            if y.value != 0:
                return elip.EllipticCurve(x.value,y.value)
        print(f"w is to small at size {w}, must increase it, Note the odds of this are extraordinarily rare. ~10^-78 level rare. Odds of winning the lotery ~3*10^-9 for comparison")
        raise Exception

    def __split(message):
        n = floor( log2( ElGamal.EC.get_prime()/ElGamal.EC.get_w() ) ) // 8 - 1

        messageList = [message[i:i+n] for i in np.arange(0, len(message), n)]
        return messageList

    messages = __split(plain_text) # splits text into manageable chunks whose int representation is less than the prime of the group
    encrypted_messages = []
    w = ElGamal.EC.get_w()
    for message in messages:
        m = __encode(message) # translates the object from a string representation to an int
        s = random.randrange(prime // (2 * w), prime // w ) # random int to obfuscate each individual message chunk
                                                            #   since 's' is actually irrelevant to decryption using a random one
                                                            #    for each chunk adds some security without impacting results
        c0 = public_key_g ** s # g^s
        hToTheS = public_key_h ** s # h^s -> (g^k)^s
        c1 = hToTheS * m # h^s * m
        encrypted_messages.append((c0,c1))

    return encrypted_messages

