#!/usr/bin/env python3
#
# Main program for extracting a dictionary from wiktionary.  This has
# mostly been used with enwiktionary, but should be usable with other
# wiktionaries as well.
#
# Copyright (c) 2018-2021 Tatu Ylonen.  See LICENSE and https://ylonen.org
#
# For pre-extracted data files, see https://kaikki.org/dictionary/

import os
import re
import sys
import html
import json
import pstats
import hashlib
import cProfile
import argparse
import collections
from wikitextprocessor import Wtp, ALL_LANGUAGES
from wiktextract import (WiktionaryConfig, parse_wiktionary,
                         reprocess_wiktionary, parse_page,
                         extract_namespace)
from wiktextract import extract_thesaurus_data
from wiktextract import extract_categories

# Compute set of all language names
all_langs = set(x["name"] for x in ALL_LANGUAGES)

# Pages whose titles have any of these prefixes are ignored.
ignore_prefixes = set(["Index", "Help", "MediaWiki", "Citations",
                       "Reconstruction", "Concordance",
                       "Rhymes", "Thread",
                       "Summary", "File",
                       "Transwiki",
])

# Pages with these prefixes are captured.
recognized_prefixes = set(["Category", "Appendix", "Wiktionary",
                           "Thesaurus", "Module", "Template"])


def capture_page(model, orig_title, text, pages_dir):
    """Checks if the page needs special handling (and maybe saving).
    Returns True if the page should be processed normally as a
    dictionary entry."""
    assert isinstance(model, str)
    assert isinstance(orig_title, str)
    assert isinstance(text, str)
    assert pages_dir is None or isinstance(pages_dir, str)

    analyze = False
    title = orig_title
    m = re.match(r"^([A-Z][a-z][-a-zA-Z0-9_]+):(.+)$", title)
    if not m:
        if len(title) > 100:
            h = hashlib.sha256()
            h.update(title.encode("utf-8"))
            title = title[:100] + "-" + h.hexdigest()[:10]
        title = "Words:" + title[:2] + "/" + title
        analyze = True
    else:
        prefix, tail = m.groups()
        if prefix in ignore_prefixes:
            analyze = False
        elif prefix not in recognized_prefixes:
            print("UNRECOGNIZED PREFIX", title)

    if pages_dir is not None:
        title = re.sub(r"//", "__slashslash__", title)
        title = re.sub(r":", "/", title)
        path = pages_dir + "/" + title + ".txt"
        path = re.sub(r"/\.+", lambda m: re.sub(r"\.", "__dot__", m.group(0)),
                      path)
        path = re.sub(r"//+", "/", path)
        dirpath = os.path.dirname(path)
        os.makedirs(dirpath, exist_ok=True)
        with open(path, "w", encoding="utf-8") as f:
            f.write("TITLE: {}\n".format(orig_title))
            text = html.unescape(text)
            f.write(text)

    return analyze


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Multilingual Wiktionary data extractor")
    parser.add_argument("path", type=str, nargs="?", default=None,
                        help="Input file (.../enwiktionary-<date>-"
                        "pages-articles.xml[.bz2])")
    parser.add_argument("--out", type=str, default=None,
                        help="Path where to write output (- for stdout)")
    parser.add_argument("--errors", type=str,
                        help="File in which to save error information")
    parser.add_argument("--language", type=str, action="append", default=[],
                        help="Language to capture (can specify multiple tiems, "
                        "defaults to English and Translingual)")
    parser.add_argument("--all-languages", action="store_true", default=False,
                        help="Extract words for all languages")
    parser.add_argument("--list-languages", action="store_true", default=False,
                        help="Print list of supported languages")
    parser.add_argument("--pages-dir", type=str, default=None,
                        help="Directory under which to save all pages")
    parser.add_argument("--all", action="store_true", default=False,
                        help="Capture everything for the selected languages")
    parser.add_argument("--translations", action="store_true", default=False,
                        help="Capture translations")
    parser.add_argument("--pronunciations", action="store_true", default=False,
                        help="Capture pronunciation information")
    parser.add_argument("--linkages", action="store_true", default=False,
                        help="Capture linkages (hypernyms, synonyms, etc)")
    parser.add_argument("--compounds", action="store_true", default=False,
                        help="Capture compound words using each word")
    parser.add_argument("--redirects", action="store_true", default=False,
                        help="Capture redirects")
    parser.add_argument("--examples", action="store_true", default=False,
                        help="Capture usage examples")
    parser.add_argument("--etymologies", action="store_true", default=False,
                        help="Capture etymologies")
    parser.add_argument("--statistics", action="store_true", default=False,
                        help="Print statistics")
    parser.add_argument("--page", type=str,
                        help="Parse a single Wiktionary page (for debugging)")
    parser.add_argument("--cache", type=str,
                        help="File prefix where phase1 results are saved; "
                        "speeds up processing a single page tremendously")
    parser.add_argument("--num-threads", type=int, default=None,
                        help="Number of parallel processes (default: #cpus)")
    parser.add_argument("--verbose", action="store_true", default=False,
                        help="Print verbose status messages (for debugging)")
    parser.add_argument("--human-readable", action="store_true", default=False,
                        help="Write output in human-readable JSON")
    parser.add_argument("--override", type=str, action="append",
                        help="Override module by one in file (for debugging)")
    parser.add_argument("--use-thesaurus", action="store_true", default=False,
                        help="Include thesaurus in single page mode")
    parser.add_argument("--profile", action="store_true", default=False,
                        help="Enable CPU time profiling")
    parser.add_argument("--categories-file", type=str,
                        help="Extract category tree as JSON in this file")
    parser.add_argument("--modules-file", type=str,
                        help="Extract all modules and save in this .tar file")
    parser.add_argument("--templates-file", type=str,
                        help="Extract all templates and save in this .tar file")
    args = parser.parse_args()

    # The --all option turns on capturing all data types
    if args.all:
        args.translations = True
        args.pronunciations = True
        args.linkages = True
        args.compounds = True
        args.redirects = True
        args.examples = True
        args.etymologies = True

    # If --list-languages has been specified, just print the list of supported
    # languages
    if args.list_languages:
        print("Supported languages:")
        for lang in sorted(all_langs):
            print("    {}".format(lang))
        sys.exit(0)

    if not args.path and not args.cache:
        print("The PATH argument for wiktionary dump file is normally "
              "mandatory.")
        print("Alternatively, --cache with --page can be used.")
        sys.exit(1)

    # Default to English and Translingual if language not specified.
    if not args.language:
        args.language = ["English", "Translingual"]
    else:
        for x in args.language:
            if x not in all_langs:
                print("Invalid language:", x)
                sys.exit(1)

    if args.all_languages:
        args.language = None
        print("Capturing words for all available languages")
    else:
        print("Capturing words for:", ", ".join(args.language))

    # Open output file.
    out_path = args.out
    if out_path and out_path != "-":
        if out_path.startswith("/dev/"):
            out_tmp_path = out_path
        else:
            out_tmp_path = out_path + ".tmp"
        out_f = open(out_tmp_path, "w", buffering=1024*1024, encoding="utf-8")
    else:
        out_tmp_path = out_path
        out_f = sys.stdout

    word_count = 0

    # Create expansion context
    ctx = Wtp(cache_file=args.cache, num_threads=args.num_threads)
    if args.override:
        for path in args.override:
            with open(path, "r", encoding="utf-8") as f:
                text = f.read()
            m = re.match(r"(?s)^TITLE: ([^\n]*)\n", text)
            if m:
                title = m.group(1)
                text = text[m.end():]
            else:
                print("First line of file supplied with --override must be "
                      "\"TITLE: <page title>\"")
                print("(The page title for this would normally start "
                      "with Module:")
                sys.exit(1)
            # Load it as a transient page, overriding the normal page
            ctx.add_page("Scribunto", title, text, transient=True)

    def word_cb(data):
        global word_count
        word_count += 1
        if args.human_readable:
            out_f.write(json.dumps(data, indent=2, sort_keys=True,
                                   ensure_ascii=False))
        else:
            out_f.write(json.dumps(data, ensure_ascii=False))
        out_f.write("\n")
        if not out_path or out_path == "-" or word_count % 1000 == 0:
            out_f.flush()

    def capture_cb(model, title, text):
        return capture_page(model, title, text, args.pages_dir)

    config = WiktionaryConfig(capture_languages=args.language,
                              capture_translations=args.translations,
                              capture_pronunciation=args.pronunciations,
                              capture_linkages=args.linkages,
                              capture_compounds=args.compounds,
                              capture_redirects=args.redirects,
                              capture_examples=args.examples,
                              capture_etymologies=args.etymologies,
                              verbose=args.verbose)

    if args.profile:
        pr = cProfile.Profile()
        pr.enable()

    try:
        if args.path:
            # Parse the normal full Wiktionary data dump
            parse_wiktionary(ctx, args.path, config, word_cb, capture_cb,
                             phase1_only=(args.page is not None))

        if args.page:
            # Parse a single Wiktionary page (extracted using --pages-dir)
            if not args.cache:
                print("NOTE: you probably want to use --cache with --page or "
                      "otherwise processing will be very slow.")
            # Load the page wikitext from the given file
            with open(args.page, "r", encoding="utf-8") as f:
                text = f.read()
            m = re.match(r"(?s)^TITLE: ([^\n]*)\n", text)
            if m:
                title = m.group(1)
                text = text[m.end():]
            else:
                title = "Test page"
            # Extract Thesaurus data (this is a bit slow for a single page, but
            # needed for debugging linkages with thesaurus extraction).  This
            # is disabled by default to speed up single page testing.
            if args.use_thesaurus:
                config.thesaurus_data = extract_thesaurus_data(ctx, config)
            # Parse the page
            ret = parse_page(ctx, title, text, config)
            for x in ret:
                word_cb(x)

        if not args.path and not args.page:
            # Parse again from the cache file
            reprocess_wiktionary(ctx, config, word_cb, capture_cb)

    finally:
        if out_path and out_path != "-":
            out_f.close()

    if args.modules_file:
        extract_namespace(ctx, "Module", args.modules_file)
    if args.templates_file:
        extract_namespace(ctx, "Template", args.templates_file)
    if args.categories_file:
        tree = extract_categories(ctx, config)
        print("Extracting category tree")
        sys.stdout.flush()
        with open(args.categories_file, "w") as f:
            json.dump(tree, f, indent=2, sort_keys=True)

    if args.profile:
        pr.disable()
        ps = pstats.Stats(pr).sort_stats(pstats.SortKey.CUMULATIVE)
        ps.print_stats()

    if out_path != out_tmp_path:
        try:
            os.remove(out_path)
        except FileNotFoundError:
            pass
        os.rename(out_tmp_path, out_path)

    if args.statistics:
        print("")
        print("LANGUAGE COUNTS")
        for k, cnt in sorted(config.language_counts.items(),
                             key=lambda x: -x[1]):
            print("  {:>7d} {}".format(cnt, k))
            if cnt < 1000:
                break
        print("  ...")
        print("")

        print("")
        print("POS HEADER USAGE")
        for k, cnt in sorted(config.pos_counts.items(),
                             key=lambda x: -x[1]):
            print("  {:>7d} {}".format(cnt, k))

        print("")
        print("POS SUBSECTION HEADER USAGE")
        for k, cnt in sorted(config.section_counts.items(),
                             key=lambda x: -x[1]):
            print("  {:>7d} {}".format(cnt, k))

        print("")
        print("{} WORDS CAPTURED".format(word_count))

    if args.errors:
        with open(args.errors, "w", encoding="utf-8") as f:
            json.dump({
                "errors": config.errors,
                "warnings": config.warnings,
                "debugs": config.debugs,
                }, f, sort_keys=True, indent=2)

    def dump_un(title, limit, counts, samples):
        items = []
        counts_ht = {}
        samples_ht = {}
        for k, v in counts.items():
            counts_ht[k] = v
        lst = list(sorted(counts_ht.items(), reverse=True,
                          key=lambda x: x[1]))
        if lst:
            print(title)
            for k, cnt in lst[:limit]:
                print("{:5d} {}".format(cnt, k))
                for kk, vv in samples[k].items():
                    for sample in vv:
                        print("        {}".format(sample))

    def dump_val(title, limit, counts):
        items = []
        counts_ht = collections.defaultdict(int)
        for la, la_val in counts.items():
            for name, name_val in la_val.items():
                for value, cnt in name_val.items():
                    counts_ht[name, value] += cnt
        for la, la_val in counts.items():
            for name, name_val in la_val.items():
                for value, v in name_val.items():
                    counts_ht[name, value] = v
        lst = list(sorted(counts_ht.items(), reverse=True,
                          key=lambda x: x[1]))
        if lst:
            print("")
            print(title)
            for (k, kk), v in lst[:limit]:
                print("  {:5d} {}: {}".format(v, k, kk))
