#!/bin/sh

##############################################################################
# SQL query csv file(s).
#

##############################################################################
# BOOTSTRAP
#
# Include ../lib in the search path before calling python.
# (Thanks to https://unix.stackexchange.com/questions/20880)
#
if "true" : '''\'
then
    export PYTHONPATH="$(dirname $0)/../lib:$PYTHONPATH"
    pythoncmd=python

    if command -v python3 >/dev/null; then
        pythoncmd=python3
    fi

    exec "$pythoncmd" "$0" "$@"
    exit 127
fi
'''

##############################################################################
# PYTHON CODE BEGINS HERE

import os
import re
import sys
import errno
import shlex
import getopts
import sqlite3
import readline
import subprocess
from csvmagic import libcsv

__copyright__ = 'Copyright 2020-2022 Mark Kim'
__license__ = 'Apache 2.0'
__version__ = '2.0.0'
__author__ = 'Mark Kim'


##############################################################################
# GLOBALS

SCRIPTNAME = os.path.basename(__file__)

class opts:
    files = []
    tables = []
    delim = None
    commands = []


##############################################################################
# USAGE

def usage():
    '''\
SQL query csv file(s).

usage: {SCRIPTNAME} [OPTIONS] [FILES]

Options:
  FILE                  csv file(s) to query.

  -a, --as=TABLE        Load FILE into the table named TABLE.  If this option
                        is omitted, the table name defaults to the base name of
                        the FILE.  This option may be specified once per FILE.

  -c, --command=SQL     SQL commands to execute.  May be specified multiple
                        times, once per SQL command.

  -d, --delim=DELIM     Use DELIM as the value delimiter, where DELIM may be
                        'p' for the pipe (|), 't' for the tab (\\t), 'a' for
                        the SOH (ASCII 1), or other string literal of one or
                        more characters and Python  string  escape  sequences.
                        DELIM  may  include  escape characters.  By default the
                        delimiter is guessed from the characters in the
                        CSV_DELIMS environment variable.

Environment Variables:
  CSV_DELIMS            A set of characters used to guess the delimiter of a
                        csv file.  The guesswork happens when reading  the
                        first  line  of  the first FILE, by testing each
                        character present in CSV_DELIMS for the character with
                        the most occurrence in the line.  If any of the
                        characters occur the same number  of  times  (including
                        zero),  the  earlier  character  in the variable is
                        chosen.  If the environment variable is not set, it
                        defaults to '{libcsv.DELIMS}'.  CSV_DELIMS may include
                        escape characters.
'''

    print(usage.__doc__.format(**globals()))


##############################################################################
# MAIN

def main():
    errcount = 0

    getopt = getopts.getopts(sys.argv, {
        'a' : 1, 'as'      : 1,
        'c' : 1, 'command' : 1,
        'd' : 1, 'delim'   : 1,
        'h' : 0, 'help'    : 0,
    })

    for c in getopt:
        if c in ('-')                   : opts.files.append(getopt.optarg)
        elif c in ('a', 'as')           : opts.tables.append(getopt.optarg)
        elif c in ('c', 'command')      : opts.commands.append(getopt.optarg)
        elif c in ('d', 'delim')        : opts.delim = arg_to_delim(getopt.optarg)
        elif c in ('h', 'help')         : usage(); sys.exit(0)
        else                            : errcount += 1

        while len(opts.tables) < len(opts.files)-1:
            opts.tables.append(None)

        while len(opts.files) < len(opts.tables)-1:
            opts.files.append(None)

    while len(opts.tables) < len(opts.files):
        opts.tables.append(None)

    while len(opts.files) < len(opts.tables):
        opts.files.append(None)

    # Sanity check
    if errcount:
        sys.stderr.write('Type `{SCRIPTNAME} --help` for help.\n'.format(**globals()))
        sys.exit(1)

    # Instantiate the database
    db = Database()

    # Load all tables
    for i, f in enumerate(opts.files):
        db.load(f, opts.tables[i])

    # Execute SQL commands
    for c in opts.commands:
        CLI(db).runcommand(c)

    # If no commands were specified, launch the CLI
    if not opts.commands:
        CLI(db).mainloop()


def arg_to_delim(delim):
    # Delimiter shorthands
    if delim in ('p'): delim = '|'
    elif delim in ('t'): delim = '\t'
    elif delim in ('a'): delim = '\u0001'

    # Translate escape characters
    decoded = delim.encode().decode('unicode_escape')

    # Use the decoded string only if it's a single character
    if len(delim) == 1 or len(decoded) == 1:
        delim = decoded

    return delim


class CLI(object):
    def __init__(self, database):
        self.__db = database
        self.__rec = Record(database)
        self.__input = Input()
        self.__variables = {
            'filter': None,
        }
        self.__commands = {
            '!': {
                'args': 'COMMAND [ARGS]',
                'desc': 'Execute the program COMMAND and pass ARGS.',
                'exec': self.__cmd_exec,
            },

            'exit': {
                'desc': 'Exit csvsql.',
                'exec': self.__cmd_exit,
            },

            'echo': {
                'args': '[ARGS]',
                'desc': 'Print ARGS to stdout.',
                'exec': self.__cmd_echo,
            },

            'set': {
                'args': '[NAME VALUE]',
                'desc': 'Set the variable NAME to VALUE.  If omitted, show all variables and their values.',
                'exec': self.__cmd_set,
            },

            'unset': {
                'args': 'NAME',
                'desc': 'Unset variable NAME.',
                'exec': self.__cmd_unset,
            },

            'load': {
                'args': 'FILE [--as=TABLE] [--delim=DELIM]',
                'desc': 'Load FILE into TABLE.  If omitted TABLE is the filename without the extension.',
                'exec': self.__cmd_load,
            },

            'showtables': {
                'desc': 'Print the list of tables.',
                'exec': self.__cmd_showtables,
            },

            'use': {
                'args': 'TABLE',
                'desc': 'Use TABLE as the default table.',
                'exec': self.__cmd_use,
            },

            'next': {
                'desc': 'Move to the next record.',
                'exec': self.__cmd_next,
            },

            'prev': {
                'desc': 'Move to the previous record.',
                'exec': self.__cmd_prev,
            },

            'read': {
                'args': 'INDEX',
                'desc': 'Go to the record with recid INDEX.',
                'exec': self.__cmd_read,
            },

            'search': {
                'args': 'CONDITION',
                'desc': 'Show all records matching SQL CONDITION.  CONDITION is typically quoted.',
                'exec': self.__cmd_search,
            },

            'count': {
                'args': '[CONDITION]',
                'desc': 'Print the number of records in the current table [that matches CONDITION].',
                'exec': self.__cmd_count,
            },

            'show': {
                'desc': 'Show the current record.',
                'exec': self.__cmd_show,
            },

            'showindexes': {
                'desc': 'Print the list of keys for the default table.',
                'exec': self.__cmd_showindexes,
            },

            'help': {
                'desc': 'Show the list of commands.',
                'exec': self.__cmd_help,
            },
        }

    def mainloop(self):
        while True:
            try:
                line = self.__getcommand()
                self.runcommand(line)

            except EOFError as e:
                print(str(e))
                break

            except AttributeError as e:
                print(str(e))
                pass

    def __completer(self, text, state):
        completed = ''
        counter = 0

        for command in self.__commands.keys():
            if command.startswith(text):
                completed = command
                counter += 1

            if counter > state: break

        if counter <= state: completed = None

        return completed

    def __getcommand(self):
        table = self.__rec.get_table()
        recnum = self.__rec.get_recnum()
        prompt = 'csvsql' if table is None else '%s.%s' % (table, recnum)

        completion_words = list(self.__commands.keys())
        completion_words += self.__db.get_tables()
        self.__input.set_completion_words(completion_words)

        return self.__input.get_command(prompt)

    def runcommand(self, line):
        tokens = shlex.shlex(line, posix=True, punctuation_chars=True)
        tokens = list(tokens)

        if len(tokens) == 0:
            pass
        elif tokens[0] in self.__commands:
            # built-in commands may be separated by semicolon - parse and execute
            while len(tokens):
                parts = []

                while len(tokens):
                    t = tokens[0]
                    tokens = tokens[1:]

                    if t == ';': break
                    parts.append(t)

                command = parts[0]
                args = parts[1:]

                self.__commands[command]['exec'](args)
        else:
            self.__db.execute(line)

    def __cmd_exit(self, args):
        raise EOFError()

    def __cmd_echo(self, args):
        print(args)

    def __cmd_load(self, args):
        filenames = []
        tablenames = []
        delim = None
        exitcode = 0

        getopt = getopts.getopts(['read'] + args, {
            'a' : 1, 'as'    : 1,
            'd' : 1, 'delim' : 1,
        })

        for c in getopt:
            if c in ('-')            : filenames.append(getopt.optarg)
            elif c in ('a', 'as')    : tablenames.append(getopt.optarg)
            elif c in ('d', 'delim') : delim = getopt.optarg
            else                     : exitcode = 1

            while len(tablenames) < len(filenames)-1:
                tablenames.append(None)

            while len(filenames) < len(tablenames)-1:
                filenames.append(None)

        while len(filenames) < len(tablenames):
            filenames.append(None)

        while len(tablenames) < len(filenames):
            tablenames.append(None)

        if exitcode == 0:
            for i, f in enumerate(filenames):
                t = tablenames[i]

                try:
                    self.__db.load(f, t, delim)
                except FileNotFoundError as e:
                    print('File not found: %s' % f)

        return exitcode

    def __cmd_showtables(self, args):
        tablelist = {}
        widths = [0, 0]

        for table in self.__db.get_tables():
            for row in self.__db.query("select count(*) from %s" % table):
                count = row[0]

            widths[0] = max(widths[0], len(table))
            widths[1] = max(widths[1], len(str(count)))

            tablelist[table] = {
                'count': count,
            }

        for table in sorted(tablelist.keys()):
            count = tablelist[table]['count']

            print("%-*s  %-*s records" % (widths[0], table, widths[1], count))

    def __cmd_use(self, args):
        if len(args) != 1:
            print('Expected 1 argument, got %s' % len(args))
            return

        try:
            self.__rec.use_table(args[0])
        except Record.RecordError as e:
            print(str(e))

    def __cmd_next(self, args):
        try:
            self.__rec.next()
        except Record.RecordError as e:
            print(str(e))

    def __cmd_prev(self, args):
        try:
            self.__rec.prev()
        except Record.RecordError as e:
            print(str(e))

    def __cmd_read(self, args):
        if len(args) != 1:
            print('Expected 1 argument, got %s' % len(args))
            return

        try:
            recid = int(args[0])
        except ValueError as e:
            print('Not an integer: "%s"' % args[0])
            return

        try:
            self.__rec.read(recid)
        except Record.RecordError as e:
            print(str(e))

    def __cmd_search(self, args):
        if len(args) != 1:
            print('Expected 1 argument, got %s' % len(args))
            return

        try:
            self.__rec.search(args[0])
        except Record.RecordError as e:
            print(str(e))

    def __cmd_count(self, args):
        condition = args[0] if len(args) else None

        try:
            count = self.__rec.count(condition)
            print(count)
        except Record.RecordError as e:
            print(str(e))

    def __cmd_show(self, args):
        try:
            self.__rec.show()
        except Record.RecordError as e:
            print(str(e))

    def __cmd_showindexes(self, args):
        try:
            self.__rec.showindexes()
        except Record.RecordError as e:
            print(str(e))

    def __cmd_help(self, args):
        commands = self.__commands

        for cmd in args or sorted(commands.keys()):
            SYNTAX_WIDTH = 17

            if cmd in commands:
                command = commands[cmd]
                help_args = command['args'] if 'args' in command else ''
                help_desc = command['desc']
                syntax = '%s %s' % (cmd, help_args)

                if len(syntax) <= SYNTAX_WIDTH:
                    print('    %-*s %s' % (SYNTAX_WIDTH, syntax, help_desc))
                else:
                    print('    %-*s' % (SYNTAX_WIDTH, syntax))
                    print('    %-*s %s' % (SYNTAX_WIDTH, '', help_desc))

            else:
                print('No help for: %s' % cmd)

    def __cmd_exec(self, args):
        try:
            process = subprocess.Popen(args)
            process.wait()
        except IndexError as e:
            print('Must specify a command')
        except FileNotFoundError as e:
            print(str(e))

        print()

    def __cmd_set(self, args):
        if len(args) < 1:
            output = FilteredOutput(self.__variables['filter'])
            delim = self.__db.get_delim()
            values = []
            cols = []

            for i, field in enumerate(sorted(self.__variables.keys())):
                cols.append(libcsv.Value(field).minquoted(delim))

            output.write('%s\n' % delim.join(cols))

            for i, field in enumerate(sorted(self.__variables.keys())):
                value = self.__variables[field]

                v = libcsv.Value(value).strquoted(delim)

                if v == '':
                    v = 'NULL'

                values.append(v)

            output.write('%s\n' % delim.join(values))

            del(output)
            return

        if args[0] not in self.__variables:
            print('Invalid variable: %s' % args[0])
            return

        if len(args) < 2:
            print(self.__variables[args[0]])
            return

        self.__variables[args[0]] = args[1]
        self.__loadvars()

    def __cmd_unset(self, args):
        if len(args) != 1:
            print('Expected 1 argument, got %s' % len(args))
            return

        if args[0] in self.__variables:
            self.__variables[args[0]] = None
            self.__loadvars()
        else:
            print('Invalid variable: %s' % args[0])

    def __loadvars(self):
        self.__db.set_filter(self.__variables['filter'])
        self.__rec.set_filter(self.__variables['filter'])


class Input(object):
    def __init__(self):
        self.__buffer = []
        self.__completion_words = []

        readline.set_completer(self.__completer)
        readline.parse_and_bind('tab: complete')

    def set_completion_words(self, words):
        self.__completion_words = words

    def __completer(self, text, state):
        completed = ''
        counter = 0

        for command in self.__completion_words:
            if command.startswith(text):
                completed = command
                counter += 1

            if counter > state: break

        if counter <= state: completed = None

        return completed

    def get_command(self, prompt):
        if not self.__buffer:
            self.__buffer_input(prompt)

        line = self.__buffer[0]
        self.__buffer = self.__buffer[1:]

        return line

    def __buffer_input(self, prompt):
        line = ''

        while True:
            line += input('%*s> ' % (len(prompt), prompt))
            prompt = '... '

            # Test for the line continuation character
            if line.endswith('\\'):
                line = line[:-1]
                continue

            # Test for parsability
            try:
                shlex.shlex(line, posix=True, punctuation_chars=True)

            except ValueError:
                line += '\n'
                continue

            # Done!
            break

        # There is no real buffering here since we only read one line, buffer
        # it, then pull it out, so the buffer never has more than one line.
        # But we'll leave the buffering code here for future extension.

        self.__buffer.append(line)


class FilteredOutput(object):
    def __init__(self, filter=None):
        if filter is None:
            self.__file = sys.stdout
        else:
            reader, writer = os.pipe()
            self.__proc = subprocess.Popen(filter, stdin=reader, shell=True)
            self.__file = os.fdopen(writer, 'w')

    def __del__(self):
        if self.__file == sys.stdout:
            return

        self.__file.close()
        self.__proc.wait()

    def write(self, text):
        self.__file.write(text)


class Record(object):
    class RecordError(Exception): pass
    class BadArgumentError(RecordError): pass
    class DependencyError(RecordError): pass
    class NoSuchRecord(RecordError): pass

    def __init__(self, database):
        self.__db = database
        self.__table = None
        self.__fields = [ 'rowid' ]
        self.__values = [ 0 ]
        self.__fieldlen = 0
        self.__dataindex = 1
        self.__filter = None

    def __getitem__(self, name):
        try:
            index = self.__fields.index(name)
            value = self.__values[index]
        except (ValueError, IndexError) as e:
            value = None

        return value

    def get_table(self):
        return self.__table

    def get_recnum(self):
        return self.__getitem__('rowid') or 0

    def set_filter(self, filter):
        self.__filter = filter

    def use_table(self, table):
        if table in self.__db.get_tables():
            self.__table = table
            self.new()
        else:
            raise Record.BadArgumentError('No such table: "%s"' % table)

    def new(self):
        cursor = self.__db.query('select rowid, * from %s limit 1' % self.__table)
        self.__fields = []
        self.__values = []
        self.__fieldlen = 0

        for cell in cursor.description:
            self.__fields.append(cell[0])
            self.__values.append(None)
            self.__fieldlen = max(self.__fieldlen, len(cell[0]))

        cursor.close()

    def next(self):
        nextnum = self.get_recnum() + 1

        if self.__table is None:
            raise Record.DependencyError('Must set table first')
        else:
            rec = self.__db.query('select rowid, * from %s where rowid=?' % self.__table, nextnum).fetchone()

        if rec is None:
            raise Record.NoSuchRecord('No more records')
        else:
            self.__read_rec(rec)

    def prev(self):
        prevnum = self.get_recnum() - 1
        prevnum = max(prevnum, 0)

        if self.__table is None:
            raise Record.DependencyError('Must set table first')
        else:
            rec = self.__db.query('select rowid, * from %s where rowid=?' % self.__table, prevnum).fetchone()

        if rec is None:
            raise Record.NoSuchRecord('No more records')
        else:
            self.__read_rec(rec)

    def read(self, recid):
        if self.__table is None:
            raise Record.DependencyError('Must set table first')
        else:
            rec = self.__db.query('select rowid, * from %s where rowid=?' % self.__table, recid).fetchone()

        if rec is None:
            raise Record.NoSuchRecord('No record with ID %d' % recid)
        else:
            self.__read_rec(rec)

    def __read_rec(self, rec):
        for i, field in enumerate(self.__fields):
            self.__values[i] = rec[i]

    def count(self, condition=None):
        condition = '1' if condition is None else condition
        count = 0

        if self.__table is None:
            raise Record.DependencyError('Must set table first')

        rec = self.__db.query('select count(*) from %s where %s' % (self.__table, condition)).fetchone()
        if rec is not None:
            count = rec[0]

        return count

    def search(self, condition):
        if self.__table is None:
            raise Record.DependencyError('Must set table first')

        cursor = self.__db.query('select rowid, * from %s where %s' % (self.__table, condition))
        if cursor is None:
            print('No record found')
            return

        for rec in cursor:
            self.__read_rec(rec)
            self.show()

    def show(self):
        if self.__table is None:
            raise Record.DependencyError('Must set table first')

        output = FilteredOutput(self.__filter)
        delim = self.__db.get_delim()
        values = []
        cols = []

        for i, field in enumerate(self.__fields):
            if i < self.__dataindex: continue

            cols.append(libcsv.Value(field).minquoted(delim))

        output.write('%s\n' % delim.join(cols))

        for i, value in enumerate(self.__values):
            if i < self.__dataindex: continue

            v = libcsv.Value(value).strquoted(delim)

            if v == '':
                v = 'NULL'

            values.append(v)

        output.write('%s\n' % delim.join(values))

        del(output)

    def showindexes(self):
        if self.__table is None:
            raise Record.DependencyError('Must set table first')
        else:
            self.__db.execute('pragma index_list(%s)' % self.__table)


class Database(object):
    def __init__(self, delim=','):
        self.__conn = sqlite3.connect('')
        self.__tablenames = []
        self.__delim = delim
        self.__filter = None

    def get_delim(self):
        return self.__delim

    def set_filter(self, filter):
        self.__filter = filter

    def get_tables(self):
        tables = []

        for row in self.query("select name from sqlite_master where type='table' order by name"):
            tables.append(row[0])

        return tables

    def load(self, filename, tablename=None, delim=None):
        is_loaded = True

        if tablename is None:
            tablename = os.path.basename(filename)
            tablename = os.path.splitext(tablename)[0]

        with smart_open(filename) as file:
            reader = libcsv.Reader(file, delim=delim, has_header=True, is_multitable=False)
            insert_count = 0

            for row in reader:
                if len(row):
                    insert_count += self.insert(tablename, row)

            if insert_count > 0:
                is_loaded = True

        return is_loaded

    def insert(self, tablename, row):
        qs = self.__delim.join(['?'] * len(row))
        table_exists = tablename in self.__tablenames
        query_result = None
        insert_count = 0

        # Create the table if it does not exist
        if not table_exists:
            if self.create(tablename, row.header().as_list()):
                self.__tablenames.append(tablename)
                table_exists = True

        if table_exists:
            query_result = self.query('insert into %s values (%s)' % (tablename, qs), *row.as_stripped_list())

        if query_result:
            insert_count = query_result.rowcount

        return insert_count

    def create(self, tablenames, fieldnames):
        is_created = False

        fieldnames = map(lambda x: self.__quote(x), fieldnames)

        if self.query('create table %s (%s)' % (tablenames, self.__delim.join(fieldnames))):
            is_created = True

        return is_created

    def __quote(self, string):
        string = string.replace("'", "''")

        return "'%s'" % string

    def query(self, command, *args):
        result = None

        try:
            result = self.__conn.cursor().execute(command, args)
        except sqlite3.OperationalError as e:
            sys.stderr.write('%s %s -- %s\n' % (command, args, str(e)))

        return result

    def execute(self, command, *args):
        output = FilteredOutput(self.__filter)
        delim = self.__delim
        count = 0

        try:
            cursor = self.query(command, *args)
            if cursor is None: return

            # Execute the command and print any rows it returns
            for i, row in enumerate(cursor):
                values = []

                # Print column names
                if i == 0:
                    cols = []

                    for j, cell in enumerate(cursor.description):
                        cols.append(libcsv.Value(cell[0]).minquoted(delim))

                    output.write('%s\n' % delim.join(cols))

                # Print row values
                for j, cell in enumerate(row):
                    v = libcsv.Value(cell).strquoted(delim)

                    if v == '':
                        v = 'NULL'

                    values.append(v)

                output.write('%s\n' % delim.join(values))
                count += 1

        except sqlite3.OperationalError as e:
            sys.stderr.write('%s -- %s\n' % (command, str(e)))

        finally:
            del(output)


def smart_open(filename, mode='r'):
    '''Open a file for reading, treating '-' as a stdin or stdout, depending on
    the mode.'''

    if filename == '-':
        # Duplicate stdin/stdout so the caller can close it without closing stdin.
        fd = sys.stdin.fileno() if 'r' in mode else sys.stdout.fileno()
        fo = os.fdopen(os.dup(fd))
    else:
        fo = open(filename, mode)

    return fo


##############################################################################
# ENTRY POINT

if __name__ == '__main__':
    try:
        main()
    except KeyboardInterrupt:
        print('')
        sys.exit(errno.EOWNERDEAD)


# vim:ft=python
