#!/usr/bin/env python3
import datetime
import gzip
import json
import operator
import re
import pytz
import time
import sys

from functools import reduce  # forward compatibility for Python 3


# datetime.datetime.strptime("2020-04-06T00:42:23.121+0200", strptime_format)
_DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%f%z" 
_DATETIME_FORMAT_INPUT = _DATETIME_FORMAT[:-5]
_BUFSIZE = 4096
_REALTIME_DELIM = '\n'


def _get_fblocks(fstr):
    """
    from 'a and b or c OR d and e or gol'
    returns 
    [['a', 'b'], ['c'], ['d', 'e'], ['gol']]
    """
    orf = [i.strip() for i in re.split(' (or|OR) ', fstr) if i not in ('or', 'OR')]  
    for ind in range(len(orf)):
        andf = [i.strip() for i in re.split(' (and|AND) ', orf[ind]) if i not in ('and', 'AND')]
        if len(andf) > 1:
            orf[ind] = andf
        if isinstance(orf[ind], str):
            orf[ind] = [orf[ind].strip()]
    return orf
        

class PyJQ(object):
    def __init__(self, json_obj, datetime_field=None,
                 datetime_format=_DATETIME_FORMAT):
        
        if isinstance(json_obj, str):
            self.dline = [json.loads(json_obj)]
        elif isinstance(json_obj, dict):
            self.dline = [json_obj]
        elif isinstance(json_obj, list):
            self.dline = json_obj
        
        if datetime_field:
            dt_value = self._get_value(datetime_field)
            self.datetime = datetime.datetime.strptime(dt_value,
                                                       datetime_format)

    def _get_value(self, key):
        """
        keys should be a single values or nested like 'key1__key2__key3'
        """
        keys = key.split('__')
        dt_value = ''
        for dline in self.dline:
            try:
                dt_value = reduce(operator.getitem, keys, dline)
            except TypeError as e:
                keys_repr = ''.join(['[{}]'.format(k) for k in keys])
                raise Exception('self.dline[{}] does not exists'.format(keys_repr))
            except KeyError as exc:
                # raise Exception('self.dline[{}] does not exists'.format(keys_repr))
                # not all the json entries have the same structure ...
                return
            return dt_value
        
    def filter(self, key, op='==', value=None):
        """
        key = 'agent__id'

        # https://docs.python.org/3/library/operator.html
        operator = set( == | in |)

        value = expected

        usage:
            wa.filter('agent__ip') -> return the corresponding value
            wa.filter('agent__ip', '172.16.16.2') -> returns True if match
            wa.filter('rule__description', 'in', 'login') -> returns if 'login' word is in the rule description
        """

        ops = {'==': operator.eq,
               '!=': operator.ne,
               'in': operator.contains,
                '<=': operator.le,
                '<': operator.lt,
                '>': operator.gt,
                '>=': operator.ge,
                # add all the others here
               }
        if op not in ops:
            raise Exception('Invalid operator "{}"'.format(op))
            
        value_got = str(self._get_value(key)).encode()
        
        if ops[op](value_got, str(value).encode()):
            return True
        elif value == None:
            return value_got

    def info(self):
        elems = (json.dumps(self.dline.get('data', ''), indent=2),
                 str(self.dline.get('full_log', '')))
        result = '\n'.join(elems)
        print(result)

    def __str__(self):
        return json.dumps(self.dline, indent=2)


def jqpy(args, line, filters=None):
    jq = PyJQ(line, datetime_field=args.datetime_field,
                    datetime_format=args.datetime_format)
    status = True
    limit_cnt = 0

    if args.start_datetime:
        # this is not needed with py37
        start_datetime = datetime.datetime.strptime(args.start_datetime, _DATETIME_FORMAT_INPUT)
        if jq.datetime < start_datetime.replace(tzinfo=pytz.UTC):
            return
    if args.end_datetime:
        # this is not needed with py37
        end_datetime = datetime.datetime.strptime(args.end_datetime, _DATETIME_FORMAT_INPUT)
        if jq.datetime > end_datetime.replace(tzinfo=pytz.UTC):
            return

    if filters:
        filters_blocks = _get_fblocks(filters)
        for fi in filters_blocks:
            if len(fi) == 1 and jq.filter(*fi[0].split(' ')):
                status = True              
                break
            elif len(fi) > 1:
                and_status = True
                for and_filter in fi:
                    if not jq.filter(*and_filter.split(' ')):
                        and_status = False
                        break
                status = and_status
            else:
                status = False
    
    if status:
        print(jq)
        if args.limit:
            limit_cnt += 1
            if limit_cnt == args.limit:
                return


def handle_stream(args):
    file_content = ''
    try:
        alert_file = open(args.json)
        file_content = alert_file.read()
    except UnicodeDecodeError:
        # try to handle it a gzip
        alert_file = gzip.open(args.json, 'rb')
        file_content = alert_file.read().decode()

    # got something like ['agent__ip', '==', '172.16.16.254']
    filters = args.filters

    # handle a pure json file or multiple json objects delimited by newline (/n)
    try:
        dlines = json.loads(file_content)
        for dline in dlines:
            if args.echo:
                print(dline)
                continue
            jqpy(args, dline, filters)
        return
    except json.decoder.JSONDecodeError as excp:
        #raise('This is not a pure json file "{}" but multiple json delimited by /n'.format(excp))
        pass
    
    # if realtime it takes only the latest entries
    if args.realtime:
        line_buffer = ''
        alert_file.read()
        while 1:
            time.sleep(1)
            digits = alert_file.read(_BUFSIZE)
            dec_digits = digits.decode() if isinstance(digits, bytes) else digits
            line_buffer += dec_digits
            chunks = line_buffer.split(_REALTIME_DELIM)
            for chunk in chunks:
                try:
                    jqpy(args, chunk, filters)
                except json.decoder.JSONDecodeError:
                    pass
                line_buffer = _REALTIME_DELIM.join(chunks[1:])
    else:
        alert_file.seek(0)
        for line in alert_file.readlines():
            if args.echo:
                print(line)
                continue
            else:
                line = line.decode() if isinstance(line, bytes) else line
                jqpy(args, line, filters)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('-json', required=True,
                        help="a json file")
    parser.add_argument('-filters', required=False,
                        help="")
    parser.add_argument('-datetime_field', default = None,
                        required=False,
                        help="")
    parser.add_argument('-datetime_format', default=_DATETIME_FORMAT,
                        required=False,
                        help="")
    parser.add_argument('-start_datetime', #type=datetime.datetime.fromisoformat, only supported in py37
                        required=False,
                        help="2020-04-06T10:22:00")
    parser.add_argument('-end_datetime', #type=datetime.datetime.fromisoformat, only supported in py37
                        required=False,
                        help="2020-04-06T13:22:00")
    parser.add_argument('-limit', type=int,
                        required=False,
                        help="how many results to return")
    parser.add_argument('-realtime', action='store_true',
                        required=False,
                        help="continous and realtime reading, takes only the latest entries")
    parser.add_argument('-echo', action='store_true',
                        required=False,
                        help="print processed lines in raw format and exit")
    args = parser.parse_args()
    handle_stream(args)
