﻿# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.

import base64
from collections import OrderedDict
from datetime import datetime
import hashlib
import hmac
from http import HTTPStatus
from time import mktime
import typing
from typing import List, Optional
from wsgiref.handlers import format_date_time

import json
import urllib
import urllib.parse

import adal
import dateutil.parser

from cdm.utilities import StorageUtils
from cdm.utilities.network.cdm_http_client import CdmHttpClient
from cdm.storage.network import NetworkAdapter

from .base import StorageAdapterBase


class ADLSAdapter(NetworkAdapter, StorageAdapterBase):
    """Azure Data Lake Storage Gen2 storage adapter"""

    ADLS_DEFAULT_TIMEOUT = 9000

    def __init__(self, hostname: Optional[str] = None, root: Optional[str] = None, **kwargs) -> None:
        super().__init__()
        super(NetworkAdapter, self).__init__()
        super(StorageAdapterBase, self).__init__()

        # --- internal ---
        self._adapter_paths = {}  # type: Dict[str, str]
        self._root_blob_contrainer = None  # type: Optional[str]
        self._formatted_hostname = None  # type: Optional[str]
        self._http_authorization = 'Authorization'
        self._http_client = CdmHttpClient()  # type: CdmHttpClient
        self._http_xms_date = 'x-ms-date'
        self._http_xms_version = 'x-ms-version'
        self._resource = "https://storage.azure.com"  # type: Optional[str]
        self._type = 'adls'
        self._root = None
        self._unescaped_root_sub_path = None # type: Optional[str]
        self._escaped_root_sub_path = None # type: Optional[str]
        self._file_modified_time_cache = {}  # type: Dict[str, datetime]
        self.timeout = self.ADLS_DEFAULT_TIMEOUT # type: int

        if root and hostname:
            self.root = root  # type: Optional[str]
            self.hostname = hostname  # type: Optional[str]
            self.client_id = kwargs.get('client_id', None)  # type: Optional[str]
            self.secret = kwargs.get('secret', None)  # type: Optional[str]
            self.shared_key = kwargs.get('shared_key', None)  # type: Optional[str]
            self.token_provider = kwargs.get('token_provider', None) # type: Optional[TokenProvider]

            # --- internal ---
            self._tenant = kwargs.get('tenant', None)  # type: Optional[str]
            self._auth_context = adal.AuthenticationContext('https://login.windows.net/' + self.tenant) if self.tenant else None

    @property
    def hostname(self) -> str:
        return self._hostname

    @hostname.setter
    def hostname(self, value: str):
        self._hostname = value
        self._formatted_hostname = self._format_hostname(self._hostname)

    @property
    def root(self) -> str:
        return self._root

    @root.setter
    def root(self, value: str):
        self._root = self._extract_root_blob_container_and_sub_path(value)

    @property
    def tenant(self) -> str:
        return self._tenant

    def can_read(self) -> bool:
        return True

    def can_write(self) -> bool:
        return True

    def clear_cache(self) -> None:
        self._file_modified_time_cache.clear()

    async def compute_last_modified_time_async(self, corpus_path: str) -> Optional[datetime]:
        cachedValue = None
        if self._is_cache_enabled:
             cachedValue = self._file_modified_time_cache.get(corpus_path)

        if cachedValue is not None:
            return cachedValue        
        else:
            adapter_path = self.create_adapter_path(corpus_path)
            request = self._build_request(adapter_path, 'HEAD')

            try:
                cdm_response = await self._http_client._send_async(request, self.wait_time_callback)
                if cdm_response.status_code == HTTPStatus.OK:
                    lastTime = dateutil.parser.parse(typing.cast(str, cdm_response.response_headers['Last-Modified']))
                    if lastTime is not None and self._is_cache_enabled:
                        self._file_modified_time_cache[corpus_path] = lastTime
                    return lastTime
            except Exception:
                pass

            return None

    def create_adapter_path(self, corpus_path: str) -> str:
        if corpus_path and corpus_path.startswith('//'):
            corpus_path = corpus_path[1:]

        formatted_corpus_path = self._format_corpus_path(corpus_path)
        if formatted_corpus_path is None:
            return None

        if formatted_corpus_path in self._adapter_paths:
            return self._adapter_paths[formatted_corpus_path]
        else:
            return 'https://' + self.hostname + self._get_escaped_root() + self._escape_path(formatted_corpus_path)

    def create_corpus_path(self, adapter_path: str) -> Optional[str]:
        if adapter_path:
            start_index = len('https://')
            end_index = adapter_path.find('/', start_index + 1)

            if end_index < start_index:
                raise Exception('Unexpected adapter path:', adapter_path)

            hostname = self._format_hostname(adapter_path[start_index:end_index])

            if hostname == self._formatted_hostname and adapter_path[end_index:].startswith(self._get_escaped_root()):
                escaped_corpus_path = adapter_path[end_index + len(self._get_escaped_root()):]
                corpus_path = urllib.parse.unquote(escaped_corpus_path)

                if corpus_path not in self._adapter_paths:
                    self._adapter_paths[corpus_path] = adapter_path

                return corpus_path

        # Signal that we did not recognize path as one for this adapter.
        return None

    async def fetch_all_files_async(self, folder_corpus_path: str) -> List[str]:
        if folder_corpus_path is None:
            return None

        url = 'https://{}/{}'.format(self._formatted_hostname, self._root_blob_contrainer)
        escaped_folder_corpus_path = self._escape_path(folder_corpus_path)
        directory = self._escaped_root_sub_path + self._format_corpus_path(escaped_folder_corpus_path)
        if directory.startswith('/'):
            directory = directory[1:]

        request = self._build_request('{}?directory={}&recursive=True&resource=filesystem'.format(url, directory), 'GET')
        cdm_response = await self._http_client._send_async(request, self.wait_time_callback)

        if cdm_response.status_code == HTTPStatus.OK:
            results = []
            data = json.loads(cdm_response.content)

            for path in data['paths']:
                if 'isDirectory' not in path or path['isDirectory'] != 'true':
                    name = path['name']  # type: str
                    name_without_root_sub_path = name[len(self._unescaped_root_sub_path) + 1:] if self._unescaped_root_sub_path and name.startswith(self._unescaped_root_sub_path) else name

                    filepath = self._format_corpus_path(name_without_root_sub_path)
                    results.append(filepath)

                    lastTimeString = path.get('lastModified')
                    if lastTimeString is not None and self._is_cache_enabled:
                        self._file_modified_time_cache[filepath] = dateutil.parser.parse(lastTimeString)

            return results

        return None

    def fetch_config(self) -> str:
        result_config = {'type': self._type}

        config_object = {
            'hostname': self.hostname,
            'root': self.root
        }

        # Check for clientId auth, we won't write shared key or secrets to JSON.
        if self.client_id and self.tenant:
            config_object['tenant'] = self.tenant
            config_object['clientId'] = self.client_id

        # Try constructing network configs.
        config_object.update(self.fetch_network_config())

        if self.location_hint:
            config_object['locationHint'] = self.location_hint

        result_config['config'] = config_object

        return json.dumps(result_config)

    async def read_async(self, corpus_path: str) -> str:
        url = self.create_adapter_path(corpus_path)
        request = self._build_request(url, 'GET')

        return await super()._read(request)

    def update_config(self, config: str):
        configs_json = json.loads(config)

        if configs_json.get('root'):
            self.root = configs_json['root']
        else:
            raise Exception('Root has to be set for ADLS adapter.')

        if configs_json.get('hostname'):
            self.hostname = configs_json['hostname']
        else:
            raise Exception('Hostname has to be set for ADLS adapter.')

        self.update_network_config(config)

        # Check first for clientId/secret auth.
        if configs_json.get('tenant') and configs_json.get('clientId'):
            self._tenant = configs_json['tenant']
            self.client_id = configs_json['clientId']

            # Check for a secret, we don't really care is it there, but it is nice if it is.
            if configs_json.get('secret'):
                self.secret = configs_json['secret']

        # Check then for shared key auth.
        if configs_json.get('sharedKey'):
            self.shared_key = configs_json['sharedKey']

        if configs_json.get('locationHint'):
            self.location_hint = configs_json['locationHint']

        self._auth_context = adal.AuthenticationContext('https://login.windows.net/' + self.tenant) if self.tenant else None

    async def write_async(self, corpus_path: str, data: str) -> None:
        url = self.create_adapter_path(corpus_path)

        request = self._build_request(url + '?resource=file', 'PUT')

        await self._http_client._send_async(request, self.wait_time_callback)

        request = self._build_request(url + '?action=append&position=0', 'PATCH', data, 'application/json; charset=utf-8')

        await self._http_client._send_async(request, self.wait_time_callback)

        request = self._build_request(url + '?action=flush&position=' + str(len(data)), 'PATCH')

        await self._http_client._send_async(request, self.wait_time_callback)

    def _apply_shared_key(self, shared_key: str, url: str, method: str, content: Optional[str] = None, content_type: Optional[str] = None):
        headers = OrderedDict()
        headers[self._http_xms_date] = format_date_time(mktime(datetime.now().timetuple()))
        headers[self._http_xms_version] = '2018-06-17'

        content_length = 0

        if content is not None:
            content_length = len(content)

        uri = urllib.parse.urlparse(url)
        builder = []
        builder.append(method)  # Verb.
        builder.append('\n')  # Verb.
        builder.append('\n')  # Content-Encoding.
        builder.append('\n')  # Content-Language.
        builder.append(str(content_length) + '\n' if content_length else '\n')  # Content length.
        builder.append('\n')  # Content-md5.
        builder.append(content_type + '\n' if content_type else '\n')  # Content-type.
        builder.append('\n')  # Date.
        builder.append('\n')  # If-modified-since.
        builder.append('\n')  # If-match.
        builder.append('\n')  # If-none-match.
        builder.append('\n')  # If-unmodified-since.
        builder.append('\n')  # Range.

        for key, value in headers.items():
            builder.append('{0}:{1}\n'.format(key, value))

        # append canonicalized resource.
        account_name = uri.netloc.split('.')[0]
        builder.append('/')
        builder.append(account_name)
        builder.append(uri.path)

        # append canonicalized queries.
        if uri.query:
            query_parameters = uri.query.split('&')  # type: List[str]

            for parameter in query_parameters:
                key_value_pair = parameter.split('=')
                builder.append('\n{}:{}'.format(key_value_pair[0], urllib.parse.unquote(key_value_pair[1])))

        # Hash the payload.
        data_to_hash = ''.join(builder).rstrip()
        shared_key_bytes = self._try_from_base64_string(shared_key)
        if not shared_key_bytes:
            raise Exception('Couldn\'t encode the shared key.')

        message = base64.b64encode(hmac.new(shared_key_bytes, msg=data_to_hash.encode('utf-8'), digestmod=hashlib.sha256).digest()).decode('utf-8')
        signed_string = 'SharedKey {}:{}'.format(account_name, message)

        headers[self._http_authorization] = signed_string

        return headers

    def _build_request(self, url: str, method: str = 'GET', content: Optional[str] = None, content_type: Optional[str] = None):
        if self.shared_key is not None:
            request = self._set_up_cdm_request(url, self._apply_shared_key(self.shared_key, url, method, content, content_type), method)
        elif self.tenant is not None and self.client_id is not None and self.secret is not None:
            token = self._generate_bearer_token()
            headers = {'Authorization': token['tokenType'] + ' ' + token['accessToken']}
            request = self._set_up_cdm_request(url, headers, method)
        elif self.token_provider is not None:
            headers = {'Authorization': self.token_provider.get_token()}
            request = self._set_up_cdm_request(url, headers, method)
        else:
            raise Exception('Adls adapter is not configured with any auth method')

        if content is not None:
            request.content = content
            request.content_type = content_type

        return request

    def _escape_path(self, unescaped_path: str):
        return urllib.parse.quote(unescaped_path).replace('%2F', '/')

    def _extract_root_blob_container_and_sub_path(self, root: str) -> None:
        # No root value was set
        if not root:
            self._root_blob_contrainer = ''
            self._update_root_sub_path('')
            return ''

        # Remove leading and trailing /
        prep_root = root[1:] if root[0] == '/' else root
        prep_root = prep_root[0: len(prep_root) - 1] if prep_root[len(prep_root) - 1] == '/' else prep_root

        # Root contains only the file-system name, e.g. "fs-name"
        if prep_root.find('/') == -1:
            self._root_blob_contrainer = prep_root
            self._update_root_sub_path('')
            return '/{}'.format(self._root_blob_contrainer)

        # Root contains file-system name and folder, e.g. "fs-name/folder/folder..."
        prep_root_array = prep_root.split('/')
        self._root_blob_contrainer = prep_root_array[0]
        self._update_root_sub_path('/'.join(prep_root_array[1:])) 
        return '/{}/{}'.format(self._root_blob_contrainer, self._unescaped_root_sub_path)

    def _format_corpus_path(self, corpus_path: str) -> str:
        path_tuple = StorageUtils.split_namespace_path(corpus_path)
        if not path_tuple:
            return None

        corpus_path = path_tuple[1]

        if corpus_path and corpus_path[0] != '/':
            corpus_path = '/' + corpus_path
        return corpus_path

    def _format_hostname(self, hostname: str) -> str:
        hostname = hostname.replace('.blob.', '.dfs.')
        port = ':443'
        if port in hostname:
            hostname = hostname[0:-len(port)]
        return hostname

    def _generate_bearer_token(self):
        # In-memory token cache is handled by adal by default.
        return self._auth_context.acquire_token_with_client_credentials(self._resource, self.client_id, self.secret)

    def _get_escaped_root(self):
        return '/' + self._root_blob_contrainer + '/' + self._escaped_root_sub_path if self._escaped_root_sub_path else '/' + self._root_blob_contrainer

    def _try_from_base64_string(self, content: str) -> bool:
        try:
            return base64.b64decode(content)
        except Exception:
            return None
    
    def _update_root_sub_path(self, value: str):
        self._unescaped_root_sub_path = value
        self._escaped_root_sub_path = self._escape_path(value)
