#!/usr/bin/env python3

import boto3
from botocore.exceptions import ClientError
import argparse
import sys
import shlex
import re
import os
import signal
import configparser
from subprocess import Popen, PIPE

###############################################################################
#
# Constants
###############################################################################
__version__ = '0.0.12'

DEFAULT_MFA_CONFIG_FILE_DIR = "~/.aws/"
DEFAULT_MFA_CONFIG_FILE_NAME = "mfa-config"
DEFAULT_MFA_CONFIG_FILE_PATH = os.path.expanduser(os.path.join(DEFAULT_MFA_CONFIG_FILE_DIR,
                                DEFAULT_MFA_CONFIG_FILE_NAME))
DEFAULT_MFA_DURATION = 36000
AWS_CREDENTIALS_FILE_PATH = os.path.expanduser("~/.aws/credentials")

YKMAN_BINARY_NAME = "ykman"


################################################################################
#
# Functions
################################################################################

# Handle signal
def signal_handler(signal, frame):
    print("Error: SIGINT received.")
    sys.exit(1)

# Input Validators
def arg_config_file(arg_string):
    # Validate arg_string below
    config_file_path = os.path.expandvars(arg_string)
    config_file_path = os.path.expanduser(config_file_path)
    if not os.path.isfile(config_file_path):
        raise argparse.ArgumentTypeError("Config file not found: {0}".format(config_file_path))
    return arg_string

def arg_profile(arg_string):
    not_valid = False
    # Validate arg_string below
    if not_valid:
        raise argparse.ArgumentTypeError(
                "Not a valid profile name '{0}'.".format(arg_string))
    return arg_string

def validate_token_code(arg_string):
    # Validate arg_string below
    if not re.match(r"\d{6}", arg_string):
        raise ValueError(
                "token_code must be 6 digits: '{0}'.".format(arg_string))
    return arg_string


def parse_args():
    parser = argparse.ArgumentParser(usage="%(prog)s [options]",
                                     description="updates aws credentials file with temporary sts credentials obtained with mfa",
                                     formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('-d', '--debug', action='store_true', dest='debug', default=False, help='Enable debug')
    parser.add_argument('-c', '--config-file', dest='config_file', action="store",
                        help='config file to load mfa details [~/.aws/mfa-config]', default=DEFAULT_MFA_CONFIG_FILE_PATH,
                        type=arg_config_file)
    parser.add_argument('-p', '--profile', dest='profile', action="store",
                        help='profile to be loaded from the config file [default]', default="default",
                        type=arg_profile)
    return parser.parse_args()


def read_config(config_file_path):
    # If any errors with the config file, raise them here
    # Otherwise return the config as a dict
    parsed_config = {}
    config_file_path = os.path.expandvars(config_file_path)
    config_file_path = os.path.expanduser(config_file_path)
    if not os.path.isfile(config_file_path):
        raise ValueError("Config file not found: {0}".format(config_file_path))
    config = configparser.ConfigParser()
    config.read(config_file_path)
    if len(config.sections()) == 0:
        raise ValueError("Error parsing config file: {0}".format(config_file_path))
    for section in config.sections():
        if section.startswith("profile"):
            if not config[section]["mfa_serial"]:
               continue
               # raise ValueError("Not mfa_serial defined in section: {0}".format(section))
            if not config[section]["dest_profile"]:
               continue
               # raise ValueError("No dest_profile defined in section: {0}".format(section))
            try:
                split_section = shlex.split(section)
            except ValueError:
                continue
            if len(split_section) == 2:
                parsed_config[split_section[1]] = config[section]
        elif section == "default":
            # default is a special section in boto3 that doesn't require
            # profile as a prefix. why change here?
            # https://github.com/boto/botocore/blob/c6ebb3be3acc946e3b706333294320dc7e304dd7/botocore/configloader.py#L265
            parsed_config[section] = config[section]
    return parsed_config


def read_credentials(aws_credentials_file):
    cred_config = configparser.ConfigParser()
    cred_config.read(aws_credentials_file)
    return cred_config


def update_config(aws_credentials_file, dest_profile, credential_dict):
    credentials_config = read_credentials(aws_credentials_file)
    for section in credentials_config.sections():
        if section == dest_profile:
            credentials_config[section]["aws_access_key_id"] = credential_dict["AccessKeyId"]
            credentials_config[section]["aws_secret_access_key"] = credential_dict["SecretAccessKey"]
            credentials_config[section]["aws_session_token"] = credential_dict["SessionToken"]
            with open(aws_credentials_file, 'w') as credfile:
                credentials_config.write(credfile)
            return
    credentials_config[dest_profile] = {}
    credentials_config[dest_profile]["aws_access_key_id"] = credential_dict["AccessKeyId"]
    credentials_config[dest_profile]["aws_secret_access_key"] = credential_dict["SecretAccessKey"]
    credentials_config[dest_profile]["aws_session_token"] = credential_dict["SessionToken"]
    with open(aws_credentials_file, 'w') as credfile:
        credentials_config.write(credfile)
    return


def get_credential_dict(source_profile, mfa_serial, region, token_code, mfa_duration):
    conn_args = {}
    if source_profile:
        conn_args["profile_name"] = source_profile
    if region:
        conn_args["region_name"] = region
    if not mfa_duration:
        mfa_duration = DEFAULT_MFA_DURATION
    session = boto3.session.Session(**conn_args)
    sts_client = session.client("sts")
    get_session_token_query = {
        "DurationSeconds": mfa_duration,
        "SerialNumber": mfa_serial,
        "TokenCode": token_code
    }
    get_session_token_response = None
    try:
        get_session_token_response = sts_client.get_session_token(**get_session_token_query)
    except ClientError as err:
        if err.response['Error']['Code'] == 'AccessDenied':
            print("Invalid Toke Code. Exiting")
            sys.exit(1)
    except Exception as err:
        raise ValueError("{0}".format(err))
    return get_session_token_response.get('Credentials')


def run_command(cmd, env=os.environ, cmd_input=None, stdin=PIPE, stdout=PIPE):
    command = Popen(cmd, env=env, stdin=stdin, stdout=stdout)
    communicate_args = {}
    if cmd_input is not None:
        communicate_args['input'] = cmd_input.encode()
    (std_out, std_err) = command.communicate(**communicate_args)
    if std_err is not None:
        raise Exception(std_err.decode("utf-8"))
    return std_out.decode("utf-8")

def is_ykman_installed():
    # Checks if ykman program is in path
    command = [YKMAN_BINARY_NAME]
    try:
        run_command(command)
    except FileNotFoundError:
        return False
    return True


def get_token_code(config):
    # If yubikey_credential_name is defined get the token from a yubikey, otherwise ask the user
    if config.get("yubikey_credential_name", False):
        token_code = get_token_code_from_yubikey(yubikey_credential_name=config.get("yubikey_credential_name", False))
    else:
        token_code = get_token_code_from_user(mfa_serial=config["mfa_serial"])
    return token_code
    

def get_token_code_from_yubikey(yubikey_credential_name):
    # Requires ykman installed on system
    # use the ykman utility to retrieve a token_code from a yubikey
    # yubikey_credential_name is the of the form ISSUER:ACCOUNT_NAME ie. Github:theorlandog
    # for a list of configured yubikey_credential_names run `ykman oath list`
    if not is_ykman_installed():
        raise Exception("{ykman_binary_name} not found in path".format(ykman_binary_name=YKMAN_BINARY_NAME))
    command = [YKMAN_BINARY_NAME, "oath", "code", yubikey_credential_name]
    result = run_command(command)
    if result.startswith(yubikey_credential_name):
        token_code = result.split()[1]
        return token_code
    else:
        raise Exception("YubiKey credential named {yubikey_credential_name} not found.".format(yubikey_credential_name=yubikey_credential_name))

def get_token_code_from_user(mfa_serial):
    token_code = validate_token_code(input("token code for {0}: ".format(mfa_serial)))
    return token_code


################################################################################
#
# Main
################################################################################

def main():
    # Parse Args and read config
    options = parse_args()
    config = read_config(options.config_file)
    token_code = get_token_code(config=config[options.profile])
    credential_dict = get_credential_dict(source_profile=config[options.profile].get("source_profile", options.profile),
                              mfa_serial=config[options.profile]["mfa_serial"],
                              mfa_duration=config[options.profile].get("mfa_duration", DEFAULT_MFA_DURATION),
                              region=config[options.profile].get("region", None), token_code=token_code)
    update_config(dest_profile=config[options.profile]["dest_profile"],
                        credential_dict=credential_dict, aws_credentials_file=AWS_CREDENTIALS_FILE_PATH)

if __name__ == '__main__':
    signal.signal(signal.SIGINT, signal_handler)
    main()
    sys.exit(0)

