"""Snakefile used by Snakemake."""
from pathlib import Path

from rich import box
from rich.highlighter import ReprHighlighter
from rich.padding import Padding
from rich.table import Table
from rich.text import Text
from snakemake.logging import logger

from sparv.api import SparvErrorMessage, util
from sparv.core import config as sparv_config
from sparv.core import paths, registry, snake_utils, snake_prints
from sparv.core.console import console

# Remove Snakemake's default log handler
if config.get("run_by_sparv") and logger.log_handler and logger.log_handler[0] == logger.text_handler:
    logger.log_handler = []

# Don't do anything if no rule was specified
rule do_nothing:
    input: []

selected_targets = config.get("targets", [])  # Explicitly selected rule names

# ==============================================================================
# Dynamic Creation of Snakemake Rules
# ==============================================================================

def make_rules(config_missing: bool) -> None:
    """Load all Sparv modules and create Snakemake rules."""
    # Get preloader info
    if config.get("socket") and not config.get("preloader"):
        from sparv.core import preload
        try:
            snake_storage.preloader_info = preload.get_preloader_info(config["socket"])
        except ConnectionRefusedError:
            raise SparvErrorMessage("Could not connect to the socket '{}'".format(config["socket"]))


    # Create rules for all available annotation functions
    for module_name in registry.modules:
        for f_name, f in registry.modules[module_name].functions.items():
            make_rule(module_name, f_name, f, config_missing)
    
    # Create custom rules
    for custom_rule_obj in sparv_config.get("custom_annotations", []):
        module_name, f_name = custom_rule_obj["annotator"].split(":")
        annotator = registry.modules[module_name].functions[f_name]
        make_rule(module_name, f_name, annotator, config_missing, custom_rule_obj)

    # Check and set rule orders
    ordered_rules = snake_utils.check_ruleorder(snake_storage)
    for rule1, rule2 in ordered_rules:
        # ruleorder:  rule1.rule_name > rule2.rule_name
        workflow.ruleorder(rule1.rule_name, rule2.rule_name)
    # Print ordered rules when in debug mode
    if config.get("debug") and ordered_rules:
        console.print("\n\n\n[b]ORDERED RULES:[/b]")
        for rule1, rule2 in ordered_rules:
            print("    • {} > {}".format(rule1.rule_name, rule2.rule_name))
        print()


def make_rule(module_name: str, f_name: str, annotator_info: dict, config_missing: bool = False,
              custom_rule_obj: dict = None) -> None:
    """Create single Snakemake rule."""
    # Init rule storage
    rule_storage = snake_utils.RuleStorage(module_name, f_name, annotator_info)

    # Process rule parameters and update rule storage
    # Save original config and restore afterwards, as custom rules may replace the config
    original_config = sparv_config.config
    create_rule = snake_utils.rule_helper(rule_storage, config, snake_storage, config_missing, custom_rule_obj)
    sparv_config.config = original_config

    # Limit number of parallel threads for this rule if requested in config
    resources = {}
    if "threads" in config:
        thread_limit = sparv_config.get(sparv_config.MAX_THREADS, {}).get(rule_storage.target_name)
        if thread_limit:
            resources["threads"] = config["threads"] // thread_limit

    if create_rule:
        # Create a Snakemake rule for annotator
        rule:
            name: rule_storage.rule_name
            message: rule_storage.target_name
            input: rule_storage.inputs
            output: rule_storage.outputs
            params:
                module_name=rule_storage.module_name,
                f_name=rule_storage.f_name,
                parameters=snake_utils.get_parameters(rule_storage),
                export_dirs=rule_storage.export_dirs,
                source_file=snake_utils.file_value(rule_storage),
                use_preloader=rule_storage.use_preloader,
                socket=config.get("socket"),
                force_preloader=config.get("force_preloader", False),
                compression=sparv_config.get("sparv.compression")
            resources: **resources
            # We use "script" instead of "run" since with "run" the whole Snakefile would have to be reloaded for every
            # single job, due to how Snakemake creates processes for run-jobs.
            script: "run_snake.py"

        # Create rule to run this annotation on all input files
        make_all_files_rule(rule_storage)


def make_all_files_rule(rule_storage: snake_utils.RuleStorage) -> None:
    """Create named rule to run an annotation on all input files."""
    # Only create rule when explicitly called
    if config.get("run_by_sparv") and rule_storage.target_name not in selected_targets:
        return

    # Get Snakemake rule object
    sm_rule = getattr(rules, rule_storage.rule_name).rule

    dependencies = rule_storage.outputs if not rule_storage.abstract else rule_storage.inputs

    # Prepend work dir to paths if needed (usually included in the {file} wildcard but here it needs to be explicit)
    rule_outputs = [paths.work_dir / o if not (paths.work_dir in o.parents or paths.export_dir in o.parents)
                    else o
                    for o in dependencies]

    # Expand {file} wildcard to every corpus file
    rule_outputs = expand(rule_outputs,
                          file=snake_utils.get_file_values(config, snake_storage),
                          **snake_utils.get_wildcard_values(config))

    rule:
        name: rule_storage.target_name
        input: rule_outputs

    if not rule_storage.abstract:
        # Set rule dependencies for every file, so Snakemake knows which rule to use in case of ambiguity.
        # Converting the values of rule_outputs to snakemake.io.IOFile objects is not enough, since file paths must match
        # for that to work (which they don't do once we've expanded the {file} wildcard).
        this_sm_rule = getattr(rules, rule_storage.target_name).rule
        for f in this_sm_rule.input:
            this_sm_rule.dependencies[f] = sm_rule

# Init the storage for some essential variables involving all rules
snake_storage = snake_utils.SnakeStorage()

# Find and load corpus config
config_missing = snake_utils.load_config(config)

# Find and load Sparv modules
registry.find_modules(find_custom=bool(sparv_config.get("custom_annotations")))

# Resolve presets in config
sparv_config.apply_presets()

# Add classes from config to registry
registry.annotation_classes["config_classes"] = sparv_config.config.get("classes", {})

# Set text class from text annotation
if not config_missing:
    sparv_config.handle_text_annotation()

# Let exporters and importers inherit config values from 'export' and 'import' sections
for module in registry.modules:
    for a in registry.modules[module].functions.values():
        if a["type"] == registry.Annotator.importer:
            sparv_config.inherit_config("import", module)
        elif a["type"] == registry.Annotator.exporter:
            sparv_config.inherit_config("export", module)

# Collect list of all explicitly used annotations (without class expansion)
for key in registry.annotation_sources:
    registry.explicit_annotations_raw.update(a[0] for a in util.misc.parse_annotation_list(sparv_config.get(key, [])))

# Figure out classes from annotation usage
registry.find_implicit_classes()

# Collect list of all explicitly used annotations (with class expansion)
for key in registry.annotation_sources:
    registry.explicit_annotations.update(
        registry.expand_variables(a[0])[0]
        for a in util.misc.parse_annotation_list(sparv_config.get(key, [])))

# Load modules and create automatic rules
make_rules(config_missing)

# Validate config usage in modules
sparv_config.validate_module_config()

# Validate config
sparv_config.validate_config()

# Get reverse_config_usage dict for look-ups
reverse_config_usage = snake_utils.get_reverse_config_usage()

# Abort if all selected targets require source files but no files are available
if not config_missing and selected_targets and not snake_storage.source_files:
    named_targets = set(i for i, _ in snake_storage.named_targets)
    named_targets.update(("export_corpus", "install_corpus"))
    if set(selected_targets).issubset(named_targets):
        raise SparvErrorMessage("No source files available!")

# ==============================================================================
# Static Snakemake Rules
# ==============================================================================

# Rule to list all config options and their current values
rule config:
    run:
        if config.get("options"):
            out_conf = {}
            for k in config["options"]:
                out_conf[k] = sparv_config.get(k)
        else:
            out_conf = sparv_config.config
        snake_prints.prettyprint_yaml(out_conf)


# Rule to list all modules and annotations
rule modules:
    run:
        if config.get("types") or config.get("names"):
            snake_prints.print_module_info(config.get("types", []), config.get("names", []), snake_storage,
                                           reverse_config_usage)
        else:
            snake_prints.print_module_summary(snake_storage)


# Rule to list all annotation classes
rule classes:
    run:
        snake_prints.print_annotation_classes()


# Rule to list all supported languages
rule languages:
    run:
        snake_prints.print_languages()


# Rule to list all annotation presets
rule presets:
    run:
        resolved_presets = {}
        for preset, annots in sparv_config.presets.items():
            preset_annotations, _ = sparv_config.resolve_presets(sparv_config.presets[preset], {}, {})
            resolved_presets[preset] = preset_annotations
        snake_prints.prettyprint_yaml(resolved_presets)


# Rule to list all targets
rule list_targets:
    run:
        print()
        table = Table(title="Available rules", box=box.SIMPLE, show_header=False, title_justify="left")
        table.add_column(no_wrap=True)
        table.add_column()

        table.add_row("[b]Exports[/b]")
        for target, desc, _lang in sorted(snake_storage.export_targets):
            table.add_row("  " + target, desc)
        table.add_row()
        table.add_row("[b]Installers[/b]")
        for target, desc, _ in sorted(snake_storage.install_targets):
            table.add_row("  " + target, desc)
        table.add_row()
        table.add_row("[b]Uninstallers[/b]")
        for target, desc in sorted(snake_storage.uninstall_targets):
            table.add_row("  " + target, desc)
        table.add_row()
        table.add_row("[b]Importers[/b]")
        for target, desc in sorted(snake_storage.import_targets):
            table.add_row("  " + target, desc)
        table.add_row()
        table.add_row("[b]Annotators[/b]")
        for target, desc in sorted(snake_storage.named_targets):
            table.add_row("  " + target, desc)
        table.add_row()
        table.add_row("[b]Model Builders[/b]")
        for target, desc, _lang in sorted(snake_storage.model_targets):
            table.add_row("  " + target, desc)
        table.add_row()
        table.add_row("[b]Custom Rules[/b]")
        for target, desc in sorted(snake_storage.custom_targets):
            table.add_row("  " + target, desc)
        console.print(table)
        note = Text.from_markup("[i]Note:[/i] Custom rules need to be declared in the 'custom_annotations' section "
                                "of your corpus configuration before they can be used.")
        ReprHighlighter().highlight(note)
        console.print(Padding(note, (0, 4)))


# Rule to list all exports
rule list_exports:
    run:
        selected_export_names = sparv_config.get("export.default", [])
        selected_exports = [(t, d, l) for t, d, l in sorted(snake_storage.export_targets) if t in selected_export_names]
        other_exports = [(t, d, l) for t, d, l in sorted(snake_storage.export_targets) if not t in selected_export_names]

        if selected_exports:
            print()
            table = Table(title="Selected corpus exports (output formats)", box=box.SIMPLE, show_header=False,
                          title_justify="left")
            table.add_column(no_wrap=True)
            table.add_column()
            for target, desc, language in selected_exports:
                if not language or registry.check_language(sparv_config.get("metadata.language"), language,
                                                        sparv_config.get("metadata.variety")):
                    table.add_row(target, desc)
            console.print(table)

        if other_exports:
            print()
            title = "Other available corpus exports (output formats)"
            table = Table(title=title, box=box.SIMPLE, show_header=False, title_justify="left")
            table.add_column(no_wrap=True)
            table.add_column()
            for target, desc, language in other_exports:
                if not language or registry.check_language(sparv_config.get("metadata.language"), language,
                                            sparv_config.get("metadata.variety")):
                    table.add_row(target, desc)
            console.print(table)

        console.print("[i]Note:[/i] Use the 'export.default' section in your corpus configuration to select what "
                      "exports should be produced when running 'sparv run' without arguments. If 'export.default' is "
                      "not specified, 'xml_export:pretty' is run by default.")


# Rule to list all input files
rule files:
    run:
        from rich.columns import Columns
        print("Available input files:\n")
        # Convert to Text to get rid of syntax highlighting
        console.print(Columns([Text(f) for f in sorted(snake_storage.source_files)], column_first=True,
                              padding=(0, 3)))


# Rule to remove dirs created by Sparv
rule clean:
    run:
        import shutil
        to_remove = []
        if config.get("export") or config.get("all"):
            to_remove.append(paths.export_dir)
            assert paths.export_dir, "Export dir name not configured."
        if config.get("logs") or config.get("all"):
            to_remove.append(paths.log_dir)
            assert paths.log_dir, "Log dir name not configured."
        if config.get("all") or not (config.get("export") or config.get("logs")):
            to_remove.append(paths.work_dir)
            assert paths.work_dir, "Work dir name not configured."

        something_removed = False
        for d in to_remove:
            full_path = Path.cwd() / d
            if full_path.is_dir():
                shutil.rmtree(full_path)
                snake_utils.print_sparv_info(f"'{d}' directory removed")
                something_removed = True
        if not something_removed:
            snake_utils.print_sparv_info("Nothing to remove")


# Rule to list all available installers
rule list_installs:
    run:
        snake_prints.print_installers(snake_storage)


# Rule to list all available uninstallers
rule list_uninstalls:
    run:
        snake_prints.print_installers(snake_storage, uninstall=True)


# Rule for making exports defined in corpus config
if "export_corpus" in selected_targets:
    export_targets = snake_utils.get_export_targets(snake_storage, rules,
        file=snake_utils.get_file_values(config, snake_storage), wildcards=snake_utils.get_wildcard_values(config))

    rule export_corpus:
        input: [f for r in export_targets for f in r[1]]

    # Set rule dependencies for every file, so Snakemake knows which rule to use in case of ambiguity
    for r, ff in export_targets:
        if r is None:
            continue
        for f in ff:
            rules.export_corpus.rule.dependencies[f] = r

# Rule for making installations
if "install_corpus" in selected_targets:
    install_inputs = snake_utils.get_install_outputs(snake_storage, config.get("install_types"))
    if not install_inputs:
        raise SparvErrorMessage("Please specify what you would like to install, either by supplying arguments "
                                "(e.g. 'sparv install xml_export:install') or by adding installers to "
                                "the 'install' section of your corpus config file.\n"
                                "You can list available installers with 'sparv install -l'.")
    else:
        rule install_corpus:
            input:
                install_inputs


# Rule for making uninstallations
if "uninstall_corpus" in selected_targets:
    uninstall_inputs = snake_utils.get_install_outputs(snake_storage, config.get("uninstall_types"), uninstall=True)
    if not uninstall_inputs:
        raise SparvErrorMessage("Please specify what you would like to uninstall, either by supplying arguments "
                                "(e.g. 'sparv uninstall xml_export:uninstall') or by adding uninstallers to "
                                "the 'uninstall' section of your corpus config file. If the 'uninstall' setting is not "
                                "set, any uninstallers connected to the installers in the 'install' section will be "
                                "used instead.\n"
                                "You can list available uninstallers with 'sparv uninstall -l'.")
    else:
        rule uninstall_corpus:
            input:
                uninstall_inputs

# Rule to list all models that can be built/downloaded
rule list_models:
    run:
        print()
        if sparv_config.get("metadata.language"):
            if sparv_config.get("metadata.language") not in registry.languages:
                console.print("Unsupported language: " + sparv_config.get("metadata.language"))
                return
            table = Table(title="Models for current language ({})".format(
                registry.languages[sparv_config.get("metadata.language")]),
                          box=box.SIMPLE, show_header=False, title_justify="left")
            table.add_column(no_wrap=True)
            table.add_column()
            for target, desc, language in sorted(snake_storage.model_targets):
                if language and registry.check_language(sparv_config.get("metadata.language"), language,
                                                        sparv_config.get("metadata.variety")):
                    table.add_row(target, desc)
            console.print(table)

        table = Table(title="Language-independent models", box=box.SIMPLE, show_header=False, title_justify="left")
        table.add_column(no_wrap=True)
        table.add_column()
        for target, desc, language in sorted(snake_storage.model_targets):
            if not language:
                table.add_row(target, desc)
        console.print(table)

        if not sparv_config.get("metadata.language") in registry.languages:
            console.print("To list or build models for a specific language, use the '--language' argument together "
                          "with the language code of one of the supported languages below:")
            table = Table(box=box.SIMPLE, show_header=False)
            for language, name in sorted(registry.languages.items(), key=lambda x: x[1]):
                table.add_row(name, language)
            console.print(table)

        print()
        console.print("[i]Note:[/i] The 'build-models' command is entirely optional, as Sparv will download and build "
                      "models automatically when needed.")


# Rule to list all annotations files that can be created
rule list_files:
    run:
        outputs = set([(rule.type, o) for rule in snake_storage.all_rules
                       for o in (rule.outputs if not rule.abstract else rule.inputs)])

        print()
        print("This is a list of files than can be created by Sparv. Please note that wildcards must be replaced "
              "with paths.")
        print()
        console.print("[i]Annotation files[/i]\n")
        for i in sorted(o for t, o in outputs if t in ("annotator", "importer")):
            print("    {}".format(i))
        print()
        console.print("[i]Export files[/i]\n")
        for i in sorted(o for t, o in outputs if t == "exporter"):
            print("    {}".format(i))
        print()
        console.print("[i]Model files[/i]\n")
        for i in sorted(o for t, o in outputs if t == "modelbuilder"):
            print("    {}".format(i))
        print()
        console.print("[i]Installation files[/i]\n")
        for i in sorted(o for t, o in outputs if t == "installer"):
            print("    {}".format(i))


# Build all models
rule build_models:
    input:
        snake_storage.model_outputs


rule preload:
    run:
        from sparv.core import preload
        if config["preload_command"] == "start":
            preload.serve(config["socket"], config["processes"], snake_storage)
        elif config["preload_command"] == "stop":
            if not Path(config["socket"]).is_socket():
                raise SparvErrorMessage(f"Socket file '{config['socket']}' doesn't exist or isn't a socket.")
            elif not preload.stop(config.get("socket")):
                raise SparvErrorMessage(f"Could not connect to socket '{config.get('socket')}'.")


rule preload_list:
    run:
        print()
        table = Table(title="Annotators available for preloading", box=box.SIMPLE, show_header=False,
            title_justify="left")
        table.add_column(no_wrap=True)
        table.add_column()

        for module, annotators in sorted(snake_storage.all_preloaders.items()):
            for a in sorted(annotators):
                table.add_row(f"{module}:{a}", registry.modules[module].functions[a]["description"])

        console.print(table)
