from unittest import TestCase
from joserfc import jwe
from joserfc.jwk import OctKey, RSAKey
from joserfc.registry import HeaderParameter
from joserfc.errors import (
    InvalidKeyTypeError,
    InvalidKeyLengthError,
    DecodeError,
)
from tests.base import load_key


class TestJWEErrors(TestCase):
    def test_dir_with_invalid_key_type(self):
        key = load_key("ec-p256-private.pem")
        protected = {"alg": "dir", "enc": "A128CBC-HS256"}
        self.assertRaises(
            InvalidKeyTypeError,
            jwe.encrypt_compact,
            protected, b"i", key,
        )

        protected = {"alg": "A128KW", "enc": "A128CBC-HS256"}
        self.assertRaises(
            InvalidKeyTypeError,
            jwe.encrypt_compact,
            protected, b"i", key,
        )

        protected = {"alg": "ECDH-ES+A128KW", "enc": "A128CBC-HS256"}
        self.assertRaises(
            InvalidKeyTypeError,
            jwe.encrypt_compact,
            protected, b"i", "secret",
        )

        protected = {"alg": "PBES2-HS256+A128KW", "enc": "A128CBC-HS256"}
        self.assertRaises(
            InvalidKeyTypeError,
            jwe.encrypt_compact,
            protected, b"i", key,
            algorithms=["PBES2-HS256+A128KW", "A128CBC-HS256"]
        )

    def test_rsa_with_invalid_key_type(self):
        key = load_key("ec-p256-private.pem")
        protected = {"alg": "RSA-OAEP", "enc": "A128CBC-HS256"}
        self.assertRaises(
            InvalidKeyTypeError,
            jwe.encrypt_compact,
            protected, b"i", key,
        )

    def test_A128KW_unwrap_error(self):
        key1 = OctKey.generate_key(128)
        key2 = OctKey.generate_key(128)
        protected = {"alg": "A128KW", "enc": "A128CBC-HS256"}
        value = jwe.encrypt_compact(protected, b"i", key1)
        self.assertRaises(
            DecodeError,
            jwe.decrypt_compact,
            value, key2
        )

    def test_invalid_alg(self):
        protected = {"alg": "INVALID", "enc": "A128CBC-HS256"}
        self.assertRaises(
            ValueError,
            jwe.encrypt_compact,
            protected, b"i", "secret"
        )

    def test_invalid_key_length(self):
        protected = {"alg": "dir", "enc": "A128CBC-HS256"}
        self.assertRaises(
            InvalidKeyLengthError,
            jwe.encrypt_compact,
            protected, b"i", "secret"
        )
        protected = {"alg": "A128KW", "enc": "A128CBC-HS256"}
        self.assertRaises(
            InvalidKeyLengthError,
            jwe.encrypt_compact,
            protected, b"i", "secret"
        )
        protected = {"alg": "RSA-OAEP", "enc": "A128CBC-HS256"}
        rsa_key = RSAKey.generate_key(1024)
        self.assertRaises(
            InvalidKeyLengthError,
            jwe.encrypt_compact,
            protected, b"i", rsa_key
        )

    def test_extra_header(self):
        key = OctKey.generate_key(256)
        protected = {"alg": "dir", "enc": "A128CBC-HS256", "custom": "hi"}
        self.assertRaises(
            ValueError,
            jwe.encrypt_compact,
            protected, b"i", key
        )

        registry = jwe.JWERegistry(strict_check_header=False)
        jwe.encrypt_compact(protected, b"i", key, registry=registry)

        registry = jwe.JWERegistry(header_registry={
            "custom": HeaderParameter("Custom", "str")
        })
        jwe.encrypt_compact(protected, b"i", key, registry=registry)

    def test_strict_check_header_with_more_header_registry(self):
        key = load_key("ec-p256-private.pem")
        protected = {"alg": "ECDH-ES", "enc": "A128CBC-HS256", "custom": "hi"}
        self.assertRaises(
            ValueError,
            jwe.encrypt_compact,
            protected, b"i", key
        )
        registry = jwe.JWERegistry(strict_check_header=False)
        jwe.encrypt_compact(protected, b"i", key, registry=registry)
