#!/usr/bin/python
import re
import argparse
from loguru import logger

def debug(message):
    if args.debug:
        logger.debug(message)

def info(message): 
    logger.info(message)


parser = argparse.ArgumentParser()
parser.add_argument("file")
parser.add_argument("-v", "--varfile")
parser.add_argument("--nostata", action="store_true")
parser.add_argument("-i", action="store_true")
parser.add_argument("-c")
parser.add_argument("--cline")
parser.add_argument("--noheader", action="store_true")
parser.add_argument("--condensed", action="store_true")
parser.add_argument("--debug", action="store_true")

args = parser.parse_args()

if args.c:
    args_c = args.c.split()
else:
    args_c = []

changes = [
    ("Observations", "Obs."),
    ("Adjusted R-squared", "Adj. $R^2$"),
    ("R-squared", "$R^2$"),
    #  ('VARIABLES', ''),
]

try:
    with open(args.varfile, "r") as v:
        var = [x.strip().split(" ", 1) for x in v.readlines() if x.strip()]
except:
    var = None
    info("cannot open varfile ...")

with open(args.file, "r") as f:
    c = f.readlines()


def varname(row):
    if var:
        for x, y in var:
            try:
                row = re.sub(r"\b" + x + r"\b", y, row)
            except:
                row = row.replace(x, y)
    return row


def formatrow(row):
    row = re.sub(r"(\d+)(\*+)", r"\1&\2", row)
    row = re.sub(r"(?<![\*\\])\s*&(?!\s*\*)", r"&&", row)
    row = re.sub(r"(?<![\*\w])\s+\\\\", r"&\\\\", row)
    row = re.sub(r"\*\s+&", "*", row)
    row = re.sub(r"&&", r"&", row, 1)
    # adjust star styles
    row = re.sub(r"(\*{1,3})", r"$^{\1}$", row)
    return row


def repeat_title(row):
    row1 = row
    row2 = "&".join(
        [x for x in row.split("&") if ";" in x or re.search("^\s*$", x)])
    for x in re.findall(r"[\w\(\)\.\s\+\-]+;[\w\(\)\.\s\+\-]+", row):
        row1 = row1.replace(x, x.split(";")[0]).replace(r"\hline", "")
        row2 = (row2.replace(x,
                             x.split(";")[1]).replace(r"\hline",
                                                      "").replace(r"\\", ""))
    return row1, row2


def process_column_title(l):
    ctitles = []
    cline = []
    current = " "
    l = l.replace(r"\\", "").replace(r"hline", "").split("&")
    for i, c in enumerate(l, start=1):
        if i > 1:
            c = " ".join(re.findall("\w+", c)).strip()
            if c:
                debug(f"clines ===={c}")
                if c == current:
                    cline[-1] = i
                else:
                    cline.append(i)
                    cline.append(i)
                    ctitles.append(l[i - 1])
                current = c
    cline_out = ""
    title_out = ""
    if len(cline) == 2:
        return None,"delete_header"
    debug(f"cline ----- {cline}")
    cline_iter = iter(cline)
    idx = 0
    for c in cline_iter:
        ctitle = ctitles[idx]
        idx += 1
        c1, c2 = c, next(cline_iter)
        num_columns = c2 - c1 + 2
        if c1 != c2:
            cline_out += r"\cline{%s-%s}" % (c1, c2)
            title_out += r"\multicolumn{%s}{c}{%s} & " % (num_columns, ctitle)
        else:
            cline_out += r"\cline{%s-%s}" % (c1, c2 + 1)
            title_out += r"\multicolumn{%s}{c}{%s} & " % (num_columns, ctitle)
    debug(f"{title_out} ======= {cline_out}")
    return title_out, cline_out


output = []
ctrls = []
for _n, row in enumerate(c):
    for ctrl in args_c:
        if re.search(r"\b" + ctrl + r"\b", row):
            ctrls.append(formatrow(varname(row)))
            ctrls.append(formatrow(c[_n + 1]))
            del c[_n + 1]
            break
    else:
        therow = varname(row).strip()
        for x, y in changes:
            therow = therow.replace(x, y)
        if not args.nostata:
            if "multicolumn" in therow:
                continue
        if "1o" in therow or "0b" in therow:
            continue
        if "tabular" in therow:
            continue
        if "{document}" in therow:
            continue
        if "documentclass" in therow:
            continue
        if "setlength" in therow:
            continue
        if "(.)" in therow:
            continue
        if "& . &" in therow:
            continue
        if "& . \\" in therow:
            continue
        if "& (.) &" in therow:
            continue
        if "& - &" in therow:
            continue
        if re.search(r"^(&\s*)+\\\\$", therow):
            continue
        therow = formatrow(therow)
        output.append(therow)
        #  if 'VARIABLES' in row:
        #  output.append('_add_empty_line123')

main = []
annotations = []
buffers = []
FE = 0
output[-1] = output[-1].replace(r"\hline", "")

for n, row in enumerate(output, start=1):
    r = row.strip()
    if ("& (1)" in row) and (not args.nostata):
        col_num = row
        continue
    if "VARIABLES" in row:
        if args.noheader:
            continue
        r = r.replace("VARIABLES", "").replace(r"\hline", "")
        the_title, the_cline = process_column_title(r)
        if ";" in r and the_cline != "delete_header":
            row1, row2 = repeat_title(r)
            if args.noheader:
                main.append(row1)
            else:
                main.append(row1)
                main.append(row2 + r"\\")
        elif the_cline != "delete_header":
            main.append("&" + the_title + r"\\")
        if args.cline:
            cline_str = ""
            for cline in args.cline.split(","):
                cline_str += r"\cline{%s}" % cline
            main.append(cline_str)
        else:
            main.append(the_cline if the_cline != "delete_header" else "")
        main.append(col_num)
        main.append(r"\hline")
        continue
    if not args.nostata:
        if re.search(r"^Obs", r):
            buffers.append(row)
            continue
        if re.search(r"^\$R\^2", r):
            buffers.append(row)
            continue
        if re.search("Adj. \$R\^2\$", r):
            buffers.append(row)
            continue
        if re.search(r"^[^&]* FE", r):
            buffers.insert(FE, row)
            FE += 1
            continue
    if ((len(main) > 0) and (not re.search("\d+", r)) and ("hline" not in r)
            and (not re.search(r"^\\+$", r))):
        annotations.append(r)
    else:
        main.append(r)

for x in buffers:
    debug(f"buffers: {x}")

for x in annotations:
    debug(f"annotations: {x}")

for x in main:
    debug(f"main rows: {x}")

def clean_line(l):
    l = l.strip()
    if re.search("^\\+$", l):
        if args.nostata:
            return r"\\"
        else:
            return ""
    if l == "":
        return ""
    if len(re.findall(r"hline", l)) > 1:
        l = l.replace(r"\hline\\", "")
    return l


def condenser(row):
    row = row.replace(r"//", "").strip()
    if not row:
        return True


final = []

if args.nostata:
    if not args.condensed:
        final.append("\\\\\n")
for i, row in enumerate(main):
    if condenser(row):
        continue
    l = clean_line(row)
    if i > 0:
        if "\hline" in main[i - 1] and "_add_empty_line123" in row:
            continue
        else:
            final.append(l)
    else:
        final.append(l)
for row in ctrls:
    if condenser(row):
        continue
    final.append(clean_line(row))

if not args.nostata:
    if not args.condensed:
        final.append("\\\\\n")
for row in annotations:
    if condenser(row):
        continue
    final.append(clean_line(row))
for row in buffers:
    if condenser(row):
        continue
    final.append(clean_line(row))


if args.i:
    f = open(args.file, "w")
else:
    f = None

final[-1] = final[-1].replace("\\", "")
final[-1] = final[-1].replace(r"\hline", "")

for l in final:
    print(l,file=f)

