#!/usr/bin/python3
# ovirt-imageio
# Copyright (C) 2018 Red Hat, Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.

"""
nbd client example

Show how to use nbd module to upload and download images to/from nbd server.

Usage
-----

Start qemu-nbd with a destination image in any format supported by
qemu-img.  To access raw images, you can also use nbdkit or nbd-server.

    $ qemu-nbd \
        --socket /tmp/nbd.sock \
        --format qcow2 \
        --export-name=export \
        --persistent \
        --cache=none \
        --aio=native \
        --discard=unmap \
        --detect-zeroes=unmap \
        image.qcow2

Upload raw file to qcow2 image via qemu-nbd:

    $ ./nbd-client -e export upload file.raw nbd:unix:/tmp/nbb.sock

Download qcow2 image to raw file via qemu-nbd:

    $ ./nbd-client -e export download nbd:unix:/tmp/nbb.sock file.raw
"""

import argparse
import errno
import io
import json
import logging
import os
import subprocess

from urllib.parse import urlparse

from ovirt_imageio import nbd
from ovirt_imageio import nbdutil
from ovirt_imageio import client


def upload(args):
    size = os.path.getsize(args.filename)
    with io.open(args.filename, "rb") as src, \
            nbd.open(urlparse(args.nbd_url)) as dst:
        if dst.export_size < size:
            raise Exception("Destination size {} is smaller than source file "
                            "size {}".format(dst.export_size, size))
        pb = None
        if not args.silent:
            pb = client.ProgressBar(size)

        for zero, start, length in _map(args.filename):
            if zero:
                _zero_range(dst, start, length, pb)
            else:
                _copy_range(dst, src, start, length, pb, args.block_size)

        dst.flush()

        if pb:
            pb.close()


def download(args):
    with nbd.open(urlparse(args.nbd_url)) as src, \
            io.open(args.filename, "wb") as dst:

        pb = None
        if not args.silent:
            pb = client.ProgressBar(src.export_size)

        # NBD has a limit of 32 MiB for request length. We let the user
        # choose a lower step size for tuning the copy.
        max_step = min(src.maximum_block_size, args.block_size)

        buf = bytearray(max_step)

        # Start by truncating the file to the right size. Ignoring EINVAL
        # allows coping to /dev/null for testing read throughput.
        logging.debug("truncate length=%s", src.export_size)
        try:
            dst.truncate(src.export_size)
        except EnvironmentError as e:
            if e.errno != errno.EINVAL:
                raise

        # Iterate over all extents in the image.
        offset = 0
        for ext in nbdutil.extents(src):
            if ext.zero:
                # Seek over zeroes, creating a hole.
                logging.debug("zero offset=%s length=%s", offset, ext.length)
                offset += ext.length
                dst.seek(offset)
                if pb:
                    pb.update(ext.length)
            else:
                # Copy data from source for data extents. Note that extents
                # length is 64 bits, but NBD supports up to 32 MiB per
                # request, so we must split the work to multiple calls.
                todo = ext.length
                while todo:
                    step = min(todo, max_step)
                    view = memoryview(buf)[:step]

                    src.readinto(offset, view)
                    logging.debug("write offset=%s length=%s", offset, step)
                    dst.write(view)

                    todo -= step
                    offset += step
                    if pb:
                        pb.update(step)

        # Finally wait until the data reasch the underlying storage.
        # Ignoring EINVAL so we can copy to /dev/null for testing read
        # throughput.
        logging.debug("flush")
        try:
            os.fsync(dst.fileno())
        except OSError as e:
            if e.errno != errno.EINVAL:
                raise

        if pb:
            pb.close()


def copy(args):
    with nbd.open(urlparse(args.src_url)) as src, \
            nbd.open(urlparse(args.dst_url)) as dst:
        progress = None
        if not args.silent:
            progress = client.ProgressBar(src.export_size)

        nbdutil.copy(
            src,
            dst,
            block_size=args.block_size,
            queue_depth=args.queue_depth,
            progress=progress)

        if progress:
            progress.close()


def _map(path):
    logging.debug("getting extents")
    out = subprocess.check_output([
        "qemu-img",
        "map",
        "--format", "raw",
        "--output", "json",
        path
    ])
    chunks = json.loads(out.decode("utf-8"))
    for c in chunks:
        yield c["zero"], c["start"], c["length"]


def _zero_range(dst, start, length, pb):
    while length:
        step = min(length, nbdutil.MAX_ZERO)
        dst.zero(start, step)
        start += step
        length -= step
        if pb:
            pb.update(step)


def _copy_range(dst, src, start, length, pb, block_size):
    max_step = min(dst.maximum_block_size, block_size)
    src.seek(start)
    while length:
        chunk = src.read(min(length, max_step))
        if not chunk:
            raise Exception("Unexpected end of file, expecting {} bytes"
                            .format(length))
        n = len(chunk)
        logging.debug("read offset=%s, length=%s", start, n)
        dst.write(start, chunk)
        start += n
        length -= n
        if pb:
            pb.update(n)


def kib(s):
    return int(s) * 1024


parser = argparse.ArgumentParser(description="nbd example")
parser.add_argument(
    "-b", "--block-size",
    # 4M is significantly faster than 8M on download, and about the same on
    # upload.
    default=4 * 1024**2,
    type=kib,
    help="block size in KiB")
parser.add_argument(
    "-v", "--verbose",
    action="store_true",
    help="Be more verbose")
parser.add_argument(
    "-s", "--silent",
    action="store_true",
    help="Disable progress")
parser.add_argument(
    "-q", "--queue-depth",
    type=int,
    default=4,
    help="Number of inflight I/O requests (default 4)")

commands = parser.add_subparsers(title="commands")

upload_parser = commands.add_parser(
    "upload",
    help="upload image data to nbd server")
upload_parser.set_defaults(command=upload)
upload_parser.add_argument(
    "filename",
    help="filename to copy")
upload_parser.add_argument(
    "nbd_url",
    help="NBD URL (nbd:unix:/sock or nbd:localhost:10809)")

download_parser = commands.add_parser(
    "download",
    help="download image data from nbd server")
download_parser.set_defaults(command=download)
download_parser.add_argument(
    "nbd_url",
    help="NBD URL (nbd:unix:/sock or nbd:localhost:10809)")
download_parser.add_argument(
    "filename",
    help="filename to write")

copy_parser = commands.add_parser(
    "copy",
    help="copy image from nbd server to another")
copy_parser.set_defaults(command=copy)
copy_parser.add_argument(
    "src_url",
    help="source server NBD URL")
copy_parser.add_argument(
    "dst_url",
    help="destination server NBD URL")

args = parser.parse_args()

logging.basicConfig(
    level=logging.DEBUG if args.verbose else logging.WARNING,
    format=("%(asctime)s %(levelname)-7s (%(threadName)s) [%(name)s] "
            "%(message)s"))

args.command(args)
