#!/usr/bin/env python3
# @Author: carlosgilgonzalez
# @Date:   2019-03-22T15:57:34+00:00
# @Last modified by:   carlosgilgonzalez
# @Last modified time: 2019-08-16T00:58:33+01:00

#
# complete rewrite of console webrepl client from aivarannamaa:
# https://forum.micropython.org/viewtopic.php?f=2&t=3124&p=29865#p29865
#
import sys
import readline
import getpass
import websocket
import argparse
import time

try:
    import thread
except ImportError:
    import _thread as thread
from time import sleep

try:                   # from https://stackoverflow.com/a/7321970
    input = raw_input  # Fix Python 2.x.
except NameError:
    pass


def help(rc=0):
    exename = sys.argv[0].rsplit("/", 1)[-1]
    print("%s - remote shell using MicroPython WebREPL protocol" % exename)
    print("Arguments:")
    print("  [-p password] [-dbg] [-r] <host> - remote shell (to <host>:8266)")
    print("Examples:")
    print("  %s 192.168.4.1" % exename)
    print("  %s -p abcd 192.168.4.1" % exename)
    print("  %s -p abcd -r 192.168.4.1 < <(sleep 1 && echo \"...\")" % exename)
    print("Special command control sequences:")
    print("  line with single characters")
    print("    'A' .. 'E' - use when CTRL-A .. CTRL-E needed")
    print('  just "exit" - end shell')
    sys.exit(rc)


inp = ""
raw_mode = False
normal_mode = True
paste_mode = False
prompt = "Password: "
prompt_seen = False
passwd = None
debug = False
redirect = False
response_snd = False
EOF_SIG = False
parser = argparse.ArgumentParser()
parser.add_argument("-c", help='command to send', required=False)
parser.add_argument("-p", help='host password', required=True)
parser.add_argument("-dbg", help='debug mode', required=False)
parser.add_argument("-t", help='host direction', required=True)
parser.add_argument("-r", help='redirect', required=False)
parser.add_argument("-f", help='file to debug', required=True)
parser.add_argument(
    "-wt", help='timeout in seconds to wait for an answer', type=float,
    required=False, default=0.4)
args = parser.parse_args()
passwd = args.p
debug = args.dbg
redirect = args.r
s_command = args.c
time_out = args.wt


def uparser_dec_script(long_command):
    lines_cmd = []
    space_count = [0]
    buffer_line = ''
    previous_incomplete = False
    for line in long_command.split('\n')[1:]:
        line_before = space_count[-1]
        if line != '':
            if not previous_incomplete:
                line_now = line.count('    ')
                # print(line_now)
        # print(line_now)
            space_count.append(line_now)
            if line_before > line_now:
                if line_now > 0:
                    lines_cmd.append(
                        ''.join(['\b' for i in range(int("{:.0f}".format((line_before-line_now))))]+[line.strip()]))
                    # print('This line must be backspaced {:.0f} times: {}'.format(((line_before-line_now)), line.strip()))
                # else:
                #     if len(line.strip()) > 0:
                #         lines_cmd.append(''.join(['\b' for i in range(1)]+[line.strip()]))

            elif line[-1] == ',':
                # print('line incomplete')
                previous_incomplete = True
                buffer_line += line.strip()
            else:
                if buffer_line != '':
                    if previous_incomplete:
                        # print('This is the complete line:{}'.format(buffer_line+line.strip()))
                        lines_cmd.append('\r'.join([buffer_line+line.strip()]))
                        buffer_line = ''
                        previous_incomplete = False
                else:
                    lines_cmd.append('\r'.join([line.strip()]))
    return "{}\r\r".format('\r'.join(lines_cmd))


def uparser_dec_script_p(long_command):
    lines_cmd = []
    space_count = [0]
    buffer_line = ''
    previous_incomplete = False
    for line in long_command.split('\n')[1:]:
        line_before = space_count[-1]
        if line != '':
            if not previous_incomplete:
                line_now = line.count('    ')
                # print(line_now)
        # print(line_now)
            space_count.append(line_now)
            if line_before > line_now:
                if line_now > 0:
                    lines_cmd.append(
                        ''.join(['' for i in range(int("{:.0f}".format((line_before-line_now))))]+[line]))
                    # print('This line must be backspaced {:.0f} times: {}'.format(((line_before-line_now)), line))
                # else:
                #     if len(line.strip()) > 0:
                #         lines_cmd.append(''.join(['\b' for i in range(1)]+[line.strip()]))

            elif line[-1] == ',':
                # print('line incomplete')
                if not previous_incomplete:
                    buffer_line += line
                else:
                    buffer_line += line.strip()
                previous_incomplete = True
            else:
                if buffer_line != '':
                    if previous_incomplete:
                        # print('This is the complete line:{}'.format(buffer_line+line.strip()))
                        lines_cmd.append('\r'.join([buffer_line+line.strip()]))
                        buffer_line = ''
                        previous_incomplete = False
                else:
                    lines_cmd.append('\r'.join([line]))
    return "{}{}".format('\r'.join(lines_cmd),'\r'*line_now)

# LOAD FILE:
string_bfile = b'\n\n'
buff = bytearray(2000)
with open(args.f, 'rb') as bupyfile:
    while True:
        try:
            buff[:] = bupyfile.read(2000)
            if buff != b'':
                string_bfile += buff
            else:
                print('EOF')
                break
        except Exception as e:
            print(e)
            pass

cmd_lines = [line for line in uparser_dec_script(string_bfile.decode()).split('\r')]
cmd_p_lines = [line for line in uparser_dec_script_p(string_bfile.decode()).split('\r')]
cmd_lines = ['#'*len(passwd)] + cmd_lines
cmd_p_lines = ['#'*len(passwd)] + cmd_p_lines
# print(cmd_lines)
cmd_index = 0
DEF_END = False
INDENT_L = False
def on_message(ws, message):
    global inp
    global raw_mode
    global normal_mode
    global paste_mode
    global prompt
    global prompt_seen
    global s_command
    global running
    global response_snd
    global cmd_index
    global cmd_lines
    global DEF_END
    global EOF_SIG
    global INDENT_L
    #command_snd = True
    if len(inp) == 1 and ord(inp[0]) <= 5:
        inp = "\r\n" if inp != '\x04' else "\x04"
    while inp != "" and message != "" and inp[0] == message[0]:
        inp = inp[1:]
        message = message[1:]
    if message != "":
        if not(raw_mode) or inp != "\x04":
            inp = ""
    if raw_mode:
        if message == "OK":
            inp = "\x04\x04"
        elif message == "OK\x04":
            message = "OK"
            inp = "\x04"
        elif message == "OK\x04\x04":
            message = "OK"
            inp = ""
        elif message == "OK\x04\x04>":
            message = "OK>"
            inp = ""
    if debug:
        print("[%s,%d,%s]" % (message, ord(message[0]), inp))
    if inp == '' and prompt != '':
        if message.endswith(prompt):
            prompt_seen = True
        elif normal_mode:
            if message.endswith("... "):
                prompt = "    "
            elif message.endswith(">>> "):
                prompt = ">>> "
                prompt_seen = True
    if prompt_seen:
        sys.stdout.write(message[:-len(prompt)])
        if command_snd == True:
            #time.sleep(0.001)
            response_snd = True
        if EOF_SIG:
            print('EOF_SIG')
            sys.exit()
    else:
        if INDENT_L:
            pass
        else:
            if DEF_END:
                pass
            else:
                sys.stdout.write(message)


        # if command_snd == True:
        #     time.sleep(0.001)
        #     response_snd = True

        #command_snd = True
        #running = False
    # sys.stdout.flush()
    pass
    if paste_mode and message == "=== ":
        inp = "\n"


def on_error(ws, error):
    sys.stdout.write("### error("+error+") ###\n")
    sys.stdout.flush()


def on_close(ws):
    sys.stdout.write("### closed ###\n")
    sys.stdout.flush()
    ws.close()
    sys.exit(1)


def on_open(ws):
    def run(*args):
        global input
        global inp
        global raw_mode
        global normal_mode
        global paste_mode
        global prompt
        global prompt_seen
        global command_snd
        global s_command
        global running
        global command_snd
        global response_snd
        global time_out
        global EOF_SIG
        global cmd_lines
        global cmd_index
        global DEF_END
        global INDENT_L
        running = True
        command_snd = False
        injected = False
        do_input = getpass.getpass
        #print(s_command)
        while running:
            while ws.sock and ws.sock.connected:
                while prompt and not(prompt_seen):
                    sleep(0.1)
                    if debug:
                        sys.stdout.write(":"+prompt+";")
                        sys.stdout.flush()
                prompt_seen = False

                if prompt == "Password: " and passwd is not None:
                    inp = passwd
                    sys.stdout.write("Password: ")
                    sys.stdout.flush()
                else:
                    if cmd_index == len(cmd_lines):
                        EOF_SIG = True
                    try:
                        if prompt == "Password: ":
                            time.sleep(0.5)
                            pass
                        else:
                            inp = do_input(prompt+cmd_p_lines[cmd_index].replace('\x08', '\x08\x08\x08\x08'))
                            # inp = do_input(prompt)
                    except Exception as e:
                        EOF_SIG = True
                    command_snd = False
                    if inp == 'C':
                        ws.send('\x03')
                        EOF_SIG = True
                    # inp = 'led.value(not led.value())'
                    if not command_snd:
                        # time.sleep(time_out)
                        try:
                            inp = cmd_lines[cmd_index]
                            if cmd_lines[cmd_index] == '':
                                inp = '\r'
                            if '    ' in cmd_p_lines[cmd_index]:
                                INDENT_L = True
                                if cmd_p_lines[cmd_index] == cmd_p_lines[-1]:
                                    INDENT_L = False
                            else:
                                INDENT_L = False
                            if cmd_lines[cmd_index].startswith('\x08') or cmd_lines[cmd_index].endswith(':'):
                                DEF_END = True
                            else:
                                DEF_END = False
                        except Exception as e:
                            EOF_SIG = True
                        command_snd = True
                        cmd_index += 1
                        if cmd_index == len(cmd_lines):
                            EOF_SIG = True

                    if redirect:
                        sys.stdout.write(inp+"\n")
                        sys.stdout.flush()

                if len(inp) != 1 or inp[0] < 'A' or inp[0] > 'E':
                    inp += "\r\n"
                else:
                    inp = chr(ord(inp[0])-64)
                    if raw_mode:
                        if inp[0] == '\x02':
                            normal_mode = True
                            raw_mode = False
                    elif normal_mode:
                        if inp[0] == '\x01':
                            raw_mode = True
                            normal_mode = False
                        elif inp[0] == '\x05':
                            paste_mode = True
                            normal_mode = False
                    else:
                        if inp[0] == '\x03' or inp[0] == '\x04':
                            normal_mode = True
                            paste_mode = False

                do_input = getpass.getpass if raw_mode else input

                if prompt == "Password: ":  # initial "CTRL-C CTRL-B" injection
                    prompt = ""
                else:
                    prompt = "=== " if paste_mode else ">>> "[4*int(raw_mode):]

                if inp == "exit\r\n":
                    running = False
                    break
                else:
                    if ws.sock and ws.sock.connected:
                        # print('INPUT:{}'.format(inp))
                        ws.send(inp)
                        # ws.send('\x03\x02')
                        if prompt == "" and not(raw_mode) and not(injected):
                            # inp += '\x03\x02'
                            injected = True
                            # ws.send('\x03\x02')
                        #ws.send("\r\n")
                        # if response_snd == True:
                            #running = False
                            # sys.exit()
                            # break

                    else:
                        running = False
            running = False
        ws.sock.close()
        sys.exit(1)
    thread.start_new_thread(run, ())


if __name__ == "__main__":
    websocket.enableTrace(False)
    ws = websocket.WebSocketApp("ws://"+args.t+":8266",
                                on_message=on_message,
                                on_error=on_error,
                                on_close=on_close)
    ws.on_open = on_open
    ws.run_forever()
