#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# 2018-02-21 Cornelius Kölbel <cornelius.koelbel@netknights.it>
#            Allow to import PSKC file
# 2017-11-21 Cornelius Kölbel <corenlius.koelbel@netknights.it>
#            export to CSV including usernames
# 2017-10-18 Cornelius Kölbel <cornelius.koelbel@netknights.it>
#            Add token export (HOTP and TOTP) to PSKC
# 2017-05-02 Friedrich Weber <friedrich.weber@netknights.it>
#            Improve token matching
# 2017-04-25 Cornelius Kölbel <cornelius.koelbel@netknights.it>
#
# Copyright (c) 2017, Cornelius Kölbel
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
from dateutil import parser
from dateutil.tz import tzlocal, tzutc
from flask.ext.script.commands import InvalidCommand

from privacyidea.lib.policy import ACTION
from privacyidea.lib.utils import parse_legacy_time
from privacyidea.lib.importotp import export_pskc

ALLOWED_ACTIONS = ["disable", "delete", "unassign", "mark", "export", "listuser"]

__doc__ = """
This script can be used to clean up the token database.

It can list, disable, delete or mark tokens based on

Conditions:

* last_auth
* orphaned tokens
* any tokeninfo
* unassigned token / token has no user
* tokentype

Actions:

* list
* unassign
* mark
* disable
* delete


    privacyidea-token-janitor find 
        --last_auth=10h|7d|2y
        --tokeninfo-key=<key>
        --tokeninfo-value=<value>
        --tokeninfo-value-greater-than=<value>
        --tokeninfo-value-less-than=<value>
        --tokeninfo-value-after=<value>
        --tokeninfo-value-before=<value>
        --unassigned
        --orphaned
        --tokentype=<type>
        --serial=<regexp>
        --description=<regexp>
        
        --action={0!s}
        
        --set-description="new description"
        --set-tokeninfo-key=<key>
        --set-tokeninfo-value=<value>    
    
.. note:: If you fail to redirect the output of this command at the commandline
   to e.g. a file with a UnicodeEncodeError, you need to set the environment
   variable like this:
       export PYTHONIOENCODING=UTF-8
   Here is a good read about it:
   https://stackoverflow.com/questions/4545661/unicodedecodeerror-when-redirecting-to-file

""".format("|".join(ALLOWED_ACTIONS))
from privacyidea.lib.token import (get_tokens, remove_token, enable_token,
                                   unassign_token, import_token)
from privacyidea.app import create_app
from flask.ext.script import Manager
import re
import sys

__version__ = "0.1"



app = create_app(config_name='production', silent=True)
manager = Manager(app)

def _try_convert_to_integer(given_value_string):
    try:
        return int(given_value_string)
    except ValueError:
        raise InvalidCommand('Not an integer: {}'.format(given_value_string))


def _compare_greater_than(key, given_value_string):
    """
    :return: a function which returns True if its parameter (converted to an integer) is greater than
    *given_value_string* (converted to an integer).
    """
    given_value = _try_convert_to_integer(given_value_string)

    def comparator(value):
        try:
            return int(value) > given_value
        except ValueError:
            return False
    return comparator


def _compare_less_than(key, given_value_string):
    """
    :return: a function which returns True if its parameter (converted to an integer) is less than
    *given_value_string* (converted to an integer).
    """
    given_value = _try_convert_to_integer(given_value_string)

    def comparator(value):
        try:
            return int(value) < given_value
        except ValueError:
            return False
    return comparator


def _parse_datetime(key, value):
    if key == ACTION.LASTAUTH:
        # Special case for last_auth: Legacy values are given in UTC time!
        last_auth = parser.parse(value)
        if not last_auth.tzinfo:
            last_auth = parser.parse(value, tzinfos=tzutc)
        return last_auth
    else:
        # Other values are given in local time
        return parser.parse(parse_legacy_time(value))


def _try_convert_to_datetime(given_value_string):
    try:
        parsed = parser.parse(given_value_string, dayfirst=True)
        if not parsed.tzinfo:
            parsed = parser.parse(given_value_string, dayfirst=True, tzinfos=tzlocal)
        return parsed
    except ValueError:
        raise InvalidCommand('Not a valid datetime format: {}'.format(given_value_string))


def _compare_after(key, given_value_string):
    """
    :return: a function which returns True if its parameter (converted to a datetime) occurs after
    *given_value_string* (converted to a datetime).
    """
    given_value = _try_convert_to_datetime(given_value_string)

    def comparator(value):
        try:
            return _parse_datetime(key, value) > given_value
        except ValueError:
            return False
    return comparator


def _compare_before(key, given_value_string):
    """
    :return: a function which returns True if its parameter (converted to a datetime) occurs before
    *given_value_string* (converted to a datetime).
    """
    given_value = _try_convert_to_datetime(given_value_string)

    def comparator(value):
        try:
            return _parse_datetime(key, value) < given_value
        except ValueError:
            return False
    return comparator


def _compare_regex(key, given_regex):
    def comparator(value):
        return re.search(given_regex, value)
    return comparator


def build_tokenvalue_filter(key,
                            tokeninfo_value,
                            tokeninfo_value_greater_than,
                            tokeninfo_value_less_than,
                            tokeninfo_value_after,
                            tokeninfo_value_before):
    """
    Build and return a token value filter, which is a list of comparator functions. Each comparator function
    takes a tokeninfo value and returns True if the user-defined criterion matches.
    The filter matches a record if *all* comparator functions return True, i.e. if the conjunction of
    all comparators returns True.
    :param key: user-supplied --tokeninfo-key= value
    :param tokeninfo_value: user-supplied --tokeninfo-value= value
    :param tokeninfo_value_greater_than: user-supplied --tokeninfo-value-greater-than= value
    :param tokeninfo_value_less_than: user-supplied --tokeninfo-value-less-than= value
    :param tokeninfo_value_after: user-supplied --tokeninfo-value-after= value
    :param tokeninfo_value_before: user-supplied --tokeninfo-value-before= value
    :return: list of functions
    """
    filter = []
    if tokeninfo_value:
        filter.append(_compare_regex(key, tokeninfo_value))
    if tokeninfo_value_greater_than:
        filter.append(_compare_greater_than(key, tokeninfo_value_greater_than))
    if tokeninfo_value_less_than:
        filter.append(_compare_less_than(key, tokeninfo_value_less_than))
    if tokeninfo_value_after:
        filter.append(_compare_after(key, tokeninfo_value_after))
    if tokeninfo_value_before:
        filter.append(_compare_before(key, tokeninfo_value_before))
    return filter


def _get_tokenlist(last_auth, assigned, active, tokeninfo_key,
                   tokeninfo_value_filter, orphaned, tokentype, serial, description):
    tlist = []
    filter_tokentype = None
    filter_active = None
    filter_assigned = None

    if assigned is not None:
        filter_assigned = assigned.lower() == "true"
    if active is not None:
        filter_active = active.lower() == "true"

    tokenobj_list = get_tokens(tokentype=tokentype,
                               active=filter_active,
                               assigned=filter_assigned)
    for token_obj in tokenobj_list:
        if last_auth and token_obj.check_last_auth_newer(last_auth):
            continue
        if serial and not re.search(serial, token_obj.token.serial):
            continue
        if description and not re.search(description,
                                         token_obj.token.description):
            continue
        if tokeninfo_value_filter and tokeninfo_key:
            value = token_obj.get_tokeninfo(tokeninfo_key)
            # if the tokeninfo key is not even set, it does not match the filter
            if value is None:
                continue
            # suppose not all comparator functions return True
            # => at least one comparator function returns False
            # => at least one user-supplied criterion does not match
            # => the token object does not match the user-supplied criteria
            if not all(comparator(value) for comparator in tokeninfo_value_filter):
                continue
        if orphaned and not token_obj.is_orphaned():
            continue

        # if everything matched, we append the token object
        tlist.append(token_obj)

    return tlist


def export_token_data(token_list):
    """
    Returns a list of tokens. Each token again is a simple list of data

    :param token_list:
    :return:
    """
    tokens = []
    for token_obj in token_list:
        token_data = []
        token_data.append(u"{0!s}".format(token_obj.token.serial))
        token_data.append(u"{0!s}".format(token_obj.token.tokentype))
        if token_obj.user:
            token_data.append(u"{0!s}".format(token_obj.user.info.get("username", "")))
            token_data.append(u"{0!s}".format(token_obj.user.info.get("givenname", "")))
            token_data.append(u"{0!s}".format(token_obj.user.info.get("surname", "")))
            token_data.append(u"{0!s}".format(token_obj.user.info.get("userid")))
            token_data.append(u"{0!s}".format(token_obj.token.resolver))
        tokens.append(token_data)
    return tokens


def export_user_data(token_list):
    """
    Returns a list of users with the information how many tokens this user has assigned

    :param token_list:
    :return:
    """
    users = {}
    userlist = {}
    for token_obj in token_list:
        if token_obj.user:
            uid = u"{0!s}|{1!s}|{2!s}|{3!s}|{4!s}".format(
                token_obj.user.info.get("username", ""),
                token_obj.user.info.get("givenname", ""),
                token_obj.user.info.get("surname", ""),
                token_obj.user.info.get("userid", ""),
                token_obj.token.resolver
            )
        else:
            uid = u"N/A"

        if uid in users.keys():
            users[uid].append(token_obj.token.serial)
        else:
            users[uid] = [token_obj.token.serial]

    return users


@manager.option('--last_auth', help='Can be something like 10h, 7d, or 2y')
@manager.option('--assigned', help='True|False|None')
@manager.option('--active', help='True|False|None')
@manager.option('--tokeninfo-value', metavar='REGEX', help='The tokeninfo value to match')
@manager.option('--tokeninfo-value-greater-than', metavar='INTEGER',
                help='Interpret tokeninfo values as integers and match only if they are greater than the given integer')
@manager.option('--tokeninfo-value-less-than', metavar='INTEGER',
                help='Interpret tokeninfo values as integers and match only if they are smaller than the given integer')
@manager.option('--tokeninfo-value-after', metavar='DATETIME',
                help='Interpret tokeninfo values as datetimes, match only if they occur after the given '
                     'ISO 8601 datetime')
@manager.option('--tokeninfo-value-before', metavar='DATETIME',
                help='Interpret tokeninfo values as datetimes, match only if they occur before the given '
                     ' ISO 8601 datetime')
@manager.option('--tokeninfo-key', help='The tokeninfo key to match')
@manager.option('--orphaned',
                help='Whether the token is an orphaned token. Set to 1')
@manager.option('--tokentype', help='The tokentype to search.')
@manager.option('--serial', help='A regular expression on the serial')
@manager.option('--description', help='A regular expression on the description')
@manager.option('--action', help='Which action should be performed on the '
                                 'found tokens. {0!s}'.format("|".join(ALLOWED_ACTIONS)))
@manager.option('--set-description', help='')
@manager.option('--set-tokeninfo-key', help='')
@manager.option('--set-tokeninfo-value', help='')
@manager.option('--sum', help='In case of the action "listuser", this switch specifies if the output'
                              ' should only contain the number of tokens owned by the user.',
                dest='sum', action='store_true')
@manager.option('--csv', dest='csv', action='store_true',
                help='In case of a simple find, the output is written as CSV instead of the '
                     'formatted output.')
def find(last_auth, assigned, active, tokeninfo_key, tokeninfo_value,
         tokeninfo_value_greater_than, tokeninfo_value_less_than,
         tokeninfo_value_after, tokeninfo_value_before,
         orphaned, tokentype, serial, description, action, set_description,
         set_tokeninfo_key, set_tokeninfo_value, sum, csv):
    """
    finds all tokens which match the conditions
    """
    if action and action not in ALLOWED_ACTIONS:
            sys.stderr.write("Unknown action. Allowed actions are {0!s}\n".format(
                ", ".join(["'{0!s}'".format(x) for x in ALLOWED_ACTIONS])
            ))
            sys.exit(1)
    filter = build_tokenvalue_filter(tokeninfo_key,
                                     tokeninfo_value,
                                     tokeninfo_value_greater_than,
                                     tokeninfo_value_less_than,
                                     tokeninfo_value_after,
                                     tokeninfo_value_before)

    tlist = _get_tokenlist(last_auth, assigned, active, tokeninfo_key,
                           filter, orphaned, tokentype, serial,
                           description)

    if not action:
        if not csv:
            print("Token serial\tTokeninfo")
            print("="*42)
            for token_obj in tlist:
                print("{0!s} ({1!s})\n\t\t{2!s}\n\t\t{3!s}".format(
                    token_obj.token.serial,
                    token_obj.token.tokentype,
                    token_obj.token.description,
                    token_obj.get_tokeninfo()))
        else:
            for token_obj in tlist:
                print(u"'{!s}','{!s}','{!s}','{!s}'".format(
                    token_obj.token.serial,
                    token_obj.token.tokentype,
                    token_obj.token.description,
                    token_obj.get_tokeninfo()
                ))

    elif action == "listuser":
        if not sum:
            tokens = export_token_data(tlist)
            for token in tokens:
                print(u",".join([u"'{0!s}'".format(x) for x in token]))
        else:
            users = export_user_data(tlist)
            for user, tokens in users.iteritems():
                print(u"{0!s},{1!s}".format(user, len(tokens)))

    elif action == "export":
        key, token_num, soup = export_pskc(tlist)
        sys.stderr.write("\n{0!s} tokens exported.\n".format(token_num))
        sys.stderr.write("\nThis is the AES encryption key of the token seeds.\n"
                         "You need this key to import the tokens again:\n\n\t{0!s}\n\n".format(key))
        print("{0!s}".format(soup))
    else:
        for token_obj in tlist:
            try:
                if action == "disable":
                    enable_token(serial=token_obj.token.serial, enable=False)
                    print("Disabling token {0!s}".format(token_obj.token.serial))
                elif action == "delete":
                    remove_token(serial=token_obj.token.serial)
                    print("Deleting token {0!s}".format(token_obj.token.serial))
                elif action == "unassign":
                    unassign_token(serial=token_obj.token.serial)
                    print("Unassigning token {0!s}".format(token_obj.token.serial))
                elif action == "mark":
                    if set_description:
                        print("Setting description for token {0!s}: {1!s}".format(
                            token_obj.token.serial, set_description))
                        token_obj.set_description(set_description)
                        token_obj.save()
                    if set_tokeninfo_value and set_tokeninfo_key:
                        print("Setting tokeninfo for token {0!s}: {1!s}={2!s}".format(
                            token_obj.token.serial, set_tokeninfo_key, set_tokeninfo_value))
                        token_obj.add_tokeninfo(set_tokeninfo_key, set_tokeninfo_value)
                        token_obj.save()
            except Exception as exx:
                print("Failed to process token {0}.".format(
                    token_obj.token.serial))
                print(u"{0}".format(exx))


@manager.option('--pskc', dest='pskc',
                help='Import this PSKC file.')
@manager.option('--preshared_key_hex', dest='preshared_key_hex',
                help='The AES encryption key.')
def loadtokens(pskc, preshared_key_hex):
    from privacyidea.lib.importotp import parsePSKCdata
    from privacyidea.lib.token import init_token, TOKENKIND

    with open(pskc, 'r') as pskcfile:
        file_contents = pskcfile.read()

    TOKENS = parsePSKCdata(file_contents, preshared_key_hex)
    success = 0
    failed = 0
    failed_tokens = []
    for serial in TOKENS:
        try:
            print("Importing token {0!s}".format(serial))
            import_token(serial, TOKENS[serial])
            success = success + 1
        except Exception as e:
            failed = failed + 1
            failed_tokens.append(serial)
            print("--- Failed to import token. {0!s}".format(e))
    print("Successfully imported {0!s} tokens.".format(success))
    print("Failed to import {0!s} tokens: {1!s}".format(failed, failed_tokens))


if __name__ == '__main__':
    manager.run()


