from __future__ import annotations

from enum import Enum

from refinery.lib import chunks
from refinery.lib.crypto import (
    BlockCipher,
    BlockCipherFactory,
    BufferType,
    CipherInterface,
    CipherMode,
    rotl32,
)
from refinery.lib.types import Param
from refinery.units.crypto.cipher import Arg, StandardBlockCipherUnit


class SBOX(tuple, Enum):
    CBR = (
        0x4, 0xA, 0x9, 0x2, 0xD, 0x8, 0x0, 0xE, 0x6, 0xB, 0x1, 0xC, 0x7, 0xF, 0x5, 0x3,
        0xE, 0xB, 0x4, 0xC, 0x6, 0xD, 0xF, 0xA, 0x2, 0x3, 0x8, 0x1, 0x0, 0x7, 0x5, 0x9,
        0x5, 0x8, 0x1, 0xD, 0xA, 0x3, 0x4, 0x2, 0xE, 0xF, 0xC, 0x7, 0x6, 0x0, 0x9, 0xB,
        0x7, 0xD, 0xA, 0x1, 0x0, 0x8, 0x9, 0xF, 0xE, 0x4, 0x6, 0xC, 0xB, 0x2, 0x5, 0x3,
        0x6, 0xC, 0x7, 0x1, 0x5, 0xF, 0xD, 0x8, 0x4, 0xA, 0x9, 0xE, 0x0, 0x3, 0xB, 0x2,
        0x4, 0xB, 0xA, 0x0, 0x7, 0x2, 0x1, 0xD, 0x3, 0x6, 0x8, 0x5, 0x9, 0xC, 0xF, 0xE,
        0xD, 0xB, 0x4, 0x1, 0x3, 0xF, 0x5, 0x9, 0x0, 0xA, 0xE, 0x7, 0x6, 0x8, 0x2, 0xC,
        0x1, 0xF, 0xD, 0x0, 0x5, 0x7, 0xA, 0x4, 0x9, 0x2, 0x3, 0xE, 0x6, 0xB, 0x8, 0xC,
    )
    R34 = (  # R 34.12.2015
        0xC, 0x4, 0x6, 0x2, 0xA, 0x5, 0xB, 0x9, 0xE, 0x8, 0xD, 0x7, 0x0, 0x3, 0xF, 0x1,
        0x6, 0x8, 0x2, 0x3, 0x9, 0xA, 0x5, 0xC, 0x1, 0xE, 0x4, 0x7, 0xB, 0xD, 0x0, 0xF,
        0xB, 0x3, 0x5, 0x8, 0x2, 0xF, 0xA, 0xD, 0xE, 0x1, 0x7, 0x4, 0xC, 0x9, 0x6, 0x0,
        0xC, 0x8, 0x2, 0x1, 0xD, 0x4, 0xF, 0x6, 0x7, 0x0, 0xA, 0x5, 0x3, 0xE, 0x9, 0xB,
        0x7, 0xF, 0x5, 0xA, 0x8, 0x1, 0x6, 0xD, 0x0, 0x9, 0x3, 0xE, 0xB, 0x4, 0x2, 0xC,
        0x5, 0xD, 0xF, 0x6, 0x9, 0x2, 0xC, 0xA, 0xB, 0x7, 0x8, 0x1, 0x4, 0x3, 0xE, 0x0,
        0x8, 0xE, 0x2, 0x5, 0x6, 0x9, 0x1, 0xC, 0xF, 0x4, 0xB, 0x0, 0xD, 0xA, 0x3, 0x7,
        0x1, 0x7, 0xE, 0xD, 0x0, 0x5, 0x8, 0x3, 0x4, 0xF, 0xA, 0x6, 0x9, 0xC, 0xB, 0x2,
    )

    def expand(self):
        result = []
        sbox = self.value
        for i in range(4):
            row = []
            for j in range(0x100):
                q, r = divmod(j, 0x10)
                r += (2 * i + 0) * 16
                q += (2 * i + 1) * 16
                t = sbox[r] | (sbox[q] << 4)
                row.append(rotl32(t, (11 + 8 * i) % 32))
            result.append(row)
        return result


class GOST(BlockCipher):

    _key_data: list[int]

    block_size = 8
    key_size = frozenset({32})

    def __init__(self, key: BufferType, mode: CipherMode | None, swap: bool = False, sbox: SBOX = SBOX.R34):
        self.swap = swap

        sbox = sbox.expand()
        S1, S2, S3, S4 = sbox

        def F(A: int, K: int, swap: bool = False):
            T = A + K & 0xFFFFFFFF
            v1, v2, v3, v4 = T.to_bytes(4, 'little')
            return S1[v1] ^ S2[v2] ^ S3[v3] ^ S4[v4]

        self.F = F

        super().__init__(key, mode)

    def block_decrypt(self, block) -> BufferType:
        A, B = chunks.unpack(block, 4, self.swap)
        F = self.F
        K = self._key_data
        for i in range(8):
            B, A = A, B ^ F(A, K[i])
        for i in range(24 - 1, -1, -1):
            B, A = A, B ^ F(A, K[i % 8])
        return chunks.pack((B, A), 4, self.swap)

    def block_encrypt(self, block) -> BufferType:
        A, B = chunks.unpack(block, 4, self.swap)
        F = self.F
        K = self._key_data
        for i in range(24):
            B, A = A, B ^ F(A, K[i % 8])
        for i in range(8 - 1, -1, -1):
            B, A = A, B ^ F(A, K[i])
        return chunks.pack((B, A), 4, self.swap)

    @property
    def key(self):
        return self._key_data

    @key.setter
    def key(self, key: bytes):
        self._key_data = chunks.unpack(key, 4, self.swap)


class gost(StandardBlockCipherUnit, cipher=BlockCipherFactory(GOST)):
    """
    GOST encryption and decryption.
    """
    def __init__(
        self, key, iv=B'', padding=None, mode=None, raw=False,
        swap: Param[bool, Arg.Switch('-s', help='Decode blocks as big endian rather than little endian.')] = False,
        sbox: Param[str, Arg.Option('-x', choices=SBOX, help=(
            'Choose an SBOX. The default is {default}, which corresponds to the R-34.12.2015 standard. '
            'The other option is CBR, which is the SBOX used by the Central Bank of Russia.'
        ))] = SBOX.R34, **more
    ):
        sbox = Arg.AsOption(sbox, SBOX)
        super().__init__(key, iv=iv, padding=padding, mode=mode, raw=raw, swap=swap, sbox=sbox, **more)

    def _new_cipher(self, **optionals) -> CipherInterface:
        return super()._new_cipher(
            swap=self.args.swap,
            sbox=self.args.sbox,
            **optionals
        )
