#!/usr/bin/env python3

from argparse import ArgumentParser, Namespace
from shutil import get_terminal_size
from sys import stderr, stdin, stdout
from typing import Callable

from yaml import BaseDumper, SafeDumper, add_representer, safe_dump_all, safe_load_all
from yaml.nodes import ScalarNode
from yaml.scanner import ScannerError

from lib import recur_sort, sig_trap


def parse_args() -> Namespace:
    cols, _ = get_terminal_size((80, -1))
    parser = ArgumentParser()
    parser.add_argument("-i", "--indent", type=int, default=2)
    parser.add_argument("-w", "--width", type=int, default=cols)
    return parser.parse_args()


def repr_str(break_pt: int) -> Callable[[BaseDumper, str], ScalarNode]:
    def repr_str(dumper: BaseDumper, data: str) -> ScalarNode:
        style = ">" if len(data) > break_pt else ""
        return dumper.represent_scalar("tag:yaml.org,2002:str", data, style=style)

    return repr_str


def main() -> None:
    args = parse_args()

    fold_pt = args.width // 2
    add_representer(str, repr_str(fold_pt), Dumper=SafeDumper)

    try:
        yaml = [*safe_load_all(stdin)]
    except ScannerError:
        print("parse error!", file=stderr)
        exit(1)
    else:
        new = recur_sort(yaml)
        safe_dump_all(
            new,
            stdout,
            allow_unicode=True,
            explicit_start=True,
            width=args.width,
            indent=args.indent,
        )


try:
    sig_trap()
    main()
except KeyboardInterrupt:
    exit(130)
except BrokenPipeError:
    exit(13)
