#!/usr/bin/env python

"""Embed source code in markdown files."""
import os
import re
import sys

EMBED_RE = re.compile(".*<!-- embed:(.*) -->")


def snippet_lang(filename):
    """Determine code snippet language from source file extension."""
    extension = filename.split(".")[-1]
    if extension in ["js", "cjs", "mjs"]:
        return "javascript"
    elif extension in ["java", "jav"]:
        return "java"
    elif extension in ["c", "h", "cpp", "hpp", "c++", "cxx"]:
        return "c_cpp"
    elif extension in ["mm"]:
        return "objc++"
    elif extension in ["m"]:
        return "objc"
    elif extension in ["cs", "csx", "linq"]:
        return "csharp"
    elif extension == "py":
        return "python"
    else:
        # Markdown rendering engines may have automatic language detection.
        # So for extensions we don't recognize, leave the language empty, to
        # allow auto-lang-detection to take place.
        return ""


def find_scope_end(initial_line, it, open_seq, close_seq):
    """
    Find the end of a given scope.

    This function starts counting open and close sequences ({} pairs) starting
    at initial_line, and iterates through the subsequent lines using iterator
    "it".
    """
    found_scope = False
    stack_size = 0
    stack_size += initial_line.count(open_seq)
    out_content = [initial_line]
    for line in it:
        out_content.append(line)
        stack_size += line.count(open_seq)
        if stack_size > 0:
            found_scope = True
        stack_size -= line.count(close_seq)
        if found_scope and stack_size <= 0:
            return out_content
    return out_content


def fetch_scope(script_content, scopes):
    """Fetch lines enclosed in a scope."""
    # TODO: support python, html
    it = iter(script_content)
    start_line = None
    for scope in scopes:
        temp_line = None
        for line in it:
            if scope in line:
                temp_line = line
                break
        if not temp_line:
            raise Exception(
                "Failed to find scope {s} of {ss}".format(s=scope, ss=scopes)
            )
        start_line = temp_line
    return find_scope_end(start_line, it, "{", "}")


def fetch_section(script_content, sections):
    """Fetch lines enclosed in a section."""
    it = iter(script_content)
    start_pattern, end_pattern = sections

    assert (
        start_pattern not in end_pattern
    ), "Start pattern {s} must not be a substring of end pattern {e}".format(
        s=start_pattern, e=end_pattern
    )
    ret = None
    for line in it:
        if start_pattern in line:
            ret = find_scope_end(line, it, start_pattern, end_pattern)
            break
    if not ret:
        raise Exception(
            "Unable to locate section from {start} to {end}".format(
                start=start_pattern, end=end_pattern
            )
        )

    # Cut away start_pattern and end_pattern
    return ret[1:-1]


def run_commands(root, commands):
    """Run a command."""
    filepath, cmd = commands[:2]
    subcommands = commands[2:]
    filepath = os.path.join(root, filepath)
    ret = None
    with open(filepath, "r") as script_file:
        script_content = script_file.read().split("\n")
        if cmd == "scope":
            ret = fetch_scope(script_content, subcommands)
        elif cmd == "section":
            ret = fetch_section(script_content, subcommands)
        else:
            raise Exception("Unhandled commands {cmd}".format(cmd=cmd))

        if not ret:
            raise Exception(
                "Did not find class {p} from {f}.".format(
                    p=subcommands, f=filepath
                )
            )
        return ret


def embed(root, md_content, should_strip=False):
    """Process all embed commands in a markdown file content."""
    out_content = list()
    it = iter(md_content)
    for line in it:
        embed_re = EMBED_RE.match(line)
        if not embed_re:
            out_content.append(line)
            continue

        commands = embed_re.group(1).split(":")
        filename = commands[0]
        lang = snippet_lang(filename)
        snippet_start = "```" + lang
        if line.startswith(snippet_start):
            find_scope_end(line, it, snippet_start, "```")

        if should_strip:
            out_content.append(
                "<!-- embed:{g} -->".format(g=embed_re.group(1))
            )
        else:
            out_content.append(
                "{ss} <!-- embed:{g} -->".format(
                    ss=snippet_start, g=embed_re.group(1)
                )
            )
            out_content.extend(run_commands(root, commands))
            out_content.append("```")
    return out_content


def process_docs(root, docs, strip):
    """Process a list of markdown files."""
    for f in docs:
        if not f.endswith(".md"):
            continue

        with open(os.path.join(root, f), "r+") as ff:
            content = embed(root, ff.read().split("\n"), strip)
            ff.seek(0)
            ff.truncate()
            ff.write("\n".join(content))


def process_directory():
    """
    Process all markdown files in the current directory (and all subdirectories).
    """
    strip = sys.argv[1] if len(sys.argv) > 1 else False

    for root, dirs, files in os.walk(os.getcwd()):
        if "node_modules" in dirs:
            dirs.remove("node_modules")
        process_docs(root, files, strip)


if __name__ == "__main__":
    exit(process_directory())
