import base64
import logging
import os
import posixpath
import requests
import uuid
from collections import namedtuple
from concurrent.futures import ThreadPoolExecutor

from mlflow_databricks_artifacts.azure.client import (
    put_adls_file_creation,
    patch_adls_file_upload,
    patch_adls_flush,
    put_block,
    put_block_list,
)

from mlflow.entities import FileInfo
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import (
    INVALID_PARAMETER_VALUE,
    INTERNAL_ERROR,
)

from mlflow_databricks_artifacts.protos.patched_databricks_artifacts_pb2 import (
    DatabricksMlflowArtifactsService,
    GetCredentialsForWrite,
    GetCredentialsForRead,
    ArtifactCredentialType,
)
from mlflow.tracking import get_tracking_uri
from mlflow.protos.service_pb2 import MlflowService, GetRun, ListArtifacts
from mlflow.store.artifact.artifact_repo import ArtifactRepository
from mlflow_databricks_artifacts.utils import chunk_list
from mlflow.utils.databricks_utils import get_databricks_host_creds
from mlflow_databricks_artifacts.utils.file_utils import (
    download_file_using_http_uri,
    relative_path_to_artifact_path,
    yield_file_in_chunks,
)
from mlflow_databricks_artifacts.utils.proto_json_utils import message_to_json
from mlflow_databricks_artifacts.utils import rest_utils
from mlflow_databricks_artifacts.utils.rest_utils import (
    call_endpoint,
    extract_api_info_for_service,
    _REST_API_PATH_PREFIX,
    augmented_raise_for_status,
)
from mlflow_databricks_artifacts.utils.uri import (
    extract_and_normalize_path,
    get_databricks_profile_uri_from_artifact_uri,
    is_databricks_acled_artifacts_uri,
    is_valid_dbfs_uri,
    remove_databricks_profile_info_from_artifact_uri,
)

_logger = logging.getLogger(__name__)
_AZURE_MAX_BLOCK_CHUNK_SIZE = (
    100000000  # Max. size of each block allowed is 100 MB in stage_block
)
_DOWNLOAD_CHUNK_SIZE = 100000000
_MAX_CREDENTIALS_REQUEST_SIZE = (
    2000  # Max number of artifact paths in a single credentials request
)
_SERVICE_AND_METHOD_TO_INFO = {
    service: extract_api_info_for_service(service, _REST_API_PATH_PREFIX)
    for service in [MlflowService, DatabricksMlflowArtifactsService]
}

# Constants used to determine max level of parallelism to use while uploading/downloading artifacts.
# Max threads to use for parallelism.
_NUM_MAX_THREADS = 8
# Max threads per CPU
_NUM_MAX_THREADS_PER_CPU = 2
assert _NUM_MAX_THREADS >= _NUM_MAX_THREADS_PER_CPU
assert _NUM_MAX_THREADS_PER_CPU > 0
# Default number of CPUs to assume on the machine if unavailable to fetch it using os.cpu_count()
_NUM_DEFAULT_CPUS = _NUM_MAX_THREADS // _NUM_MAX_THREADS_PER_CPU

class DatabricksArtifactRepository(ArtifactRepository):
    """
    Performs storage operations on artifacts in the access-controlled
    `dbfs:/databricks/mlflow-tracking` location.

    Signed access URIs for S3 / Azure Blob Storage are fetched from the MLflow service and used to
    read and write files from/to this location.

    The artifact_uri is expected to be of the form
    dbfs:/databricks/mlflow-tracking/<EXP_ID>/<RUN_ID>/
    """

    def __init__(self, artifact_uri):
        if not is_valid_dbfs_uri(artifact_uri):
            raise MlflowException(
                message="DBFS URI must be of the form dbfs:/<path> or "
                + "dbfs://profile@databricks/<path>",
                error_code=INVALID_PARAMETER_VALUE,
            )
        if not is_databricks_acled_artifacts_uri(artifact_uri):
            raise MlflowException(
                message=(
                    "Artifact URI incorrect. Expected path prefix to be"
                    " databricks/mlflow-tracking/path/to/artifact/.."
                ),
                error_code=INVALID_PARAMETER_VALUE,
            )
        # The dbfs:/ path ultimately used for artifact operations should not contain the
        # Databricks profile info, so strip it before setting ``artifact_uri``.
        super().__init__(remove_databricks_profile_info_from_artifact_uri(artifact_uri))
        self.thread_pool = ThreadPoolExecutor(max_workers=self.max_workers)

        self.databricks_profile_uri = (
            get_databricks_profile_uri_from_artifact_uri(artifact_uri)
            or get_tracking_uri()
        )
        self.run_id = self._extract_run_id(self.artifact_uri)
        # Fetch the artifact root for the MLflow Run associated with `artifact_uri` and compute
        # the path of `artifact_uri` relative to the MLflow Run's artifact root
        # (the `run_relative_artifact_repo_root_path`). All operations performed on this artifact
        # repository will be performed relative to this computed location
        artifact_repo_root_path = extract_and_normalize_path(artifact_uri)
        run_artifact_root_uri = self._get_run_artifact_root(self.run_id)
        run_artifact_root_path = extract_and_normalize_path(run_artifact_root_uri)
        run_relative_root_path = posixpath.relpath(
            path=artifact_repo_root_path, start=run_artifact_root_path
        )
        # If the paths are equal, then use empty string over "./" for ListArtifact compatibility.
        self.run_relative_artifact_repo_root_path = (
            ""
            if run_artifact_root_path == artifact_repo_root_path
            else run_relative_root_path
        )

    @property
    def max_workers(self) -> int:
        """Compute the number of workers to use for multi-threading."""
        num_cpus = os.cpu_count() or _NUM_DEFAULT_CPUS
        return min(num_cpus * _NUM_MAX_THREADS_PER_CPU, _NUM_MAX_THREADS)

    @staticmethod
    def _extract_run_id(artifact_uri):
        """
        The artifact_uri is expected to be
        dbfs:/databricks/mlflow-tracking/<EXP_ID>/<RUN_ID>/artifacts/<path>
        Once the path from the input uri is extracted and normalized, it is
        expected to be of the form
        databricks/mlflow-tracking/<EXP_ID>/<RUN_ID>/artifacts/<path>

        Hence the run_id is the 4th element of the normalized path.

        :return: run_id extracted from the artifact_uri
        """
        artifact_path = extract_and_normalize_path(artifact_uri)
        return artifact_path.split("/")[3]

    def _call_endpoint(self, service, api, json_body):
        db_creds = get_databricks_host_creds(self.databricks_profile_uri)
        endpoint, method = _SERVICE_AND_METHOD_TO_INFO[service][api]
        response_proto = api.Response()
        return call_endpoint(db_creds, endpoint, method, json_body, response_proto)

    def _get_run_artifact_root(self, run_id):
        json_body = message_to_json(GetRun(run_id=run_id))
        run_response = self._call_endpoint(MlflowService, GetRun, json_body)
        return run_response.run.info.artifact_uri

    def _get_credential_infos(self, request_message_class, run_id, paths):
        """
        Issue one or more requests for artifact credentials, providing read or write
        access to the specified run-relative artifact `paths` within the MLflow Run specified
        by `run_id`. The type of access credentials, read or write, is specified by
        `request_message_class`.

        :return: A list of `ArtifactCredentialInfo` objects providing read access to the specified
                 run-relative artifact `paths` within the MLflow Run specified by `run_id`.
        """
        credential_infos = []

        for paths_chunk in chunk_list(paths, _MAX_CREDENTIALS_REQUEST_SIZE):
            page_token = None
            while True:
                json_body = message_to_json(
                    request_message_class(
                        run_id=run_id, path=paths_chunk, page_token=page_token
                    )
                )
                response = self._call_endpoint(
                    DatabricksMlflowArtifactsService, request_message_class, json_body
                )
                credential_infos += response.credential_infos
                page_token = response.next_page_token
                if not page_token or len(response.credential_infos) == 0:
                    break

        return credential_infos

    def _get_write_credential_infos(self, run_id, paths):
        """
        :return: A list of `ArtifactCredentialInfo` objects providing write access to the specified
                 run-relative artifact `paths` within the MLflow Run specified by `run_id`.
        """
        return self._get_credential_infos(GetCredentialsForWrite, run_id, paths)

    def _get_read_credential_infos(self, run_id, paths):
        """
        :return: A list of `ArtifactCredentialInfo` objects providing read access to the specified
                 run-relative artifact `paths` within the MLflow Run specified by `run_id`.
        """
        return self._get_credential_infos(GetCredentialsForRead, run_id, paths)

    def _extract_headers_from_credentials(self, headers):
        """
        :return: A python dictionary of http headers converted from the protobuf credentials
        """
        return {header.name: header.value for header in headers}

    def _azure_upload_file(self, credentials, local_file, artifact_path):
        """
        Uploads a file to a given Azure storage location.
        The function uses a file chunking generator with 100 MB being the size limit for each chunk.
        This limit is imposed by the stage_block API in azure-storage-blob.
        In the case the file size is large and the upload takes longer than the validity of the
        given credentials, a new set of credentials are generated and the operation continues. This
        is the reason for the first nested try-except block
        Finally, since the prevailing credentials could expire in the time between the last
        stage_block and the commit, a second try-except block refreshes credentials if needed.
        """
        try:
            headers = self._extract_headers_from_credentials(credentials.headers)
            uploading_block_list = list()
            for chunk in yield_file_in_chunks(local_file, _AZURE_MAX_BLOCK_CHUNK_SIZE):
                # Base64-encode a UUID, producing a UTF8-encoded bytestring. Then, decode
                # the bytestring for compliance with Azure Blob Storage API requests
                block_id = base64.b64encode(uuid.uuid4().hex.encode()).decode("utf-8")
                try:
                    put_block(credentials.signed_uri, block_id, chunk, headers=headers)
                except requests.HTTPError as e:
                    if e.response.status_code in [401, 403]:
                        _logger.info(
                            "Failed to authorize request, possibly due to credential expiration."
                            " Refreshing credentials and trying again..."
                        )
                        credential_info = self._get_write_credential_infos(
                            run_id=self.run_id, paths=[artifact_path]
                        )[0]
                        put_block(
                            credential_info.signed_uri, block_id, chunk, headers=headers
                        )
                    else:
                        raise e
                uploading_block_list.append(block_id)
            try:
                put_block_list(
                    credentials.signed_uri, uploading_block_list, headers=headers
                )
            except requests.HTTPError as e:
                if e.response.status_code in [401, 403]:
                    _logger.info(
                        "Failed to authorize request, possibly due to credential expiration."
                        " Refreshing credentials and trying again..."
                    )
                    credential_info = self._get_write_credential_infos(
                        run_id=self.run_id, paths=[artifact_path]
                    )[0]
                    put_block_list(
                        credential_info.signed_uri,
                        uploading_block_list,
                        headers=headers,
                    )
                else:
                    raise e
        except Exception as err:
            raise MlflowException(err)

    def _retryable_adls_function(self, func, artifact_path, **kwargs):
        # Attempt to call the passed function.  Retry if the credentials have expired
        try:
            func(**kwargs)
        except requests.HTTPError as e:
            if e.response.status_code in [403]:
                _logger.info(
                    "Failed to authorize ADLS operation, possibly due "
                    "to credential expiration. Refreshing credentials and trying again..."
                )
                new_credentials = self._get_write_credential_infos(
                    run_id=self.run_id, paths=[artifact_path]
                )[0]
                kwargs["sas_url"] = new_credentials.signed_uri
                func(**kwargs)
            else:
                raise e

    def _azure_adls_gen2_upload_file(self, credentials, local_file, artifact_path):
        """
        Uploads a file to a given Azure storage location using the ADLS gen2 API.
        """
        try:
            headers = self._extract_headers_from_credentials(credentials.headers)

            # try to create the file
            self._retryable_adls_function(
                func=put_adls_file_creation,
                artifact_path=artifact_path,
                sas_url=credentials.signed_uri,
                headers=headers,
            )

            # next try to append the file
            cursor_position = 0
            use_single_part_upload = False
            for idx, chunk in enumerate(
                    yield_file_in_chunks(local_file, _AZURE_MAX_BLOCK_CHUNK_SIZE)
            ):
                cur_chunk_size = len(chunk)
                if idx == 0 and cur_chunk_size < _AZURE_MAX_BLOCK_CHUNK_SIZE:
                    use_single_part_upload = True
                self._retryable_adls_function(
                    func=patch_adls_file_upload,
                    artifact_path=artifact_path,
                    sas_url=credentials.signed_uri,
                    data=chunk,
                    position=cursor_position,
                    headers=headers,
                    is_single=use_single_part_upload,
                )
                cursor_position += cur_chunk_size

            # finally try to flush the file
            if not use_single_part_upload:
                self._retryable_adls_function(
                    func=patch_adls_flush,
                    artifact_path=artifact_path,
                    sas_url=credentials.signed_uri,
                    position=cursor_position,
                    headers=headers,
                )
        except Exception as err:
            raise MlflowException(err)

    def _signed_url_upload_file(self, credentials, local_file):
        try:
            headers = self._extract_headers_from_credentials(credentials.headers)
            signed_write_uri = credentials.signed_uri
            # Putting an empty file in a request by reading file bytes gives 501 error.
            if os.stat(local_file).st_size == 0:
                with rest_utils.cloud_storage_http_request(
                    "put", signed_write_uri, data="", headers=headers
                ) as response:
                    augmented_raise_for_status(response)
            else:
                with open(local_file, "rb") as file:
                    with rest_utils.cloud_storage_http_request(
                        "put", signed_write_uri, data=file, headers=headers
                    ) as response:
                        augmented_raise_for_status(response)
        except Exception as err:
            raise MlflowException(err)

    def _upload_to_cloud(
        self, cloud_credential_info, src_file_path, dst_run_relative_artifact_path
    ):
        """
        Upload a local file to the specified run-relative `dst_run_relative_artifact_path` using
        the supplied `cloud_credential_info`.
        """
        if cloud_credential_info.type == ArtifactCredentialType.AZURE_SAS_URI:
            self._azure_upload_file(
                cloud_credential_info, src_file_path, dst_run_relative_artifact_path
            )
        elif cloud_credential_info.type == ArtifactCredentialType.AZURE_ADLS_GEN2_SAS_URI:
            self._azure_adls_gen2_upload_file(
                cloud_credential_info, src_file_path, dst_run_relative_artifact_path
            )
        elif cloud_credential_info.type in [
            ArtifactCredentialType.AWS_PRESIGNED_URL,
            ArtifactCredentialType.GCP_SIGNED_URL,
        ]:
            self._signed_url_upload_file(cloud_credential_info, src_file_path)
        else:
            raise MlflowException(
                message="Cloud provider not supported.", error_code=INTERNAL_ERROR
            )

    def _download_from_cloud(self, cloud_credential_info, dst_local_file_path):
        """
        Download a file from the input `cloud_credential_info` and save it to `dst_local_file_path`.
        """
        if cloud_credential_info.type not in [
            ArtifactCredentialType.AZURE_SAS_URI,
            ArtifactCredentialType.AZURE_ADLS_GEN2_SAS_URI,
            ArtifactCredentialType.AWS_PRESIGNED_URL,
            ArtifactCredentialType.GCP_SIGNED_URL,
        ]:
            raise MlflowException(
                message="Cloud provider not supported.", error_code=INTERNAL_ERROR
            )
        try:
            download_file_using_http_uri(
                cloud_credential_info.signed_uri,
                dst_local_file_path,
                _DOWNLOAD_CHUNK_SIZE,
                self._extract_headers_from_credentials(cloud_credential_info.headers),
            )
        except Exception as err:
            raise MlflowException(err)

    def _get_run_relative_artifact_path_for_upload(
        self, src_file_path, dst_artifact_dir
    ):
        """
        Obtain the run-relative destination artifact path for uploading the file specified by
        `src_file_path` to the artifact directory specified by `dst_artifact_dir` within the
        MLflow Run associated with the artifact repository.

        :param src_file_path: The path to the source file on the local filesystem.
        :param dst_artifact_dir: The destination artifact directory, specified as a POSIX-style
                                 path relative to the artifact repository's root URI (note that
                                 this is not equivalent to the associated MLflow Run's artifact
                                 root location).
        :return: A POSIX-style artifact path to be used as the destination for the file upload.
                 This path is specified relative to the root of the MLflow Run associated with
                 the artifact repository.
        """
        basename = os.path.basename(src_file_path)
        dst_artifact_dir = dst_artifact_dir or ""
        dst_artifact_dir = posixpath.join(dst_artifact_dir, basename)
        if len(dst_artifact_dir) > 0:
            run_relative_artifact_path = posixpath.join(
                self.run_relative_artifact_repo_root_path, dst_artifact_dir
            )
        else:
            run_relative_artifact_path = self.run_relative_artifact_repo_root_path
        return run_relative_artifact_path

    def log_artifact(self, local_file, artifact_path=None):
        run_relative_artifact_path = self._get_run_relative_artifact_path_for_upload(
            src_file_path=local_file,
            dst_artifact_dir=artifact_path,
        )
        write_credential_info = self._get_write_credential_infos(
            run_id=self.run_id, paths=[run_relative_artifact_path]
        )[0]
        self._upload_to_cloud(
            cloud_credential_info=write_credential_info,
            src_file_path=local_file,
            dst_run_relative_artifact_path=run_relative_artifact_path,
        )

    def log_artifacts(self, local_dir, artifact_path=None):
        """
        Parallelized implementation of `download_artifacts` for Databricks.
        """
        StagedArtifactUpload = namedtuple(
            "StagedArtifactUpload",
            [
                # Local filesystem path of the source file to upload
                "src_file_path",
                # Run-relative artifact path specifying the upload destination
                "dst_run_relative_artifact_path",
            ],
        )

        artifact_path = artifact_path or ""

        staged_uploads = []
        for (dirpath, _, filenames) in os.walk(local_dir):
            artifact_subdir = artifact_path
            if dirpath != local_dir:
                rel_path = os.path.relpath(dirpath, local_dir)
                rel_path = relative_path_to_artifact_path(rel_path)
                artifact_subdir = posixpath.join(artifact_path, rel_path)
            for name in filenames:
                file_path = os.path.join(dirpath, name)
                dst_run_relative_artifact_path = (
                    self._get_run_relative_artifact_path_for_upload(
                        src_file_path=file_path,
                        dst_artifact_dir=artifact_subdir,
                    )
                )
                staged_uploads.append(
                    StagedArtifactUpload(
                        src_file_path=file_path,
                        dst_run_relative_artifact_path=dst_run_relative_artifact_path,
                    )
                )

        write_credential_infos = self._get_write_credential_infos(
            run_id=self.run_id,
            paths=[
                staged_upload.dst_run_relative_artifact_path
                for staged_upload in staged_uploads
            ],
        )

        inflight_uploads = {}
        for staged_upload, write_credential_info in zip(
            staged_uploads, write_credential_infos
        ):
            upload_future = self.thread_pool.submit(
                self._upload_to_cloud,
                cloud_credential_info=write_credential_info,
                src_file_path=staged_upload.src_file_path,
                dst_run_relative_artifact_path=staged_upload.dst_run_relative_artifact_path,
            )
            inflight_uploads[staged_upload.src_file_path] = upload_future

        # Join futures to ensure that all artifacts have been uploaded prior to returning
        failed_uploads = {}
        for (src_file_path, upload_future) in inflight_uploads.items():
            try:
                upload_future.result()
            except Exception as e:
                failed_uploads[src_file_path] = repr(e)

        if len(failed_uploads) > 0:
            raise MlflowException(
                message=(
                    "The following failures occurred while uploading one or more artifacts"
                    " to {artifact_root}: {failures}".format(
                        artifact_root=self.artifact_uri,
                        failures=failed_uploads,
                    )
                )
            )

    def list_artifacts(self, path=None):
        if path:
            run_relative_path = posixpath.join(
                self.run_relative_artifact_repo_root_path, path
            )
        else:
            run_relative_path = self.run_relative_artifact_repo_root_path
        infos = []
        page_token = None
        while True:
            json_body = message_to_json(
                ListArtifacts(
                    run_id=self.run_id, path=run_relative_path, page_token=page_token
                )
            )
            response = self._call_endpoint(MlflowService, ListArtifacts, json_body)
            artifact_list = response.files
            # If `path` is a file, ListArtifacts returns a single list element with the
            # same name as `path`. The list_artifacts API expects us to return an empty list in this
            # case, so we do so here.
            if (
                len(artifact_list) == 1
                and artifact_list[0].path == run_relative_path
                and not artifact_list[0].is_dir
            ):
                return []
            for output_file in artifact_list:
                file_rel_path = posixpath.relpath(
                    path=output_file.path,
                    start=self.run_relative_artifact_repo_root_path,
                )
                artifact_size = None if output_file.is_dir else output_file.file_size
                infos.append(FileInfo(file_rel_path, output_file.is_dir, artifact_size))
            if len(artifact_list) == 0 or not response.next_page_token:
                break
            page_token = response.next_page_token
        return infos

    def _download_file(self, remote_file_path, local_path):
        run_relative_remote_file_path = posixpath.join(
            self.run_relative_artifact_repo_root_path, remote_file_path
        )
        read_credentials = self._get_read_credential_infos(
            run_id=self.run_id, paths=[run_relative_remote_file_path]
        )
        # Read credentials for only one file were requested. So we expected only one value in
        # the response.
        assert len(read_credentials) == 1
        self._download_from_cloud(
            cloud_credential_info=read_credentials[0], dst_local_file_path=local_path
        )

    def delete_artifacts(self, artifact_path=None):
        raise MlflowException("Not implemented yet")
