# encoding: utf-8

# This file is part of CycloneDX Python Lib
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) OWASP Foundation. All Rights Reserved.

import re
from enum import Enum
from typing import List, Optional, Tuple, Union
from urllib.parse import ParseResult, urlparse

"""
This set of classes represents the data that is possible under the CycloneDX extension
schema for Vulnerabilities (version 1.0).

.. note::
    See the CycloneDX Schema extension definition https://cyclonedx.org/ext/vulnerability/.
"""


class VulnerabilitySourceType(Enum):
    """
    Enum object that defines the permissible source types for a Vulnerability.

    .. note::
        See `scoreSourceType` in https://github.com/CycloneDX/specification/blob/master/schema/ext/vulnerability-1.0.xsd
    """
    CVSS_V2 = 'CVSSv2'
    CVSS_V3 = 'CVSSv3'
    OWASP = 'OWASP Risk'
    OPEN_FAIR = 'Open FAIR'
    OTHER = 'Other'

    @staticmethod
    def get_from_vector(vector: str) -> 'VulnerabilitySourceType':
        """
        Attempt to derive the correct SourceType from an attack vector.

        For example, often attack vector strings are prefixed with the scheme in question - such
        that __CVSS:3.0/AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N__ would be the vector
        __AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N__ under the __CVSS 3__ scheme.

        Returns:
            Always returns an instance of `VulnerabilitySourceType`. `VulnerabilitySourceType.OTHER` is returned if the
            scheme is not obvious or known to us.
        """
        if vector.startswith('CVSS:3.'):
            return VulnerabilitySourceType.CVSS_V3
        elif vector.startswith('CVSS:2.'):
            return VulnerabilitySourceType.CVSS_V2
        elif vector.startswith('OWASP'):
            return VulnerabilitySourceType.OWASP
        else:
            return VulnerabilitySourceType.OTHER

    def get_localised_vector(self, vector: str) -> str:
        """
        This method will remove any Source Scheme type from the supplied vector, returning just the vector.

        .. Note::
            Currently supports CVSS 3.x, CVSS 2.x and OWASP schemes.

        Returns:
            The vector without any scheme prefix as a `str`.
        """
        if self == VulnerabilitySourceType.CVSS_V3 and vector.startswith('CVSS:3.'):
            return re.sub('^CVSS:3\\.\\d/?', '', vector)

        if self == VulnerabilitySourceType.CVSS_V2 and vector.startswith('CVSS:2.'):
            return re.sub('^CVSS:2\\.\\d/?', '', vector)

        if self == VulnerabilitySourceType.OWASP and vector.startswith('OWASP'):
            return re.sub('^OWASP/?', '', vector)

        return vector


class VulnerabilitySeverity(Enum):
    """
    Enum object that defines the permissible severities for a Vulnerability.

    .. note::
        See `severityType` in https://github.com/CycloneDX/specification/blob/master/schema/ext/vulnerability-1.0.xsd
    """
    NONE = 'None'
    LOW = 'Low'
    MEDIUM = 'Medium'
    HIGH = 'High'
    CRITICAL = 'Critical'
    UNKNOWN = 'Unknown'

    @staticmethod
    def get_from_cvss_scores(scores: Union[Tuple[float], float, None]) -> 'VulnerabilitySeverity':
        """
        Derives the Severity of a Vulnerability from it's declared CVSS scores.

        Args:
            scores: A `tuple` of CVSS scores. CVSS scoring system allows for up to three separate scores.

        Returns:
            Always returns an instance of `VulnerabilitySeverity`.
        """
        if type(scores) is float:
            scores = (scores,)

        if scores is None:
            return VulnerabilitySeverity.UNKNOWN

        max_cvss_score: float
        if isinstance(scores, tuple):
            max_cvss_score = max(scores)
        else:
            max_cvss_score = float(scores)

        if max_cvss_score >= 9.0:
            return VulnerabilitySeverity.CRITICAL
        elif max_cvss_score >= 7.0:
            return VulnerabilitySeverity.HIGH
        elif max_cvss_score >= 4.0:
            return VulnerabilitySeverity.MEDIUM
        elif max_cvss_score > 0.0:
            return VulnerabilitySeverity.LOW
        else:
            return VulnerabilitySeverity.NONE


class VulnerabilityRating:
    """
    Class that models the `scoreType` complex element in the Vulnerability extension schema.

    .. note::
        See `scoreType` in https://github.com/CycloneDX/specification/blob/master/schema/ext/vulnerability-1.0.xsd
    """

    def __init__(self, score_base: Optional[float] = None, score_impact: Optional[float] = None,
                 score_exploitability: Optional[float] = None, severity: Optional[VulnerabilitySeverity] = None,
                 method: Optional[VulnerabilitySourceType] = None, vector: Optional[str] = None) -> None:
        self._score_base = score_base
        self._score_impact = score_impact
        self._score_exploitability = score_exploitability
        self._severity = severity
        self._method = method
        if self._method and vector:
            self._vector = self._method.get_localised_vector(vector=vector)
        else:
            self._vector = str(vector)

    def get_base_score(self) -> Optional[float]:
        """
        Get the base score of this VulnerabilityRating.

        Returns:
           Declared base score of this VulnerabilityRating as `float`.
        """
        return self._score_base

    def get_impact_score(self) -> Optional[float]:
        """
        Get the impact score of this VulnerabilityRating.

        Returns:
           Declared impact score of this VulnerabilityRating as `float`.
        """
        return self._score_impact

    def get_exploitability_score(self) -> Optional[float]:
        """
        Get the exploitability score of this VulnerabilityRating.

        Returns:
           Declared exploitability score of this VulnerabilityRating as `float`.
        """
        return self._score_exploitability

    def get_severity(self) -> Optional[VulnerabilitySeverity]:
        """
        Get the severity score of this VulnerabilityRating.

        Returns:
           Declared severity of this VulnerabilityRating as `VulnerabilitySeverity` or `None`.
        """
        return self._severity

    def get_method(self) -> Optional[VulnerabilitySourceType]:
        """
        Get the source method of this VulnerabilitySourceType.

        Returns:
           Declared source method of this VulnerabilitySourceType as `VulnerabilitySourceType` or `None`.
        """
        return self._method

    def get_vector(self) -> Optional[str]:
        return self._vector

    def has_score(self) -> bool:
        return (None, None, None) != (self._score_base, self._score_impact, self._score_exploitability)


class Vulnerability:
    """
    Represents <xs:complexType name="vulnerability">
    """

    def __init__(self, id: str, source_name: Optional[str], source_url: Optional[str],
                 ratings: Optional[List[VulnerabilityRating]], cwes: Optional[List[int]], description: Optional[str],
                 recommendations: Optional[List[str]], advisories: Optional[List[str]]) -> None:
        self._id = id
        self._source_name = source_name
        self._source_url: Optional[ParseResult] = urlparse(source_url) if source_url else None
        self._ratings: List[VulnerabilityRating] = ratings if ratings else []
        self._cwes: List[int] = cwes if cwes else []
        self._description = description
        self._recommendations: List[str] = recommendations if recommendations else []
        self._advisories: List[str] = advisories if advisories else []

    def get_id(self) -> str:
        return self._id

    def get_source_name(self) -> Optional[str]:
        return self._source_name

    def get_source_url(self) -> Optional[ParseResult]:
        return self._source_url

    def get_ratings(self) -> List[VulnerabilityRating]:
        if not self.has_ratings():
            return list()
        return self._ratings

    def get_cwes(self) -> List[int]:
        if not self.has_cwes():
            return list()
        return self._cwes

    def get_description(self) -> Optional[str]:
        return self._description

    def get_recommendations(self) -> List[str]:
        if not self.has_recommendations():
            return list()
        return self._recommendations

    def get_advisories(self) -> List[str]:
        if not self.has_advisories():
            return list()
        return self._advisories

    def has_ratings(self) -> bool:
        return bool(self._ratings)

    def has_cwes(self) -> bool:
        return bool(self._cwes)

    def has_recommendations(self) -> bool:
        return bool(self._recommendations)

    def has_advisories(self) -> bool:
        return bool(self._advisories)
