#!/usr/bin/env python
# git-annex external special remote program for Globus data repository
# 
# This is an addition to git-annex's built-in directory special remotes.
# 
# Install in PATH as git-annex-remote-directory
#
# TODO: Add Copyright

import os
import globus_sdk
from annexremote import Master
from annexremote import SpecialRemote
from annexremote import RemoteError
import sys
import webbrowser
import urllib.request
import shutil
import logging
import keyring
from globus_sdk import (NativeAppAuthClient, TransferClient,
                        RefreshTokenAuthorizer)
from globus_sdk.exc import GlobusAPIError
# from fair_research_login import NativeClient
versions = None

logger = logging.getLogger('globus-remote')
logger.setLevel(level=logging.ERROR)

APP_NAME = 'git-annex-globus-remote'

class GlobusRemote(SpecialRemote):

    """This is the class of Globus remotes."""

    get_input = getattr(__builtins__, 'raw_input', input)

    def __init__(self, annex):
        super(GlobusRemote, self).__init__(annex)
        self.server = None
        self.uuid = None
        self.fileprefix = None
        self.endpoint = None
        self.transfer_client = None
        self.settings = {
            'client_id': '01589ab6-70d1-4e1c-b33d-14b6af4a16be',
            'server': 'https://2a9f4.8443.dn.glob.us/',
            'redirect_uri': 'https://auth.globus.org/v2/web/auth-code',
            'scopes': ('openid email profile '
                       'urn:globus:auth:scope:transfer.api.globus.org:all')

        }

    @staticmethod
    def load_tokens():
        """Load a set of saved tokens from keyring"""
        tokens = keyring.get_password("globus-remote", "auth-tokens")
        return eval(tokens)

    @staticmethod
    def save_tokens(tokens):
        """Save a set of tokens in keyring for later use."""
        keyring.set_password("globus-remote", "auth-tokens", str(tokens))

    def update_tokens_on_refresh(self, token_response):
        """
        Callback function passed into the RefreshTokenAuthorizer
        Will be invoked any time a new access token is fetched.
        """
        keyring.delete_password("globus-remote", "auth-tokens")
        self.save_tokens(token_response.by_resource_server)

    def do_native_app_authentication(self, client_id, redirect_uri, requested_scopes=None):
        """
        Does a Native App authentication flow and returns a dict of tokens keyed by service name.
        """

        client = globus_sdk.NativeAppAuthClient(client_id=client_id)
        # pass refresh_tokens=True to request refresh tokens
        client.oauth2_start_flow(requested_scopes=requested_scopes,
                                 redirect_uri=redirect_uri,
                                 refresh_tokens=True)

        url = client.oauth2_get_authorize_url()

        if not os.environ.get('SSH_TTY', os.environ.get('SSH_CONNECTION')):
            webbrowser.open(url, new=1)

        auth_code = input().strip()

        token_response = client.oauth2_exchange_code_for_tokens(auth_code)

        # return a set of tokens, organized by resource server name
        return token_response.by_resource_server

    def setup(self):
        """
        Setup function to be run before initremote to handle things like authentication interactively
        Load credentials, obtain or refresh tokens if they are not in token files, tokens saved in keyring for privacy
        """
        tokens = None
        try:
            # if we already have tokens, load and use them
            tokens = self.load_tokens()
        except Exception as e:
            logger.info('Tokens not available, try to authenticate: ', e)

        if not tokens:
            # if we need to get tokens, start the Native App authentication process
            tokens = self.do_native_app_authentication(self.settings['client_id'],
                                                       self.settings['redirect_uri'],
                                                       self.settings['scopes'])
            try:
                self.save_tokens(tokens)
            except Exception as e:
                logger.error('Exception while saving tokens with keyring: ', e)
                sys.exit()
        self.authenticate(tokens)

    def authenticate(self, tokens=None):
        """
        Manages authentication to Globus and returns transfer client to enable operations on dataset
        """
        # get tokens from keyring
        if not tokens:
            tokens = self.load_tokens()

        transfer_tokens = tokens['transfer.api.globus.org']

        auth_client = NativeAppAuthClient(client_id=self.settings['client_id'])

        authorizer = RefreshTokenAuthorizer(
            transfer_tokens['refresh_token'],
            auth_client,
            access_token=transfer_tokens['access_token'],
            expires_at=transfer_tokens['expires_at_seconds'],
            on_refresh=self.update_tokens_on_refresh)

        self.transfer_client = TransferClient(authorizer=authorizer)

    def get_endpoint_id(self, dir_name):

        if not self.transfer_client:
            self.authenticate()

        endpoints_list = []
        for ep in self.transfer_client.endpoint_search(filter_fulltext=dir_name, num_results=None):
            endpoints_list.append(ep['id'])
        count = len(endpoints_list)

        if count == 1:
            # there is a unique id associated with the endpoint name
            return endpoints_list[0]
        elif count > 1:
            logger.error('The endpoint {0} is not unique: {1} were found'.format(dir_name, count))
            # crash
            sys.exit()
        else:
            logger.error('The endpoint {0} does not exist'.format(dir_name))
            # crash
            sys.exit()

    def set_endpoint_server(self, endpoint_id):

        if not self.transfer_client:
            self.authenticate()
        # print out endpoint details
        try:
            obj = self.transfer_client.get_endpoint(endpoint_id)
            self.server = obj['https_server']
        except GlobusAPIError as ex:
            logger.error(ex)
            if ex.http_status == 401:
                sys.exit('Refresh token has expired. '
                         'Please delete refresh-tokens.json and setup again.')
            else:
                raise ex

# *********************************************************************************************************************

    def initremote(self):
        """Requests the remote to initialize itself. Idempotent call"""

        try:
            # query uuid, fileprefix and directory name from git annex
            self.uuid = self.annex.getconfig('uuid')
            self.fileprefix = self.annex.getconfig('fileprefix')
            self.endpoint = self.annex.getconfig('endpoint')

            if not self.uuid and not self.endpoint:
                raise RemoteError("Either directory name or uuid must be given.")

            if not self.uuid and self.endpoint:
                self.uuid = self.get_endpoint_id(self.endpoint)

        except Exception as e:
            self.annex.error("Failed to initialised the remote: ", e,
                             "Run 'git-annex-remote-globus setup' to authenticate")

    def prepare(self):
        """Connection takes place between remote and git-annex"""

        try:
            self.uuid = self.annex.getconfig('uuid')
            self.fileprefix = self.annex.getconfig('fileprefix')
            self.endpoint = self.annex.getconfig('endpoint')

            if not self.uuid and not self.endpoint:
                raise RemoteError("Either directory name or uuid must be given.")

            if not self.uuid and self.endpoint:
                self.uuid = self.get_endpoint_id(self.endpoint)

            if not self.server:
                self.set_endpoint_server(self.uuid)

        except Exception as e:
            self.annex.error("Failed to initialised the remote: ", e,
                             "Run 'git-annex-remote-globus setup' to authenticate")

# *********************************************************************************************************************
    @classmethod
    def key_size(cls, key):
        """get the file size from a given key"""
        return str(key).split('-')[1].split('s')[1]

    def _get_remote_location(self, url):
        """Constructs remote location with the globus url"""
        path = str(url).split(str(self.endpoint).lower())[1]
        # construct a remote directory path
        return '/~' + path

    def _get_size(self, location):
        """gets the file size of the specified remote location"""
        # get remote file location directory
        dir_path = os.path.dirname(location)
        # and the file name to check for
        file_name = location.split('/')[-1]
        # access remote using ls operation to check if file exists
        for file in self.transfer_client.operation_ls(self.uuid, path=dir_path, num_results=None):
            # if it finds the file
            if file['name'] == file_name:
                # return file size in reply
                return file['size']

    def _check_size(self, key, url):
        """Checks file size has not changed for a given file"""
        # generates path to check for precence in git-annex branch
        # get remote location
        path = self._get_remote_location(url)
        # construct a remote directory path
        globus_size = self._get_size(path)
        key_size = GlobusRemote.key_size(key)
        return int(globus_size) == int(key_size)

    def checkpresent(self, key):
        """Indicates whether a key has been verified to be present in a remote location"""
        # make globus generate the desired file name to check for key presence
        globus_urls = self.annex.geturls(key, prefix='globus://')
        # there is multiple or missing url
        if len(globus_urls) != 1:
            self.annex.error("Could not find the globus url for the specified key"
                             "number of urls found: {0}".format(len(globus_urls)))
            return False
        else:
            return self._check_size(key, globus_urls[0])

# *********************************************************************************************************************
        
    def transfer_store(self, key, filename):
        pass

    def remove(self, key):
        pass

    def transfer_retrieve(self, key, filename):
        """Requests transfer of a key. The filename if where to store the download"""
        globus_urls = self.annex.geturls(key, prefix='globus://')
        # there is multiple or missing url
        if len(globus_urls) != 1:
            self.annex.error("Could not find the globus url for the specified key"
                             "number of urls found: {0}".format(len(globus_urls)))
        else:
            # do transfer in the given filename location and return if success
            if self._do_retrive(globus_urls[0], filename):
                return key

    def _do_retrive(self, globus_url, filename):
        """Data transfer at the given url into a temporary location, filename"""
        download_url = None
        if not self.server:
            self.server = self.settings['server']
        try:
            download_url = self.server + str(globus_url).split(str(self.endpoint).lower())[1]
        except ValueError as e:
            logger.error('Download URL error. Cannot retrieve file from server:', e, download_url)

        try:
            logger.info("Downloading {0} into {1}".format(globus_url, filename))
            # TODO: add progress bar
            with urllib.request.urlopen(str(download_url).replace(" ", "%20")) \
                    as response, open(os.path.normpath(filename), 'wb') as out_file:
                logger.info("Download in process...")
                shutil.copyfileobj(response, out_file)
        except Exception as e:
            self.annex.error("problem occurred while downloading: {0}".format(e))
            return False
        logger.info("Successfully downloaded {0} into {1}".format(globus_url, filename))
        return True

    def _is_valid(self, url):
        """Checks that the endpoint name is given in the url"""
        split_by_endpoint = str(url).split(str(self.endpoint).lower())
        if len(split_by_endpoint) == 1:
            # if the endpoint splitting value is not contained
            self.annex.error("Unsupported url")
            return False
        return True

    def claimurl(self, url):
        """Check whether it is possible to download a url given the specified protocol prefix"""
        if self._is_valid(url):
            prefix = str(url).split(str(self.endpoint).lower())[0]
            if prefix == 'globus://':
                # accept url
                return url
            else:
                self.annex.error("Unsupported prefix")

    def checkurl(self, url):
        """Check if the url's content can currently be downloaded (without downloading it)"""
        # initialize response
        reply = []
        dict_reply = dict()
        # get remote location
        path = self._get_remote_location(url)
        # get file size in remote location
        dict_reply['size'] = self._get_size(path)
        # return it as a response
        reply.append(dict_reply)
        return reply


    # Export methods
    # def transferexport_store(self, key, local_file, remote_file):
    #     location = '/'.join((self.directory, remote_file))
    #     self._do_store(key, local_file, location)
    #
    # def transferexport_retrieve(self, key, local_file, remote_file):
    #     location = '/'.join((self.directory, remote_file))
    #     self._do_retrieve(key, location, local_file)
    #
    # def checkpresentexport(self, key, remote_file):
    #     location = '/'.join((self.directory, remote_file))
    #     return self._do_checkpresent(key, location)
    #
    # def removeexport(self, key, remote_file):
    #     location = '/'.join((self.directory, remote_file))
    #     self._do_remove(key, location)
    #
    # def removeexportdirectory(self, remote_directory):
    #     location = '/'.join((self.directory, remote_directory))
    #     try:
    #         os.rmdir(location)
    #     except OSError as e:
    #         if e.errno != errno.ENOENT:
    #             raise RemoteError(e)
    #
    # def renameexport(self, key, filename, new_filename):
    #     oldlocation = '/'.join((self.directory, filename))
    #     newlocation = '/'.join((self.directory, new_filename))
    #     try:
    #         os.rename(oldlocation, newlocation)
    #     except OSError as e:
    #         raise RemoteError(e)
    #
    # def _mkdir(self, directory):
    #     try:PREPARE
    #         os.makedirs(directory)
    #     except OSError as e:
    #         if e.errno != errno.EEXIST:
    #             raise RemoteError("Failed to write to {}".format(directory))
    #
    # def _calclocation(self, key):
    #     return "{dir}/{hash}{key}".format(
    #                     dir=self.directory,
    #                     hash=self.annex.dirhash(key),
    #                     key=key)
    #
    # def _info(self, message):
    #     try:
    #         self.annex.info(message)
    #     except ProtocolError:
    #         print(message)
    #
    # def _do_store(self, key, filename, location):
    #     self._mkdir(os.path.dirname(location))
    #     templocation = '/'.join((self.directory,
    #                             'tmp',
    #                             key))
    #     self._mkdir(os.path.dirname(templocation))
    #     try:
    #         copyfile(filename, templocation)
    #         os.rename(templocation, location)
    #     except OSError as e:
    #         raise RemoteError(e)
    #     try:
    #         os.rmdir(os.path.dirname(templocation))
    #     except OSError:
    #         self._info("Could not remove tempdir (Not empty)")
    #
    # def _do_retrieve(self, key, location, filename):rt
    #     try:
    #         copyfile(location, filename)
    #     except OSError as e:
    #         raise RemoteError(e)
    #
    # def _do_checkpresent(self, key, location):
    #     if not os.path.exists(self.directory):
    #         raise RemoteError("this remote is not currently available")
    #     return os.path.isfile(location)
    #
    # def _do_remove(self, key, location):
    #     if not os.path.exists(self.directory):
    #         raise RemoteError("this remote is not currently available")
    #     try:
    #         os.remove(location)
    #     except OSError as e:
    #         # It's not a failure to remove a file that is not present.
    #         if e.errno != errno.ENOENT:
    #             raise RemoteError(e)


def main():
    if len(sys.argv) > 1:
        if sys.argv[1] == 'setup':
            print(sys.argv)
            with open(os.devnull, 'w') as devnull:
                master = Master(devnull)
                remote = GlobusRemote(master)
                remote.setup()
            return 0
    else:
        output = sys.stdout
        sys.stdout = sys.stderr
        master = Master(output)
        remote = GlobusRemote(master)
        master.LinkRemote(remote)
        master.Listen()


if __name__ == "__main__":
    main()        
