import subprocess
import os
import re
import csv


class MalletError(Exception):
    pass


class MalletTopicModel(object):
    def __init__(self, mallet_dir, memory=1):
        '''
        Constructor
        mallet_dir : str : file path of Mallet-3 directory
        memory : int, float : maximum gigabytes of memory to allocate to Mallet
        '''
        self.memory = memory
        # self.mallet_dir = mallet_dir if mallet_dir.endswith('/') else mallet_dir + '/'
        self.mallet_dir = mallet_dir

        self.set_memory()
        self.set_output_dir()

    def set_memory(self):
        '''
        Sets the memory limit by editing a line in a batch file
        Default from constructor is 1gb
        '''
        # reads
        full_path = os.path.join(self.mallet_dir, r'bin\mallet.bat' if os.name == 'nt' else 'bin/mallet')
        with open(full_path, 'r') as file:
            filedata = file.read()
        # replaces memory line
        filedata = re.sub('set MALLET_MEMORY=.*', 'set MALLET_MEMORY=' + str(self.memory) + 'g', filedata)
        # writes
        with open(full_path, 'w') as file:
            file.write(filedata)

    def set_output_dir(self):
        '''
        Creates (if not already present) an output folder in the Mallet directory
        Sets stdout file path
        '''
        if not os.path.isdir(self.mallet_dir + 'output'):
            os.mkdir(self.mallet_dir + 'output')

        outdir_full_path = os.path.join(self.mallet_dir, 'output')
        self.output_dir = outdir_full_path
        self.out_full_path = os.path.join(self.output_dir, 'output.txt')
        self.stdout = self.out_full_path

    def call(self, command):
        '''
        Executes a command within the Mallet directory
        Redirects all output to a txt file
        '''

        print(command)
        result = subprocess.call(command, cwd=self.mallet_dir, stderr=subprocess.STDOUT, stdout=open(self.stdout, 'w'),
                                 shell=True)

    def execute(self, command, args):
        '''
        Cannot determine success/failure of Mallet command execution from exit codes (0 or 1)
        Checks whether the file that was supposed to be generated by a command is actually generated
        If the file did not previously exist, checks if it exists after command execution
        Otherwise, checks whether the modified timestamp of the file has changed
        '''
        start = None
        for arg in args:
            # for train_topics, this will be either 'output-topic-keys' or 'output-doc-topics' (not 'topic-word-weights-file')
            # it should not matter which of the three are checked; if the command fails, neither should be generated
            if 'output' in arg:
                if os.path.isfile(args[arg]):
                    start = os.path.getmtime(args[arg])
                break
        self.call(command)
        stop = None
        if os.path.isfile(args[arg]):
            stop = os.path.getmtime(args[arg])
        if not stop or start == stop:
            raise MalletError('\n\n' + args[arg] + ' was not generated from the command:\n\n"' + ' '.join(command) +
                              '"\n\nCheck ' + self.out_full_path + ' for details.\n')

    def build_command(self, operation, kwargs):
        '''
        Converts the command arguments from dictionary to list format
        '''
        command = [r'bin\mallet.bat' if os.name == 'nt' else 'bin/mallet', operation]
        for key, value in kwargs.items():
            if isinstance(value, list):
                command.extend(['--' + key.replace('_', '-'), ' '.join(value)])
            else:
                command.extend(['--' + key.replace('_', '-'), value])
        return command

    def import_dir(self, **kwargs):
        '''
        Calls the import-dir Mallet command
        Overwrites some arguments
        '''
        kwargs[
            'keep_sequence'] = True  # topic modeling currently only supports feature sequences: use --keep-sequence option when importing data.
        full_path = os.path.join(self.output_dir, 'import.mallet')
        kwargs['output'] = full_path
        command = self.build_command(operation='import-dir', kwargs=kwargs)
        self.execute(command, kwargs)

    def import_file(self, **kwargs):
        '''
        Calls the import-file Mallet command
        Overwrites some arguments
        '''

        full_path = os.path.join(self.output_dir, 'import.mallet')
        kwargs['output'] = full_path

        kwargs['keep_sequence'] = str(True) if os.name == 'nt' else True
        # topic modeling currently only supports feature sequences: use --keep-sequence option when importing data.
        kwargs['label-as-features'] = str(True) if os.name == 'nt' else True
        command = self.build_command(operation='import-file', kwargs=kwargs)
        self.execute(command, kwargs)

    def train_topics(self, **kwargs):
        '''
        Calls the train-topics Mallet command
        Overwrites some arguments
        '''
        instance_input = os.path.join(self.output_dir, 'import.mallet')
        kwargs['input'] = instance_input
        self.word_weights_file = os.path.join(self.output_dir, 'topic_word_weights.txt')
        kwargs['topic-word-weights-file'] = self.word_weights_file
        self.topic_keys_file = os.path.join(self.output_dir, 'topic_keys.txt')
        kwargs['output-topic-keys'] = self.topic_keys_file  # these can also be derived from topic-word-weights-file
        self.doc_topics_file = os.path.join(self.output_dir, 'doc_topics.txt')
        kwargs['output-doc-topics'] = self.doc_topics_file

        self.parameter_file = os.path.join(self.output_dir, 'topic_parameter.txt')
        kwargs['parameter-filename'] = self.parameter_file

        kwargs['num-top-words'] = str(30) if os.name == 'nt' else 30
        kwargs['num-iterations'] = str(2000) if os.name == 'nt' else 2000
        kwargs['num-topics'] = str(10) if os.name == 'nt' else 10

        # sets the doc-topics-threshold to a very low number if not specified
        # output-doc-topics format changes when this parameter is not specified
        # see https://github.com/mimno/Mallet/issues/41
        if not kwargs.get('doc-topics-threshold', False):
            kwargs['doc-topics-threshold'] = str(0.0001) if os.name == 'nt' else 0.0001
        command = self.build_command(operation='train-topics', kwargs=kwargs)
        self.execute(command, kwargs)

        # parses the files generated by the command into python objects
        self.parse_topic_keys()
        self.parse_doc_topics()
        self.parse_word_weights()

    def parse_topic_keys(self):
        '''
        Parses the output-topic-keys txt file generated by train_topics into a python object
        {topic # (int): {dirichlet parameter: float, words: list}, ... }
        '''
        self.topic_keys = dict()
        with open(self.topic_keys_file, mode='r', encoding='utf8') as file:
            for line in file:
                data = re.split('\t| ',
                                line.strip())  # cols 0, 1, and 2 are tab delimited, the rest are space delimited
                self.topic_keys[int(data[0])] = {'dirichlet': float(data[1]), 'words': self.process_words(data[2:])}

    def parse_doc_topics(self):
        '''
        Parses the output-doc-topics txt file generated by train_topics into a python object
        {document # (int): {document name: e.g. file path (str), topics: {topic # (int): weight (float), ... }}, ... }
        '''
        self.doc_topics = dict()
        with open(self.doc_topics_file, mode='r', encoding='utf8') as file:
            next(file)
            for line in file:
                data = line.strip().split('\t')
                self.doc_topics[int(data[0])] = {'name': data[1], 'topics': self.process_topics(data[2:])}

    def parse_word_weights(self):
        '''
        Parses the topic-word-weights txt file generated by train_topics into a python object
        {topic # (int): {word (str): weight (float)}, ... }
        '''
        self.word_weights = dict()
        with open(self.word_weights_file, mode='r', encoding='utf8') as file:
            for line in file:
                data = line.strip().split('\t')
                try:
                    self.word_weights[int(data[0])][data[1]] = float(data[2])
                except KeyError:
                    self.word_weights[int(data[0])] = {data[1]: float(data[2])}

    @staticmethod
    def process_words(words):
        return [x.replace('_', ' ') for x in words]

    @staticmethod
    def process_topics(topics):
        topic_weights = dict()
        for i in range(0, len(topics), 2):
            topic_weights[int(topics[i])] = float(topics[i + 1])
        return topic_weights


if __name__ == '__main__':
    model = MalletTopicModel('C:\mallet')
    # model.import_file(input=r'C:\mallet\topic_input\dblp_sample.txt')
    model.import_file(input=r'C:\mallet\topic_input\sample_dmr_input.txt')
    model.train_topics()

    print(model.topic_keys)  # see output_topic_keys parameter in Train Topics documentation
    print(model.doc_topics)  # see output_doc_topics parameter in Train Topics documentation
    print(model.word_weights)  # see topic_word_weights_file parameter in Train Topics documentation