#!/usr/bin/env python3

from argparse import ArgumentParser, Namespace
from locale import strxfrm
from shutil import get_terminal_size
from signal import SIG_DFL, SIGPIPE, signal
from sys import stderr, stdin, stdout
from typing import Any, Callable

from yaml import BaseDumper, SafeDumper, add_representer, safe_dump_all, safe_load_all
from yaml.nodes import ScalarNode
from yaml.scanner import ScannerError


def recur_sort(data: Any) -> Any:
    if type(data) is dict:
        return {k: recur_sort(data[k]) for k in sorted(data, key=strxfrm)}
    elif type(data) is list:
        return [recur_sort(el) for el in data]
    else:
        return data


def parse_args() -> Namespace:
    cols, _ = get_terminal_size((80, -1))
    parser = ArgumentParser()
    parser.add_argument("-i", "--indent", type=int, default=2)
    parser.add_argument("-w", "--width", type=int, default=cols)
    return parser.parse_args()


def repr_str(break_pt: int) -> Callable[[BaseDumper, str], ScalarNode]:
    def repr_str(dumper: BaseDumper, data: str) -> ScalarNode:
        style = ">" if len(data) > break_pt else ""
        return dumper.represent_scalar("tag:yaml.org,2002:str", data, style=style)

    return repr_str


def main() -> None:
    args = parse_args()
    data = stdin.read()

    fold_pt = args.width // 2
    add_representer(str, repr_str(fold_pt), Dumper=SafeDumper)

    try:
        yaml = [*safe_load_all(data)]
    except ScannerError:
        print(f"ERROR! -- Failed to parse:\n\n{data}", file=stderr)
        exit(1)
    else:
        new = recur_sort(yaml)
        safe_dump_all(
            new,
            stdout,
            allow_unicode=True,
            explicit_start=True,
            width=args.width,
            indent=args.indent,
        )


try:
    signal(SIGPIPE, SIG_DFL)
    main()
except KeyboardInterrupt:
    exit(130)
