#!/usr/bin/env python3
# @Author: carlosgilgonzalez
# @Date:   2018-10-15T12:00:02+01:00
# @Last modified by:   carlosgilgonzalez
# @Last modified time: 2019-11-17T02:52:16+00:00


from __future__ import print_function
import sys
import os
import struct
import time
import ssl
from upydevice import BASE_WS_DEVICE
from datetime import timedelta
try:
    import usocket as socket
except ImportError:
    import socket
import upydev as websocket_helper
from upydev import wss_helper_host

# Define to 1 to use builtin "uwebsocket" module of MicroPython
USE_BUILTIN_UWEBSOCKET = 0
# Treat this remote directory as a root for file transfers
SANDBOX = ""
# SANDBOX = "/tmp/webrepl/"
DEBUG = 0

WEBREPL_REQ_S = "<2sBBQLH64s"
WEBREPL_PUT_FILE = 1
WEBREPL_GET_FILE = 2
WEBREPL_GET_VER = 3

bloc_progress = ["▏", "▎", "▍", "▌", "▋", "▊", "▉"]

columns, rows = os.get_terminal_size(0)
cnt_size = 80
if columns > cnt_size:
    bar_size = int((columns - cnt_size))
    pb = True
else:
    bar_size = 1
    pb = False

# + " "*((bar_size+1) - len("█" *index + l_bloc)
# parser = argparse.ArgumentParser()
# parser.add_argument("op", metavar='Mode', help='operation')
# parser.add_argument("-size", help='size of file to get', required=False)
# parser.add_argument("-p", help='password', required=False)
# args = parser.parse_args()


def do_pg_bar(index, wheel, nb_of_total, speed, time_e, loop_l,
              percentage, ett, size_bar=bar_size):
    l_bloc = bloc_progress[loop_l]
    if index == bar_size:
        l_bloc = "█"
    sys.stdout.write("\033[K")
    print('▏{}▏{:>2}{:>5} % | {} | {:>5} KB/s | {}/{} s'.format("█" *index + l_bloc  + " "*((bar_size+1) - len("█" *index + l_bloc)),
                                                                    wheel[index % 4],
                                                                    int((percentage)*100),
                                                                    nb_of_total, speed,
                                                                    str(timedelta(seconds=time_e)).split('.')[0][2:],
                                                                    str(timedelta(seconds=ett)).split('.')[0][2:]), end='\r')
    sys.stdout.flush()


def debugmsg(msg):
    if DEBUG:
        print(msg)


if USE_BUILTIN_UWEBSOCKET:
    from uwebsocket import websocket
else:
    class websocket:

        def __init__(self, s):
            self.s = s
            self.buf = b""

        def write(self, data):
            l = len(data)
            if l < 126:
                # TODO: hardcoded "binary" type
                hdr = struct.pack(">BB", 0x82, l)
            else:
                hdr = struct.pack(">BBH", 0x82, 126, l)
            self.s.send(hdr)
            self.s.send(data)

        def recvexactly(self, sz):
            res = b""
            while sz:
                data = self.s.recv(sz)
                if not data:
                    break
                res += data
                sz -= len(data)
            return res

        def read(self, size, text_ok=False):
            if not self.buf:
                while True:
                    hdr = self.recvexactly(2)
                    assert len(hdr) == 2
                    fl, sz = struct.unpack(">BB", hdr)
                    if sz == 126:
                        hdr = self.recvexactly(2)
                        assert len(hdr) == 2
                        (sz,) = struct.unpack(">H", hdr)
                    if fl == 0x82:
                        break
                    if text_ok and fl == 0x81:
                        break
                    debugmsg(
                        "Got unexpected websocket record of type %x, skipping it" % fl)
                    while sz:
                        skip = self.s.recv(sz)
                        debugmsg("Skip data: %s" % skip)
                        sz -= len(skip)
                data = self.recvexactly(sz)
                assert len(data) == sz
                self.buf = data

            d = self.buf[:size]
            self.buf = self.buf[size:]
            assert len(d) == size, len(d)
            return d

        def ioctl(self, req, val):
            assert req == 9 and val == 2


def login(ws, passwd):
    while True:
        c = ws.read(1, text_ok=True)
        if c == b":":
            assert ws.read(1, text_ok=True) == b" "
            break
    ws.write(passwd.encode("utf-8") + b"\r")


def read_resp(ws):
    data = ws.read(4)
    sig, code = struct.unpack("<2sH", data)
    assert sig == b"WB"
    return code


def send_req(ws, op, sz=0, fname=b""):
    rec = struct.pack(WEBREPL_REQ_S, b"WA", op, 0, 0, sz, len(fname), fname)
    debugmsg("%r %d" % (rec, len(rec)))
    ws.write(rec)


def get_ver(ws):
    send_req(ws, WEBREPL_GET_VER)
    d = ws.read(3)
    d = struct.unpack("<BBB", d)
    return d


def put_file(ws, local_file, remote_file):
    wheel = ['|', '/', '-', "\\"]
    sz = os.stat(local_file)[6]
    dest_fname = (SANDBOX + remote_file).encode("utf-8")
    rec = struct.pack(WEBREPL_REQ_S, b"WA", WEBREPL_PUT_FILE,
                      0, 0, sz, len(dest_fname), dest_fname)
    debugmsg("%r %d" % (rec, len(rec)))
    ws.write(rec[:10])
    ws.write(rec[10:])
    assert read_resp(ws) == 0
    cnt = 0
    t_start = time.time()
    print('\n')
    with open(local_file, "rb") as f:
        while True:
            t_0 = time.time()
            # sys.stdout.write("Sent %d of %d bytes\r" % (cnt, sz))
            # sys.stdout.flush()
            buf = f.read(1024)
            if not buf:
                break
            ws.write(buf)
            cnt += len(buf)
            loop_index_f = (cnt/sz)*bar_size
            loop_index = int(loop_index_f)
            loop_index_l = int(round(loop_index_f-loop_index, 1)*6)
            nb_of_total = "{:.2f}/{:.2f} KB".format(cnt/(1024), sz/(1024))
            percentage = cnt/sz
            t_elapsed = time.time() - t_start
            t_speed = "{:^2.2f}".format((cnt/(1024))/t_elapsed)
            ett = sz / (cnt / t_elapsed)
            if pb:
                do_pg_bar(loop_index, wheel, nb_of_total, t_speed, t_elapsed,
                          loop_index_l, percentage, ett)
            else:
                sys.stdout.write("Sent %d of %d bytes\r" % (cnt, sz))
                sys.stdout.flush()

    print('\n')
    print()
    assert read_resp(ws) == 0


def get_file(ws, local_file, remote_file, file_size):
    sz_r = file_size
    wheel = ['|', '/', '-', "\\"]
    src_fname = (SANDBOX + remote_file).encode("utf-8")
    rec = struct.pack(WEBREPL_REQ_S, b"WA", WEBREPL_GET_FILE,
                      0, 0, 0, len(src_fname), src_fname)
    debugmsg("%r %d" % (rec, len(rec)))
    ws.write(rec)
    assert read_resp(ws) == 0
    t_start = time.time()
    print('\n')
    with open(local_file, "wb") as f:
        cnt = 0
        while True:
            ws.write(b"\0")
            (sz,) = struct.unpack("<H", ws.read(2))
            if sz == 0:
                break
            while sz:
                buf = ws.read(sz)
                if not buf:
                    raise OSError()
                cnt += len(buf)
                f.write(buf)
                sz -= len(buf)
                if pb:
                    loop_index_f = (cnt/sz_r)*bar_size
                    loop_index = int(loop_index_f)
                    loop_index_l = int(round(loop_index_f-loop_index, 1)*6)
                    nb_of_total = "{:.2f}/{:.2f} KB".format(cnt/(1024), sz_r/(1024))
                    percentage = cnt / sz_r
                    t_elapsed = time.time() - t_start
                    t_speed = "{:^2.2f}".format((cnt/(1024))/t_elapsed)
                    ett = sz_r / (cnt / t_elapsed)
                    do_pg_bar(loop_index, wheel, nb_of_total, t_speed,
                              t_elapsed, loop_index_l, percentage, ett)
                else:
                    sys.stdout.write("Received %d bytes\r" % cnt)
                    sys.stdout.flush()
    print()
    print('\n')
    assert read_resp(ws) == 0


def help(rc=0):
    exename = sys.argv[0].rsplit("/", 1)[-1]
    print("%s - Perform remote file operations using MicroPython WebREPL protocol" % exename)
    print("Arguments:")
    print("  [-p password] <host>:<remote_file> <local_file> - Copy remote file to local file")
    print("  [-p password] <local_file> <host>:<remote_file> - Copy local file to remote file")
    print("Examples:")
    print("  %s script.py 192.168.4.1:/another_name.py" % exename)
    print("  %s script.py 192.168.4.1:/app/" % exename)
    print("  %s -p password 192.168.4.1:/app/script.py ." % exename)
    sys.exit(rc)


def error(msg):
    print(msg)
    sys.exit(1)


def parse_remote(remote, ssl=False):
    host, fname = remote.rsplit(":", 1)
    if fname == "":
        fname = "/"
    port = 8266
    if ssl:
        port = 8833
    if ":" in host:
        host, port = host.split(":")
        port = int(port)
    return (host, port, fname)


def main():
    if len(sys.argv) not in (3, 5, 6):
        help(1)

    passwd = None
    websec = False

    for i in range(len(sys.argv)):
        if sys.argv[i] == '-p':
            sys.argv.pop(i)
            passwd = sys.argv.pop(i)
            break

    for i in range(len(sys.argv)):
        if sys.argv[i] == '-wss':
            sys.argv.pop(i)
            websec = True
            break

    if not passwd:
        import getpass
        passwd = getpass.getpass()

    if ":" in sys.argv[1] and ":" in sys.argv[2]:
        error("Operations on 2 remote files are not supported")
    if ":" not in sys.argv[1] and ":" not in sys.argv[2]:
        error("One remote file is required")

    if ":" in sys.argv[1]:
        op = "get"
        if not websec:
            host, port, src_file = parse_remote(sys.argv[1])
        else:
            host, port, src_file = parse_remote(sys.argv[1], ssl=True)
        dst_file = sys.argv[2]
        if os.path.isdir(dst_file):
            basename = src_file.rsplit("/", 1)[-1]
            dst_file += "/" + basename
    else:
        op = "put"
        if not websec:
            host, port, dst_file = parse_remote(sys.argv[2])
        else:
            host, port, dst_file = parse_remote(sys.argv[2], ssl=True)
        src_file = sys.argv[1]
        if dst_file[-1] == "/":
            basename = src_file.rsplit("/", 1)[-1]
            dst_file += basename

    if True:
        print("op:%s, host:%s, port:%d, passwd:%s." %
              (op, host, port, '*'*len(passwd)))
        print(src_file, "->", dst_file)
        if not websec:
            dev = BASE_WS_DEVICE(host, passwd)
            if op == 'get':
                dev.cmd("import uos; uos.stat('{}')[6]".format(src_file), silent=True)
                size_file_to_get = dev.output
        else:
            dev = BASE_WS_DEVICE(host, passwd)
            if op == 'get':
                dev.cmd("import uos; uos.stat('{}')[6]".format(src_file), silent=True, ssl=True)
                size_file_to_get = dev.output

    # HERE CATCH -wss and do handshake

    s = socket.socket()

    ai = socket.getaddrinfo(host, port)
    addr = ai[0][4]
    if websec:
        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
        context.check_hostname = False
        context.verify_mode = ssl.CERT_NONE
        context.set_ciphers('ECDHE-ECDSA-AES128-CCM8')
        # sock = context.wrap_socket(sock, server_hostname=hostname)
        s = context.wrap_socket(s)
        s.connect(addr)
        wss_helper_host.client_handshake(s, ssl=True)
    else:
        s.connect(addr)
    # s = s.makefile("rwb")
        websocket_helper.client_handshake(s)

    ws = websocket(s)

    login(ws, passwd)
    print("Remote WebREPL version:", get_ver(ws))

    # Set websocket to send data marked as "binary"
    ws.ioctl(9, 2)

    if op == "get":
        print('File size: {:.2f} KB'.format(dev.output/1024))
        get_file(ws, dst_file, src_file, size_file_to_get)
    elif op == "put":
        put_file(ws, src_file, dst_file)

    s.close()


if __name__ == "__main__":
    main()
