from typing import List, Iterator
from pathlib import Path
from multiprocessing import Pool

from eleve.memory import MemoryStorage as Storage
from eleve.memory import CSVStorage
from eleve import Segmenter

from eleve.preprocessing import chinese


def preproc(l: str) -> List[str]:
    chunks = chinese.tokenize_by_unicode_category(l)
    return [cjk for cjk in chinese.filter_cjk(chunks)]


def train_batch(corpus: Iterator[str], ng, prune):
    storage = Storage(ng)
    pool = Pool()
    buf = []
    voc = set()
    counter = 0
    for line in corpus:
        buf.append(line)
        counter += 1
        if counter % 100000 == 0:
            print("read", counter)
            lines = pool.map(preproc, buf)
            for chunks in lines:
                for chunk in chunks:
                    storage.add_sentence(list(chunk))
            if prune > 0:
                print("pruning")
                storage.prune(prune)
            buf = []
    if len(buf) > 0:
        print("read", counter)
        lines = pool.map(preproc, buf)
        for chunks in lines:
            for chunk in chunks:
                storage.add_sentence(list(chunk))
                for i in range(len(chunk)):
                    for j in range(i + 1, min(i + 6, len(chunk))):
                        voc.add(chunk[i:j])
    pool.terminate()
    pool.close()
    return storage, voc


def train(file: Path, size: int, ng, prune):
    def reader():
        i = 0
        with open(file) as f:
            for line in f:
                if size and size < i:
                    break
                yield line
                i += 1
    return train_batch(reader(), ng, prune)


def segment_batch(storage: Storage, batch: List[str]) -> List[str]:
    segmenter = Segmenter(storage)
    result = []
    # pool = Pool()
    # for line in pool.map(lambda x: chinese.segment_with_preprocessing(segmenter,x), batch):
    for line in batch:
        line = line[:-1]
        if line.strip() != "":
            result.append(chinese.segment_with_preprocessing(segmenter, line))
    return result


def segment_file(storage, input_file: Path, output_file: Path, bies:bool=False):
    segmenter = Segmenter(storage)
    with open(input_file) as f:
        with open(output_file, "w") as out:
            buf = []
            i = 0
            for line in f:
                line = line[:-1]
                if line.strip() != "":
                    buf.append(line)
                    i += 1
                    if i % 100000 == 0:
                        for l in chinese.segment_with_preprocessing_pool(segmenter, buf, bies):
                            out.write(l)
                            out.write("\n")
                        buf = []
            if len(buf) > 0:
                for l in chinese.segment_with_preprocessing_pool(segmenter, buf, bies):
                    out.write(l)
                    out.write("\n")


def main():
    from argparse import ArgumentParser
    parser = ArgumentParser()
    parser.add_argument('action',
                        help='train or segment',
                        default=False)
    parser.add_argument('-c', '--corpus',
                        help='source file (text corpus)',
                        default=False)
    parser.add_argument('-m', '--model',
                        help='model file (csv)',
                        default=False)
    parser.add_argument('-t', '--target',
                        help='target file (segmented text)',
                        default=False)
    parser.add_argument('-L', '--training_length',
                        help='number of lines read for training',
                        default=100000,
                        type=int,
                        required=False)
    parser.add_argument('-l', '--ngram_length',
                        help='ngram maximum length to be stored',
                        default=5,
                        type=int,
                        required=False)
    parser.add_argument('-p', '--prune',
                        help="prune the tries every 100 000 lines",
                        default=0,
                        type=int,
                        required=False)
    parser.add_argument('--bies', help="mark word segmentation with -B -I -E -S tags rather than adding spaces",
                        default=False,
                        action="store_true")
    args = parser.parse_args()
    if args.action == "train":
        assert args.corpus
        storage, voc = train(Path(args.corpus), args.training_length, args.ngram_length, args.prune)
        storage.update_stats()
        if args.model:
            CSVStorage.writeCSV(storage, voc, args.model)
    elif args.action == "segment":
        assert(args.corpus and args.target)
        if args.model:
            storage = CSVStorage(args.model)
        else:
            storage, _ = train(Path(args.corpus), args.training_length,args.ngram_length, args.prune)
            storage.update_stats()
        segment_file(storage, Path(args.corpus), Path(args.target), bies=args.bies)


if __name__ == "__main__":
    main()
