#!/usr/bin/env python
"""
Creates a frequency distribution of values from columns of the input file and prints it out in
columns - the first being the unique key and the last being the count of occurances.


Usage: gristle_freaker [options]


{see: helpdoc.HELP_SECTION}


Main Options:
    -i, --infiles [INFILES [INFILES ...]]
                         Input filenames or default to '-' for stdin.
    -o, --outfile OUTFILE
                         Output filename or '-' for stdout (the default).
    -c, --columns [COLUMNS [COLUMNS ...]]
                         Columns to collection distributions on.
                         Columns are identified based on a zero-offset, and multiple columns
                         can be provided.
    --col-type COL_TYPE  Process columns individually or together as a group.
                         Valid values are: 'specified', 'all', or 'each'. The default is
                         'specified'.


Limiting Options:
    --max-freq MAX_FREQ  Identifies the max number of items in freq dict.
                         Default is 10m which would require approx 300MB of
                         mem to store 10 million 20 byte keys.
    --sampling-method SAMPLING_METHOD
                         Method to use for sampling: 'non' or 'interval'.
                         The default is 'non'.
    --sampling-rate SAMPLING_RATE
                         The number of records to skip if using 'interval' sampling-method.


Output Formatting Options:
    --max-key-len MAX_KEY_LEN
                         Max length of key written to output.
                         The default is 50.
    --write-limit WRITE_LIMIT
                         Max records to write.
                         Default of 0 indicates no limit.
    --sort-col SORT_COL  Column to sort by in the output, either 0 (the key) or 1 (count).
                         Default is 1.
    --sort-order SORT_ORDER
                         Specifies sort order of forward or reverse.
                         The default is reverse.
    --verbosity VERBOSITY
                         Stdout level of detail with valid values of quiet, normal, high, debug.


{see: helpdoc.CSV_SECTION}


{see: helpdoc.CONFIG_SECTION}


Examples:
    $ gristle_freaker -i sample.csv -d '|'  -c 0
        Creates two columns from the input - the first with unique keys from column 0, the second
        with a count of how many times each exists.
    $ gristle_freaker -i sample.csv -d '|'  -c 0 --sort-col 1 --sort-order forward
            --write-limit 25
        In addition to what was described in the first example, this example adds sorting of the
        output by count ascending and just prints the first 25 entries.
    $ gristle_freaker -i sample.csv -d '|'  -c 0 --sampling-rate 3
             --sampling-method interval
        In addition to what was described in the first example, this example adds a sampling in
        which it only references every third record.
    $ gristle_freaker -i sample.csv -d '|'  -c 0 1
        Creates three columns from the input - the first two with unique key combinations from
        columns 0 & 1, the third with the number of times each combination exists.
    $ gristle_freaker -i sample.csv -d '|'  -c -1
        Creates two columns from the input - the first with unique keys from the last column of
        the file (negative numbers wrap), then a second with the number of times each exists.
    $ gristle_freaker -i sample.csv -d '|'  --col-type all
        Creates two columns from the input - all columns combined into a key, then a second with
        the number of times each combination exists.
    $ gristle_freaker -i sample.csv -d '|'  --col-type each
        Unlike all other examples, this one performs a separate analysis for every single column
        of the file.  Each analysis produces three columns from the input - the first is a column
        number, second is a unique value from the column, and the third is the number of times
        that value appeared.  This output is repeated for each column.
    Many more examples can be found here:
        https://github.com/kenfar/DataGristle/tree/master/examples/gristle_freaker


Licensing and further Info:
    This source code is protected by the BSD license.  See the file "LICENSE"
    in the source code root directory for the full language or refer to it here:
        http://opensource.org/licenses/BSD-3-Clause
    Copyright 2011-2021 Ken Farmer
"""

import sys
import operator
from os.path import basename
import errno
from pprint import pprint as pp
from typing import Dict, List, Tuple, Optional, Any
from signal import signal, SIGPIPE, SIG_DFL

from datagristle.common import abort
import datagristle.configulator as configulator
import datagristle.file_io as file_io
import datagristle.helpdoc as helpdoc

# Ignore SIG_PIPE and don't throw exceptions on it...
# (http://docs.python.org/library/signal.html)
signal(SIGPIPE, SIG_DFL)

NAME = basename(__file__)
LONG_HELP = helpdoc.expand_long_help(__doc__)
SHORT_HELP = helpdoc.get_short_help_from_long(LONG_HELP)



def main() -> int:

    return_code = 0 # success

    try:
        config_manager = ConfigManager(NAME, SHORT_HELP, LONG_HELP)
        nconfig, _ = config_manager.get_config()
    except EOFError:
        sys.exit(errno.ENODATA) # 61: empty file

    # Run once for each col_set.  Only col_type of 'each' will have > 1 col_set,
    # in which they'll run the following loop once for each column.  Other col_types
    # will just run a single iteration of the loop.
    total_read_cnt = 0
    while True:
        input_handler = file_io.InputHandler(nconfig.infiles,
                                             nconfig.dialect)
        output_handler = file_io.OutputHandler(nconfig.outfile,
                                               input_handler.dialect)

        colset_freaker = ColSetFreaker(input_handler,
                                       output_handler,
                                       nconfig.col_type,
                                       nconfig.max_freq,
                                       nconfig.sampling_method,
                                       nconfig.sampling_rate,
                                       nconfig.sort_order,
                                       nconfig.sort_col,
                                       nconfig.max_key_len)
        colset_freaker.create_freq_from_column_set(nconfig.col_set)
        total_read_cnt += colset_freaker.read_cnt

        colset_freaker.write_output(nconfig.outfile,
                                    nconfig.write_limit,
                                    nconfig.col_set[0])

        if colset_freaker.more_columns():   # col_type is 'each'
            nconfig.col_set[0] += 1
        else:
            if total_read_cnt == 0:
                return_code = errno.ENODATA # 61: empty file
            elif colset_freaker.truncated:
                return_code = errno.ENOMEM  # 12: too many unique values
            break

        input_handler.close()
        output_handler.close()


    return return_code




class ColSetFreaker(object):

    def __init__(self,
                 input_handler,
                 output_handler,
                 col_type: str,
                 number: int,
                 sampling_method: str,
                 sampling_rate: float,
                 sort_order: str,
                 sort_col: int,
                 max_key_len: int) -> None:

        self.input_handler = input_handler
        self.output_handler = output_handler
        self.col_type = col_type
        self.number = number
        self.sampling_method = sampling_method
        self.sampling_rate = sampling_rate
        self.sort_order = sort_order
        self.sort_col = sort_col
        self.max_key_len = max_key_len

        self.col_len_tracker: ColumnLengthTracker = None
        self.field_freq: Dict[Any, Any] = {}
        self.sort_freq: List[Any] = None
        self.truncated: bool = None
        self.read_cnt = 0


    def create_freq_from_column_set(self, col_set: List[int]) -> None:

        #----- calculate field lengths -----
        self.build_freq(col_set)

        #----- sort the field_freq ------
        revorder = (True if self.sort_order == 'reverse' else False)
        self.sort_freq = sorted(self.field_freq.items(),
                                key=operator.itemgetter(self.sort_col),
                                reverse=revorder)

        #----- calculate field lengths for output formmating -----
        self.col_len_tracker = ColumnLengthTracker()
        self.col_len_tracker.add_all_values(self.sort_freq)
        self.col_len_tracker.trunc_all_col_lengths(self.max_key_len)


    def build_freq(self, columns: List[int]) -> None:
        """ Inputs:
                - columns    - list of columns to put into key
            Updates instance variables:
                - freq_dict  - freq distribution dict - can be empty if column
                               doesn't exist in file.  This happens when user
                               is doing a freq on "each" columns.
                - truncated  - True/False boolean that indicates if the dict was truncated
                Tested via test-harness
        """
        self.read_cnt = 0
        freq_cnt = 0
        self.field_freq = {}
        self.truncated = False

        def process_rec(record: List[str]) -> None:
            nonlocal freq_cnt
            self.read_cnt += 1
            if self.read_cnt == 1 and self.input_handler.dialect.has_header:
                return
            if self.sampling_method == 'interval':
                if self.read_cnt % self.sampling_rate != 0:
                    return
            key = self._create_key(record, columns)
            if key:
                try:
                    self.field_freq[key] += 1
                except KeyError:
                    self.field_freq[key] = 1
                    freq_cnt += 1
                    if freq_cnt >= self.number:
                        self.truncated = True
                        print('WARNING: freq dict is too large - will truncate')
                        raise TooLargeError

        for rec in self.input_handler:
            try:
                process_rec(rec)
            except TooLargeError:
                break


    def _create_key(self,
                    fields: List[str],
                    required_field_indexes: List[int]) -> Tuple[str, ...]:
        """ input:
                - fields - a single record in the form of a list of fields
                - requird_fields - identifies which fields to put into output key
                output:
                - the required fields from fields in tuple format.
                Tested via test-harness.
        """
        if self.col_type == 'all':
            return tuple(fields)

        key = []
        for field_idx in required_field_indexes:
            try:
                key.append(fields[field_idx].strip())
            except IndexError:
                break # is probably freaking all cols with 'each' option.
        return tuple(key)


    def more_columns(self):
        if self.col_type == 'each':
            if self.sort_freq == []:   # columntype 'each' just ran out of cols
                return False
            else:
                return True
        else:
            return False


    def write_output(self,
                     output: str,
                     write_limit: int,
                     col_set_num: int):
        if output == '-':
            outfile = sys.stdout
            #outfile = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', newline='')
        else:
            outfile = open(output, 'a')

        write_cnt = 0
        for key_tup in self.sort_freq:
            write_cnt += 1
            if write_limit and write_cnt > write_limit:
                break
            row = self.create_output_row(key_tup, col_set_num)
            outfile.write(row)

        if output != '-':
            outfile.close()


    def create_output_row(self,
                          freq_tup: Tuple[Tuple[str, ...], int],
                          col_set_num: int) -> str:
        """ input:
            - freq_tup - is a tuple consisting of two entries: key & count.
                Each key is also a tuple.  So it looks like this:
                    (('key1a', 'key1b', 'key1c'), 5   )
            output:
            - an entire row to be written as output
            tested via test-harness
        """
        if self.col_type == 'each':
            key_label = 'col:%03d - ' % col_set_num
        else:
            key_label = ''

        key_part = ''
        for colnum, key in enumerate(freq_tup[0]):
            key_string = ' %-*.*s  - ' % (self.col_len_tracker.max_dict[colnum],
                                          self.col_len_tracker.max_dict[colnum],
                                          key)
            key_part += key_string
        row = '%s%s %s\n' % (key_label, key_part, freq_tup[1])
        return row





class ColumnLengthTracker(object):
    """ Tracks the maximum key length for each column.
        This is useful later when printing columns - in order to keep them
        of consistent widths for each record.
        Tested via test harness.
    """
    def __init__(self):
        self.max_dict: Dict[int, int] = {}

    def add_val(self, col_num: int, value: str) -> None:
        """ Update tracked max length for a given column if the new value
            is greater than the old value.
            Inputs:
                 - col_num:  is the position of the column within the key
                 - value:    is the actual column value
            Outputs:
                 - none:     it updates the internal object length dictionary
        """
        try:
            if len(value) > self.max_dict[col_num]:
                self.max_dict[col_num] = len(value)
        except KeyError:
            self.max_dict[col_num] = len(value)

    def add_all_values(self, freq_list: List[Tuple[Tuple[str, ...], int]]) -> None:
        """ Step through a list of tuples which the key is a tuple of 1-many
            parts, and the value is the count.  For each part of the key call
            the add_val function to then store the max length.
        """
        for keyval_tup in freq_list:
            keys = keyval_tup[0]
            for index, key in enumerate(keys):
                self.add_val(index, key)

    def trunc_all_col_lengths(self, maxkeylen: int) -> None:
        """ Gradually reduces the length of columns, reducing the longest one
            before the shorter ones.
        """
        while (self._get_tot_col_len() - maxkeylen) > 0:
            max_key = max(self.max_dict, key=self.max_dict.get)
            self.max_dict[max_key] -= 1

    def _get_tot_col_len(self) -> int:
        """ Returns the total column length by adding together to the longest
            column value of each column.
        """
        return sum(list(self.max_dict.values()))







class ConfigManager(configulator.Config):


    def define_obsolete_config(self) -> None:
        self.add_obsolete_metadata('number', short_name='n', msg='-n/--number have been replaced by --max-freq')


    def define_user_config(self) -> None:
        """ Get the configuration elements from arguements and config files
        """

        self.add_standard_metadata('infiles')
        self.add_standard_metadata('outfile')

        self.add_custom_metadata(name='columns',
                                 short_name='c',
                                 default=None,
                                 type=int,
                                 nargs='*')
        self.add_custom_metadata(name='col_type',
                                 default='specified',
                                 choices=['specified', 'all', 'each'],
                                 type=str)
        self.add_custom_metadata(name='max_freq',
                                 default=10_000_000,
                                 type=int)
        self.add_custom_metadata(name='max_key_len',
                                 type=int,
                                 minimum=1,
                                 maximum=500,
                                 default=50)
        self.add_custom_metadata(name='write_limit',
                                 type=int,
                                 default=0)
        self.add_custom_metadata(name='sort_col',
                                 type=int,
                                 default=1,
                                 choices=[0, 1])
        self.add_custom_metadata(name='sort_order',
                                 type=str,
                                 choices=['forward', 'reverse'],
                                 default='reverse')
        self.add_custom_metadata(name='sampling_method',
                                 type=str,
                                 choices=['non', 'interval'],
                                 default='non')
        self.add_custom_metadata(name='sampling_rate',
                                 type=int,
                                 default=None)

        self.add_standard_metadata('verbosity')
        self.add_all_config_configs()
        self.add_all_csv_configs()
        self.add_all_help_configs()



    def validate_custom_config(self,
                               config) -> None:

        # Handle columns & col_type relationship
        if config['columns'] is None:
            if config['col_type'] == 'specified':
                abort("ERROR: Columns must be provided for col_type of 'specified'.")
        else:
            if config['col_type'] != 'specified':
                abort('ERROR: Columns can only be specified for col_type != specified.')

        # Validate the sampling fields
        if config['sampling_method'] == 'non':
            if config['sampling_rate']:
                abort('ERROR: Sampling_rate populated',
                      'sampling_rate must not be provided if sampling_method is "non"')
        elif config['sampling_method'] == 'interval':
            if config['sampling_rate']:
                if config['sampling_rate'] != int(config['sampling_rate']):
                    abort('ERROR: Inverval sampling rate must be an int')
            else:
                abort('ERROR: Integer sampling_rate missing',
                      'sampling_rate must be provided if sampling_method == "interval"')


    def extend_config(self):
        """ Add derrived or extended elements to config.
            Note: csvhelper.get_dialect could raise EOFError
        """
        self.generate_csv_dialect_config()

        # pylint: disable=no-member
        if self.nconfig.col_type in ('each', 'all'):
            self.update_config('col_set', [0])
        else:
            self.update_config('col_set', self.nconfig.columns)



class TooLargeError(Exception):
    pass



if __name__ == '__main__':
    sys.exit(main())
