'''

Welcome To PYMULTICRYPT

Python module for the new and the secure
MULTICRYPT encryption algorithm made by me which is
a hybrid encryption system (consists of
both symmetric and asymmetric encryption)
and is designed specially for End-2-End encryption
(for more details on the algorithm, 
read the readme.md on https://github.com/AbhayTr/PyMultiCrypt).

Usage:

Constructor: End2End()

@params

public_key (Optional): Public Key to be used if you want to use existing key (Default: "").
private_key (Optional): Private Key to be used if you want to use existing key (Default: "").
save (Optional): Should be True/False. Specifies whether the keys have to be stored in a file or not (Default: True).
key_path (Optional): Specifies the path and name of the file where the keys have to be stored, if save = True (Default: Same as your python file).
new (Optional): Should be True/False. Specifies whether it should ignore any existing key pairs and generate new key pair or not (Default: False).

Methods:

1. keys(): Returns Private Key and Public Key in the form of dictionary of the format {"public": %YOUR_PUBLIC_KEY%, "private": %YOUR_PRIVATE_KEY%}.

2. encrypt(): Encrypts the message using MULTICRYPT algorithm.

@params

message (Required): Message to encrypt.
public_key (Required): Public Key of the recipient of the message (for the asymmetric encryption part).

3. decrypt(): Decrypts the encrypted message using MULTICRYPT algorithm.

@params

message (Required): Encrypted Message to decryt.
private_key (Optional): Your Private Key required to decrypt any message which is encrypted with Public Key
                        linked to that private key (Default: Key which was either passed in the End2End()
                        constrctor or generated by the program for you).

Useful for transmitting data securely between 2 devices on a network.

© Abhay Tripathi

'''

import random
import binascii

class End2End:

    def __init__(self, public_key = "", private_key = "", save = True, key_path = "key_pair.key", new = False):
        self.primes_list = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349]
        self.schema = {1: "a", 2: "b", 3: "c", 4: "d", 5: "e", 6: "f", 7: "g", 8: "h", 9: "i", 0: "j"}
        self.reverse_schema = {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6, "g": 7, "h": 8, "i": 9, "j": 0}
        if public_key == "" or private_key == "":
            try:
                if new:
                    raise Exception("Generate New Key Pair.")
                keys = open(key_path, "r").readlines()
                self.private_key = keys[0]
                self.public_key = keys[1]
            except:
                fresh_keys = self.get_keys()
                self.private_key = fresh_keys["private"]
                self.public_key = fresh_keys["public"]
                if save:
                    save_keys = open(key_path, "w")
                    save_keys.write(self.private_key + "\n")
                    save_keys.write(self.public_key)
                    save_keys.close()
        else:
            self.public_key = public_key
            self.private_key = private_key
            if save:
                save_keys = open(key_path, "w")
                save_keys.write(self.private_key + "\n")
                save_keys.write(self.public_key)
                save_keys.close()

    def n_bits_prime(self, n):
        while True:
            prime_number = random.randrange((2 ** (n-1)) + 1, (2 ** n) - 1)
            for divisor in self.primes_list:
                if prime_number % divisor == 0 and divisor ** 2 <= prime_number:
                    break
            else:
                return prime_number

    def check_prime_strength(self, prime_number):
        max_divisions_two = 0
        ec = prime_number - 1
        while ec % 2 == 0:
            ec >>= 1
            max_divisions_two += 1
        assert 2 ** max_divisions_two * ec == prime_number - 1

        def trial_composite(round_tester):
            if pow(round_tester, ec, prime_number) == 1:
                return False
            for i in range(max_divisions_two):
                if pow(round_tester, 2 ** i * ec, prime_number) == prime_number - 1:
                    return False
            return True

        for i in range(20):
            round_tester = random.randrange(2, prime_number)
            if trial_composite(round_tester):
                return False
        return True

    def is_coprime(self, number1, number2):
        while number2 != 0:
            number1, number2 = number2, number1 % number2
        return number1 == 1

    def mod_inverse(self, number1, number2):

        def modulo_inverse(number1, number2):
            if number2 == 0:
                return (1,0)
            (q,r) = (number1 // number2, number1 % number2)
            (s,t) = modulo_inverse(number2, r)
            return (t, s - (q * t))

        inv = modulo_inverse(number1, number2)[0]
        if inv < 1:
            inv += number2
        return inv

    def compress_string(self, string):
        return int(binascii.hexlify(string.encode("utf-8")), 16)

    def deflate_string(self, number):
        return binascii.unhexlify(format(number, "x").encode("utf-8")).decode("utf-8")

    def compress_number(self, number):
        compressed_number_string = ""
        for digit in number:
            try:
                compressed_number_string += self.schema[int(digit)]
            except:
                compressed_number_string += digit
        return compressed_number_string

    def deflate_number(self, number_string):
        deflated_number = ""
        for charecter in number_string:
            try:
                deflated_number += str(self.reverse_schema[charecter])
            except:
                deflated_number += charecter
        return deflated_number

    def get_keys(self):
        prime_key_1 = 0
        prime_key_2 = 0
        while True:
            prime_key = self.n_bits_prime(1024)
            if self.check_prime_strength(prime_key):
                if prime_key_1 == 0:
                    prime_key_1 = prime_key
                    continue
                elif prime_key_2 == 0:
                    prime_key_2 = prime_key
                    break
        public_key_number = prime_key_1 * prime_key_2
        phi_n = (prime_key_1 - 1) * (prime_key_2 - 1)
        e = 0
        for number in range(2, phi_n):
            if self.is_coprime(phi_n, number):
                e = number
                break
        public_key = str(public_key_number) + "X" + str(e)
        private_key = str(self.mod_inverse(e, phi_n)) + "X" + str(public_key_number)
        return {"private": self.compress_number(private_key), "public": self.compress_number(public_key)}

    def keys(self):
        return {"public": self.public_key, "private": self.private_key}

    def rsa_encrypt(self, message, public_key):
        public_key = self.deflate_number(public_key)
        seperator_position = public_key.index("X")
        e = int(public_key[seperator_position + 1:])
        public_key_number = int(public_key[:seperator_position])
        encrypted_message = ""
        charecters = list(message)
        for charecter_index in range(len(charecters)):
            encrypted_message += str(ord(charecters[charecter_index]))
            if charecter_index != len(charecters) - 1:
                encrypted_message += "300"
        encrypted_message = str(int((int(encrypted_message) ** e) % public_key_number))
        return encrypted_message

    def rsa_decrypt(self, message, private_key):
        private_key = self.deflate_number(private_key)
        seperator_position = private_key.index("X")
        public_key_number = int(private_key[seperator_position + 1:])
        private_key_number = int(private_key[:seperator_position])
        charecters = str(pow(int(message), private_key_number, public_key_number)).split("300")
        actual_message = ""
        for charecter in charecters:
            actual_message += chr(int(charecter))
        return actual_message

    def encrypt(self, message, public_key):
        key = self.n_bits_prime(256)
        encrypted_message = self.compress_string(message) + key
        return self.compress_number(str(encrypted_message)) + "K" + self.compress_number(self.rsa_encrypt(str(key), public_key))

    def decrypt(self, message, private_key = ""):
        if private_key == "":
            private_key = self.private_key
        seperator_position = message.index("K")
        encrypted_message = message[:seperator_position]
        encrypted_key = self.deflate_number(message[seperator_position + 1:])
        encrypted_message = self.deflate_number(encrypted_message)
        key = self.rsa_decrypt(encrypted_key, private_key)
        actual_message_number = int(encrypted_message) - int(key)
        actual_message = self.deflate_string(actual_message_number)
        return actual_message
