#!/usr/bin/env python3
"""Python netcat implementation."""

from __future__ import print_function
from builtins import input
from abc import ABCMeta, abstractmethod
from subprocess import Popen, PIPE
import argparse
import atexit
import inspect
import os
import re
import socket
import subprocess
import sys
import threading
import time

# -------------------------------------------------------------------------------------------------
# GLOBALS
# -------------------------------------------------------------------------------------------------

APPNAME = "pwncat"
APPREPO = "https://github.com/cytopia/pwncat"
VERSION = "0.0.5-alpha"


# -------------------------------------------------------------------------------------------------
# ABSTRACT CLASS: AbstractNetcatModule
# -------------------------------------------------------------------------------------------------
class AbstractNetcatModule:
    """
    Abstract class to for netcat modules.

    This is a skeleton that defines how the modules for Netcat should look like.

    The data_generator should constantly yield data received from sort sort of input
    which could be user input, output from a shell command data from a socket.

    The data_callback will apply some sort of action on the data received from a data_generator
    which could be output to stdout, send it to the shell or to a socket.
    """

    __metaclass__ = ABCMeta

    @abstractmethod
    def __init__(self, logger, encoder, options={}):
        """Set specific options for this module."""
        pass

    @abstractmethod
    def input_generator(self):
        """Implement a generator function which constantly yields data from some input."""
        while False:
            yield None

    @abstractmethod
    def output_callback(self, data):
        """Implement a callback which processes the input into some output."""
        print(data)


# -------------------------------------------------------------------------------------------------
# ABSTRCT CLASS: AbstractSocket
# -------------------------------------------------------------------------------------------------
class AbstractSocket(object):
    """Abstract class which provides TCP, UDP and IPv4, IPv6 Socket funcionality."""

    sock = None  # server binding socket (until accept())
    conn = None  # client/server communication socket

    # The instance role must be 'server' or 'client' and
    # is used to determine how to reconnect broken connections.
    # Either listen again (tcp-only) or re-connect to upstream.
    role = None  # Must be 'server' or 'client'

    options = {
        "udp": False,  # Is TCP or UDP server?
        "bufsize": 1024,  # Receive buffer size
        "backlog": 0,  # Listen backlog
        "nodns": False,  # Do not resolve hostnames
        "reinit": False,  # False (never), True (indefinite) or int for how many times to reinit
        "reconn": False,  # False (never), True (indefinite) or int for how many times to reconnect
        "reinit_robin": [],  # Ports to round-robin during failed init phase
        "reconn_robin": [],  # Ports to round-robin during failed phase (after 1st succ init)
        "reinit_wait": 0,  # Time in seconds to wait between reinits
        "reconn_wait": 0,  # Time in seconds to wait between reconnects
        "udp_ping_intvl": False,  # Interval in sec for UDP client to ping server
    }

    # In case the server is running in UDP mode,
    # it must wait for the client to connect in order
    # to retrieve its addr and port in order to be able
    # to send data back to it.
    udp_client_addr = None
    udp_client_port = None

    # For client role only
    # Store the address and port of the remote server to connect to.
    # This is required for self.connect()
    remote_addr = None
    remote_addr = None

    # ------------------------------------------------------------------------------
    # Constructor / Destructor
    # ------------------------------------------------------------------------------
    def __init__(self, logger, encoder, role, options={}):
        """Constructor."""
        assert type(self) is not AbstractSocket, "AbstractSocket cannot be instantiated directly."
        assert role in ["server", "client"], "The role must be 'server' or 'client'."

        self.log = logger
        self.enc = encoder
        self.role = role

        if "udp" in options:
            self.options["udp"] = options["udp"]
        if "bufsize" in options:
            self.options["bufsize"] = options["bufsize"]
        if "backlog" in options:
            self.options["backlog"] = options["backlog"]
        if "nodns" in options:
            self.options["nodns"] = options["nodns"]
        if "reinit" in options:
            self.options["reinit"] = options["reinit"]
        if "reconn" in options:
            self.options["reconn"] = options["reconn"]
        if "reinit_robin" in options:
            self.options["reinit_robin"] = options["reinit_robin"]
        if "reconn_robin" in options:
            self.options["reconn_robin"] = options["reconn_robin"]
        if "reinit_wait" in options:
            self.options["reinit_wait"] = options["reinit_wait"]
        if "reconn_wait" in options:
            self.options["reconn_wait"] = options["reconn_wait"]
        if "udp_ping_intvl" in options:
            self.options["udp_ping_intvl"] = options["udp_ping_intvl"]

        # Register destructor
        atexit.register(self.__exit__)

    def __exit__(self):
        """Destructor."""
        if self.conn is not None:
            self.log.trace("Closing 'conn' socket")
            self.__close_socket(self.conn)
            self.conn = None
        if self.sock is not None:
            self.log.trace("Closing 'sock' socket")
            self.__close_socket(self.sock)
            self.sock = None

    # ------------------------------------------------------------------------------
    # Private Functions
    # ------------------------------------------------------------------------------
    def __close_socket(self, sock):
        """Close a socket."""
        try:
            sock.shutdown(socket.SHUT_RDWR)
        except (ValueError, OSError, socket.error):
            pass
        sock.close()

    def __reconnect(self):
        """Reconnect to a server if upstream has gone."""
        self.__close_socket(self.conn)
        self.__close_socket(self.sock)
        self.create_socket()
        self.conn = self.sock
        if not self.connect():
            self.__reconnect_to_server()

    def __reaccept_from_client(self):
        """Ensure the server is able to accept clients again, after current client has left."""
        # Only for server
        assert self.role == "server", "Only the role 'server' can accept connections."
        # Do not re-accept for UDP
        assert not self.options["udp"], "This should have been caught during arg check."

        # [NO] Never re-accept
        if type(self.options["reconn"]) is bool and not self.options["reconn"]:
            self.log.info("No automatic re-accept specified. Shutting down.")
            return False
        # [YES] Always re-accept indefinitely
        if type(self.options["reconn"]) is bool and self.options["reconn"]:
            self.log.info(
                "Re-accepting in {} sec (indefinitely)".format(self.options["reconn_wait"])
            )
            time.sleep(self.options["reconn_wait"])
            self.accept()
            return True
        # [YES] Re-accept x many times
        if self.options["reconn"] > 0:
            self.log.info(
                "Re-accepting in {} sec ({} more times left)".format(
                    self.options["reconn_wait"], self.options["reconn"]
                )
            )
            self.options["reconn"] -= 1
            time.sleep(self.options["reconn_wait"])
            self.accept()
            return True
        # [NO] Re-accept count is used up
        self.log.info("Re-accept count is used up. Shutting down.")
        return False

    def __reconnect_to_server(self):
        """Ensure the client re-connects to the remote server, if the remote server hang up."""
        # Only for Clients
        assert self.role == "client", "Only the role 'client' can re-connect."
        # Do not re-connect with UDP
        assert not self.options["udp"], "This should have been caught during arg check."

        # [NO] Never re-connect
        if type(self.options["reconn"]) is bool and not self.options["reconn"]:
            self.log.info("No automatic reconnect specified. Shutting down.")
            return False
        # [YES] Always re-connect indefinitely
        if type(self.options["reconn"]) is bool and self.options["reconn"]:
            self.log.info(
                "Reconnecting in {} sec (indefinitely)".format(self.options["reconn_wait"])
            )
            time.sleep(self.options["reconn_wait"])
            self.__reconnect()
            return True
        # [YES] Re-connect x many times
        if self.options["reconn"] > 0:
            self.log.info(
                "Reconnecting in {} sec ({} more times left)".format(
                    self.options["reconn_wait"], self.options["reconn"]
                )
            )
            self.options["reconn"] -= 1
            time.sleep(self.options["reconn_wait"])
            self.__reconnect()
            return True
        # [NO] Re-connect count is used up
        self.log.info("Reconnect count is used up. Shutting down.")
        return False

    # ------------------------------------------------------------------------------
    # Helper Functions
    # ------------------------------------------------------------------------------
    def gethostbyname(self, host):
        """Translate hostname into IP address."""
        if self.options["nodns"]:
            return host
        try:
            self.log.debug("Resolving hostname: {}".format(host))
            addr = socket.gethostbyname(host)
            self.log.debug("Resolved hostname:  {}".format(addr))
            return addr
        except socket.gaierror as error:
            self.log.error("Resolve Error: {}".format(error))
            sys.exit(1)

    def create_socket(self):
        """Create TCP or UDP socket."""
        try:
            if self.options["udp"]:
                self.log.debug("Creating UDP socket")
                self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
            else:
                self.log.debug("Creating TCP socket")
                self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        except socket.error as error:
            self.log.error("Failed to create the socket: {}".format(error))
            sys.exit(1)

    def bind(self, addr, port):
        """Bind the socket to an address."""
        try:
            self.log.debug("Binding socket to {}:{}".format(addr, port))
            self.sock.bind((addr, port))
        except (OverflowError, OSError, socket.error) as error:
            self.log.error("Binding socket to {}:{} failed: {}".format(addr, port, error))
            sys.exit(1)

    def listen(self):
        """Listen for connections made to the socket."""
        try:
            self.log.debug("Listening with backlog={}".format(self.options["backlog"]))
            self.sock.listen(self.options["backlog"])
        except socket.error as error:
            self.log.error("Listening failed: {}".format(error))
            sys.exit(1)

    def accept(self):
        """Accept a connection."""
        try:
            self.log.debug("Waiting for TCP client")
            self.conn, client = self.sock.accept()
            addr, port = client
            self.log.info("Client connected from {}:{}".format(addr, port))
        except (socket.gaierror, socket.error) as error:
            self.log.error("Accept failed: {}".format(error))
            sys.exit(1)

    def connect(self):
        """Connect to a remote socket at given address and port (TCP-only)."""
        try:
            self.log.debug("Connecting to {}:{}".format(self.remote_addr, self.remote_port))
            self.sock.connect((self.remote_addr, self.remote_port))
            return True
        except socket.error as error:
            self.log.error(
                "Connecting to {}:{} failed: {}".format(self.remote_addr, self.remote_port, error)
            )
            return False

    # ------------------------------------------------------------------------------
    # Send / Receive Functions
    # ------------------------------------------------------------------------------
    def send(self, data):
        """Send data."""
        # In case of sending data back to an udp client we need to wait
        # until the client has first connected and told us its addr/port
        if self.options["udp"] and self.udp_client_addr is None and self.udp_client_port is None:
            self.log.info("Waiting for UDP client to connect")
            while self.udp_client_addr is None and self.udp_client_port is None:
                pass

        send = 0
        size = len(data)
        data = self.enc.encode(data)
        assert size == len(data), "Encoding messed up string length, might need to do len() after."
        # Loop until all bytes have been send
        while send < size:
            if self.conn is None:
                self.log.error("Socket is gone")
                return
            try:
                self.log.trace("Trying to send: {}".format(data))
                if self.options["udp"]:
                    send += self.conn.sendto(data, (self.udp_client_addr, self.udp_client_port))
                else:
                    send += self.conn.send(data)
                # Remove 'send' many bytes from data
                data = data[send:]
            except (OSError, socket.error) as error:
                self.log.error("Send Error: {}".format(error))
                # exit the thread gracefully (program shuts down fine)
                return

    def receive(self):
        """Generate received data endlessly by yielding it."""
        while True:
            if self.conn is None:
                self.log.error("Socket is gone")
                return
            try:
                self.log.trace(
                    "Waiting to receive data (bufsize={})...".format(self.options["bufsize"])
                )
                (byte, addr) = self.conn.recvfrom(self.options["bufsize"])
                data = byte
                data = self.enc.decode(byte)
                self.log.trace("Received: {}".format(data))

                # If we're receiving data from a UDP client
                # we can firstly/finally set its addr/port in order
                # to send data back to it (see send() function)
                if self.options["udp"]:
                    self.udp_client_addr, self.udp_client_port = addr
                    # Avoid the noise on UDP connections to spam on every send
                    if self.udp_client_addr is None or self.udp_client_port is None:
                        self.log.info(
                            "Client connected: {}:{}".format(
                                self.udp_client_addr, self.udp_client_port
                            )
                        )
                    # Find for debug
                    else:
                        self.log.debug(
                            "Client connected: {}:{}".format(
                                self.udp_client_addr, self.udp_client_port
                            )
                        )

            except socket.error as error:
                self.log.error("Receive Error: {}".format(error))
                sys.exit(1)

            if not data:
                # Upstream (server or client) is gone. Do we reconnect or quit?
                self.log.warning("Upstream connection is gone")
                if self.role == "server":
                    if not self.__reaccept_from_client():
                        sys.exit(0)
                if self.role == "client":
                    if not self.__reconnect_to_server():
                        sys.exit(0)

            yield data


# -------------------------------------------------------------------------------------------------
# CLASS: StringEncoder
# -------------------------------------------------------------------------------------------------
class StringEncoder(object):
    """
    Takes care about Python 2/3 string encoding/decoding.

    This allows to parse all string/byte values internally between all
    classes or functions as strings to keep full Python 2/3 compat.
    """

    # https://stackoverflow.com/questions/606191/27527728#27527728
    codec = "cp437"

    def __init__(self, logger):
        self.log = logger
        self.py3 = sys.version_info >= (3, 0)

    def encode(self, data):
        """Convert string into a byte type for Python3."""
        if self.py3:
            self.log.trace("Encoding: {}".format(data))
            data = data.encode("cp437")
            self.log.trace("Encoded: {}".format(data))
        return data

    def decode(self, data):
        """Convert bytes into a string type for Python3."""
        if self.py3:
            self.log.trace("Decoding: {}".format(data))
            data = data.decode("cp437")
            self.log.trace("Decoded: {}".format(data))
        return data


# -------------------------------------------------------------------------------------------------
# CLASS: Logger
# -------------------------------------------------------------------------------------------------
class Logger(object):
    """Logger class."""

    # ------------------------------------------------------------------------------
    # Constructor / Destructor
    # ------------------------------------------------------------------------------
    def __init__(self, verbosity=1):
        """Constructor.

        verbosity == 0: Log errors
        verbosity == 1: Log errors, warnings
        verbosity == 2: Log errors, warnings, info
        verbosity == 3: Log errors, warnings, info, debug, class/func names
        verbosity == 4: Log errors, warnings, info, debug, trace, class/func names
        """
        assert verbosity >= 0, "Verbosity cannot be less than 0."
        self.verbosity = verbosity

    # ------------------------------------------------------------------------------
    # Public Functions
    # ------------------------------------------------------------------------------
    def error(self, message):
        """Log error messages."""
        if self.verbosity > 2:
            _stack = inspect.stack()[1]
            cls = _stack[0].f_locals["self"].__class__.__name__
            fnc = _stack[3]
            print("[ERROR] %s.%s(): %s" % (cls, fnc, repr(message)), file=sys.stderr)
        else:
            print("[ERROR] %s" % repr(message), file=sys.stderr)

    def warning(self, message):
        """Log warning messages."""
        if self.verbosity > 2:
            _stack = inspect.stack()[1]
            cls = _stack[0].f_locals["self"].__class__.__name__
            fnc = _stack[3]
            print("[WARN]  %s.%s(): %s" % (cls, fnc, repr(message)), file=sys.stderr)
        elif self.verbosity > 0:
            print("[WARN]  %s" % repr(message), file=sys.stderr)

    def info(self, message):
        """Log info messages."""
        if self.verbosity > 2:
            _stack = inspect.stack()[1]
            cls = _stack[0].f_locals["self"].__class__.__name__
            fnc = _stack[3]
            print("[INFO]  %s.%s(): %s" % (cls, fnc, repr(message)), file=sys.stderr)
        elif self.verbosity > 1:
            print("[INFO]  %s" % repr(message), file=sys.stderr)

    def debug(self, message):
        """Log debug messages."""
        if self.verbosity > 2:
            _stack = inspect.stack()[1]
            cls = _stack[0].f_locals["self"].__class__.__name__
            fnc = _stack[3]
            print("[DEBUG] %s.%s(): %s" % (cls, fnc, repr(message)), file=sys.stderr)
        elif self.verbosity > 2:
            print("[DEBUG] %s" % repr(message), file=sys.stderr)

    def trace(self, message):
        """Log trace messages."""
        if self.verbosity > 3:
            _stack = inspect.stack()[1]
            cls = _stack[0].f_locals["self"].__class__.__name__
            fnc = _stack[3]
            print("[TRACE] %s.%s(): %s" % (cls, fnc, repr(message)), file=sys.stderr)


# -------------------------------------------------------------------------------------------------
# CLASS: Runner
# -------------------------------------------------------------------------------------------------
class Runner(object):
    """Runner class that takes care about putting everything into threads."""

    # ------------------------------------------------------------------------------
    # Constructor / Destructor
    # ------------------------------------------------------------------------------
    def __init__(self, logger):
        """Constructor."""
        self.log = logger

    # ------------------------------------------------------------------------------
    # Public Functions
    # ------------------------------------------------------------------------------
    def set_recv_generator(self, func):
        """Set generator func which constantly receives network data."""
        self.recv_generator = func

    def set_input_generator(self, func):
        """Set generator func which constantly receives input (shell output/user input)."""
        self.input_generator = func

    def set_send_callback(self, func):
        """Set the callback for sending data to a socket."""
        self.send_callback = func

    def set_output_callback(self, func):
        """Set the callback for outputting data to stdin/stdout."""
        self.output_callback = func

    def set_time_action(self, intvl, func, *args, **kwargs):
        """Set a function that should be called periodically."""
        self.timed_action_intvl = intvl
        self.timed_action_func = func
        self.timed_action_args = args
        self.timed_action_kwargs = kwargs

    def run(self):
        """Run threaded NetCat."""

        assert hasattr(self, "recv_generator"), "Error, recv_generator not set"
        assert hasattr(self, "input_generator"), "Error, input_generator not set"
        assert hasattr(self, "send_callback"), "Error, send_callback not set"
        assert hasattr(self, "output_callback"), "Error, output_callback not set"

        def receiver():
            """Receive data from a socket and process it with a callback.

            receive: Must be a generator function to receive network data.
            callback: Must be a callback to process received data, e.g.: print to stdin/stdout.
            """
            while True:
                self.log.trace("[Thread] receiver loop")
                for data in self.recv_generator():
                    self.log.trace("[Thread] receiver received: {}".format(data))
                    self.output_callback(data)

        def sender():
            """Receive data from user-input/command-output and process it with a callback.

            receive: Must be a generator function to receive user-input or command output.
            callback: Must be a callback to send this data to a socket.
            """
            while True:
                self.log.trace("[Thread] sender")
                for data in self.input_generator():
                    self.log.trace("[Thread] sender received: {}".format(data))
                    self.send_callback(data)

        def timer():
            """Execute periodic tasks by an optional provided time_action."""
            self.log.debug(
                "Ready for timed action every {} seconds".format(self.timed_action_intvl)
            )
            time_last = int(time.time())
            while True:
                time_now = int(time.time())
                if time_now > time_last + self.timed_action_intvl:
                    self.log.debug("[{}] Executing timed function".format(time_now))
                    self.timed_action_func(*self.timed_action_args, **self.timed_action_kwargs)
                    time_last = time_now  # Reset previous time
                time.sleep(1)

        # Start sending and receiving threads
        tr = threading.Thread(target=receiver)
        ts = threading.Thread(target=sender)
        # If the main thread kills, this thread will be killed too.
        tr.daemon = True
        ts.daemon = True
        # Start threads
        tr.start()
        self.log.trace("Receiving thread started")
        ts.start()
        self.log.trace("Sending thread started")

        if hasattr(self, "timed_action_intvl"):
            tt = threading.Thread(target=timer)
            tt.daemon = True
            tt.start()
            self.log.trace("Timer thread started")

        # Cleanup the main program
        while True:
            if not tr.is_alive():
                sys.exit(0)
            if not ts.is_alive():
                sys.exit(0)
            if hasattr(self, "time_action_intvl"):
                if not tt.is_alive():
                    sys.exit(0)


# -------------------------------------------------------------------------------------------------
# CLASS: NetcatModuleOutput (Module for: user-input -> send -> receive -> output)
# -------------------------------------------------------------------------------------------------
class NetcatModuleOutput(AbstractNetcatModule):
    """Implement basic chat module functionality."""

    linefeed = "\n"

    # ------------------------------------------------------------------------------
    # Constructor / Destructor
    # ------------------------------------------------------------------------------
    def __init__(self, logger, encoder, options={}):
        """Set specific options for this module."""
        self.log = logger
        self.enc = encoder
        if "linefeed" in options:
            self.linefeed = options["linefeed"]

    # ------------------------------------------------------------------------------
    # Public Functions
    # ------------------------------------------------------------------------------
    def input_generator(self):
        """Constantly ask for user input."""
        while True:
            data = input()
            yield data + self.linefeed

    def output_callback(self, data):
        """Print received data to stdout."""
        if data.endswith("\r\n"):
            data = data[:-2]
        elif data.endswith("\n") or data.endswith("\r"):
            data = data[:-1]
        print(data)


# -------------------------------------------------------------------------------------------------
# CLASS: NetcatModuleCommand (Module for user-input -> send -> execute -> send-back -> output)
# -------------------------------------------------------------------------------------------------
class NetcatModuleCommand(AbstractNetcatModule):
    """Implement command execution functionality."""

    executable = None

    # ------------------------------------------------------------------------------
    # Constructor / Destructor
    # ------------------------------------------------------------------------------
    def __init__(self, logger, encoder, options={}):
        """Set specific options for this module."""
        self.log = logger
        self.enc = encoder
        assert "executable" in options
        self.log.debug("Setting '{}' as executable".format(options["executable"]))
        self.executable = options["executable"]

        # Open executable to wait for commands
        env = os.environ.copy()
        self.p = Popen(
            self.executable, stdin=PIPE, stdout=PIPE, stderr=subprocess.STDOUT, shell=False, env=env
        )
        # Define destructor
        atexit.register(self.__exit__)

    def __exit__(self):
        """Destructor."""
        self.log.trace("Killing executable: {} with pid {}".format(self.executable, self.p.pid))
        self.p.kill()

    # ------------------------------------------------------------------------------
    # Public Functions
    # ------------------------------------------------------------------------------
    def input_generator(self):
        """Constantly ask for input."""
        while True:
            self.log.trace("Reading command output")
            data = self.p.stdout.readline()  # Much better performance than self.p.read(1)
            data = self.enc.decode(data)
            self.log.trace("Command output: {}".format(data))
            if not data:
                self.log.error("No program output received")
                break
            yield data

    def output_callback(self, data):
        """Send data received to stdin (command input)."""
        data = self.enc.encode(data)
        self.log.trace("Appending to stdin: {}".format(data))
        self.p.stdin.write(data)
        self.p.stdin.flush()


# -------------------------------------------------------------------------------------------------
# CLASS: NetcatServer
# -------------------------------------------------------------------------------------------------
class NetcatServer(AbstractSocket):
    """Netcat Server implementation."""

    def __init__(self, logger, encoder, host, port, options={}):
        """Construct a listening server."""
        super(NetcatServer, self).__init__(logger, encoder, "server", options)

        # Setup server
        addr = self.gethostbyname(host)
        self.create_socket()
        self.bind(addr, port)
        if self.options["udp"]:
            self.conn = self.sock
            self.log.info("Waiting on {}:{} (UDP)".format(addr, port))
        else:
            self.listen()
            self.log.info("Listening on {}:{} (TCP)".format(addr, port))
            self.accept()


# -------------------------------------------------------------------------------------------------
# CLASS: NetcatClient
# -------------------------------------------------------------------------------------------------
class NetcatClient(AbstractSocket):
    """Netcat Client implementation."""

    def __init__(self, logger, encoder, host, port, options={}):
        """Construct a connecting clientt."""
        super(NetcatClient, self).__init__(logger, encoder, "client", options)

        # Setup client
        addr = self.gethostbyname(host)
        self.create_socket()
        self.conn = self.sock
        if self.options["udp"]:
            self.udp_client_addr = addr
            self.udp_client_port = port
        else:
            self.remote_addr = addr
            self.remote_port = port
            if not self.connect():
                sys.exit(1)


# -------------------------------------------------------------------------------------------------
# COMMAND LINE ARGUMENTS
# -------------------------------------------------------------------------------------------------


def get_version():
    """Return version information."""
    return """%(prog)s: Version %(version)s (%(url)s) by %(author)s""" % (
        {"prog": APPNAME, "version": VERSION, "url": APPREPO, "author": "cytopia"}
    )


def _args_check_port(value):
    """Check arguments for invalid port number."""
    min_port = 1
    max_port = 65535
    intvalue = int(value)

    if intvalue < min_port or intvalue > max_port:
        raise argparse.ArgumentTypeError("%s is an invalid port number" % value)
    return intvalue


def _args_check_forwards(value):
    """Check forward argument (-L/-R) for correct pattern."""
    match = re.search(r"(.+):(.+)", value)
    if match is None or len(match.groups()) != 2:
        raise argparse.ArgumentTypeError("%s is not a valid 'addr:port' format." % value)
    _args_check_port(match.group(2))
    return value


def _args_check_reinit(value):
    """Check reinit argument for correct value."""
    intvalue = int(value)
    if intvalue < 0:
        raise argparse.ArgumentTypeError("must be equal or greater than 0." % value)
    return intvalue


def _args_check_reconn(value):
    """Check reconn argument for correct value."""
    intvalue = int(value)
    if intvalue < 0:
        raise argparse.ArgumentTypeError("must be equal or greater than 0." % value)
    return intvalue


def _args_check_robin_ports(value):
    """Check reinit-robin argument for comma separated string or range."""
    mcomma = re.search(r"^[0-9]+(,([0-9]+))*$", value)
    mrange = re.search(r"^[0-9]+\-[0-9]+$", value)

    if mcomma is None:
        if mrange is None:
            raise argparse.ArgumentTypeError("%s is not a valid port specifier" % value)

    if mcomma:
        ports = mcomma.group(0).split(",")
        for port in ports:
            _args_check_port(port)
        return ports

    if mrange:
        ranges = mrange.group(0).split("-")
        if int(ranges[0]) >= (int(ranges[1]) + 1):
            raise argparse.ArgumentTypeError(
                "Left side of range must be smaller or equal than right side."
            )
        ports = []
        for port in range(int(ranges[0]), int(ranges[1]) + 1):
            _args_check_port(port)
            ports.append(port)
        return ports


def _args_check_mutually_exclusive(parser, args):
    """Check mutually exclusive arguments."""

    # [MODE] --listen
    if args.listen and (args.zero or args.local or args.remote):
        parser.print_usage()
        print(
            "%s: error: -l/--listen mutually exclusive with -z/-zero, -L/--local or -R/--remote"
            % (APPNAME),
            file=sys.stderr,
        )
        sys.exit(1)

    # [MODE] --zero
    if args.zero and (args.listen or args.local or args.remote):
        parser.print_usage()
        print(
            "%s: error: -z/--zero mutually exclusive with -l/--listen, -L/--local or -R/--remote"
            % (APPNAME),
            file=sys.stderr,
        )
        sys.exit(1)

    # [MODE --local
    if args.local and (args.listen or args.zero or args.remote):
        parser.print_usage()
        print(
            "%s: error: -L/--local mutually exclusive with -l/--listen, -z/--zero or -R/--remote"
            % (APPNAME),
            file=sys.stderr,
        )
        sys.exit(1)

    # [MODE] --remote
    if args.remote and (args.listen or args.zero or args.local):
        parser.print_usage()
        print(
            "%s: error: -R/--remote mutually exclusive with -l/--listen, -z/--zero or -L/--local"
            % (APPNAME),
            file=sys.stderr,
        )
        sys.exit(1)

    # [MODULE] --exec
    if args.cmd and (args.local or args.remote or args.zero):
        parser.print_usage()
        print(
            "%s: error: -e/--exec mutually exclusive with -L/--local, -R/-remote or -z/--zero"
            % (APPNAME),
            file=sys.stderr,
        )
        sys.exit(1)

    # [OPTIONS] --udp
    if args.udp and args.zero:
        parser.print_usage()
        print(
            "%s: error: -u/--udp mutually exclusive with -z/--zero" % (APPNAME), file=sys.stderr,
        )
        sys.exit(1)

    # [ADVANCED] --reinit
    if args.reinit is not False and (args.udp or args.local or args.remote):
        parser.print_usage()
        print(
            "%s: error: --reinit mutually exclusive with -u/--udp, -L/--local or -R/--remote"
            % (APPNAME),
            file=sys.stderr,
        )
        sys.exit(1)

    # [ADVANCED] --reconn
    if args.reconn is not False and (args.udp or args.local or args.remote or args.zero):
        parser.print_usage()
        print(
            "%s: error: --reinit mutually excl. with -u/--udp, -L/--local -R/--remote or -z/--zero"
            % (APPNAME),
            file=sys.stderr,
        )
        sys.exit(1)

    # [ADVANCED] --udp-ping-interval
    if args.udp_ping_intvl and not args.udp:
        parser.print_usage()
        print(
            "%s: error: --udp-ping-intvl mutually exclusive with -u/--udp" % (APPNAME),
            file=sys.stderr,
        )
        sys.exit(1)


def get_args():
    """Retrieve command line arguments."""
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawTextHelpFormatter,
        add_help=False,
        usage="""%(prog)s [-Cnuv] [-e cmd] hostname port
       %(prog)s [-Cnuv] [-e cmd] -l [hostname] port
       %(prog)s [-Cnuv] -z hostname port
       %(prog)s [-Cnuv] -L addr:port [hostname] port
       %(prog)s [-Cnuv] -R addr:port hostname port
       %(prog)s -V, --version
       %(prog)s -h, --help
       """
        % ({"prog": APPNAME}),
        description="""
Enhanced and comptaible Netcat implementation written in Python (2 and 3) with
connect, zero-i/o, listen and forward modes and techniques to detect and evade
firewalls and intrusion prevention systems.""",
    )

    positional = parser.add_argument_group("positional arguments")
    mode = parser.add_argument_group("mode arguments")
    optional = parser.add_argument_group("optional arguments")
    advanced = parser.add_argument_group("advanced arguments")
    misc = parser.add_argument_group("misc arguments")

    positional.add_argument(
        "hostname", nargs="?", type=str, help="Address to listen, forward or connect to"
    )
    positional.add_argument(
        "port", type=_args_check_port, help="Port to listen, forward or connect to"
    )

    mode.add_argument(
        "-l",
        "--listen",
        action="store_true",
        default=False,
        help="""[Listen mode]:
Start server and listen for incoming connections.

""",
    )
    mode.add_argument(
        "-z",
        "--zero",
        action="store_true",
        default=False,
        help="""[Zero-I/0 mode]:
Connect to a remote endpoint and report status only.

""",
    )
    mode.add_argument(
        "-L",
        "--local",
        metavar="addr:port",
        default=False,
        type=_args_check_forwards,
        help="""[Local forward mode]:
Specify local <addr>:<port> to which traffic should be
forwarded to. %(prog)s will listen locally
(specified by hostname and port) and forward all
traffic to the specified value for -L/--local.

"""
        % ({"prog": APPNAME}),
    )
    mode.add_argument(
        "-R",
        "--remote",
        metavar="addr:port",
        default=False,
        type=_args_check_forwards,
        help="""[Remote forward mode]:
Specify local <addr>:<port> from which traffic should be
forwarded from. %(prog)s will connect remotely
(specified by hostname and port) and for ward all
traffic from the specified value for -R/--remote.
"""
        % ({"prog": APPNAME}),
    )

    optional.add_argument(
        "-e",
        "--exec",
        metavar="cmd",
        dest="cmd",
        default=False,
        type=str,
        help="Execute shell command. Only for connect or listen mode.",
    )
    optional.add_argument(
        "-C",
        "--crlf",
        action="store_true",
        default=False,
        help="Send CRLF line-endings in connect mode (default: LF)",
    )
    optional.add_argument(
        "-n", "--nodns", action="store_true", default=False, help="Do not resolve DNS.",
    )
    optional.add_argument("-u", "--udp", action="store_true", default=False, help="UDP mode")
    optional.add_argument(
        "-v",
        "--verbose",
        action="count",
        default=0,
        help="""Be verbose and print info to stderr. Use -v, -vv, -vvv
or -vvvv for more verbosity. The server performance will
decrease drastically if you use more than three -v.""",
    )
    advanced.add_argument(
        "--reinit",
        metavar="x",
        default=False,
        type=_args_check_reinit,
        help="""Listen mode (TCP only):
If the server is unable to bind or accept clients, it
will re-initialize itself x many times before giving up.
Use 0 to re-initialize endlessly. (default: don't).

Connect mode (TCP only):
If the client is unable to connect to a remote endpoint,
it will try again x many times before giving up.
Use 0 to retry endlessly. (default: don't)

Zero-I/O mode (TCP only):
Same as connect mode.

""",
    )
    advanced.add_argument(
        "--reconn",
        metavar="x",
        default=False,
        type=_args_check_reconn,
        help="""Listen mode (TCP only):
If the client has hung up, the server will re-accept a
new client x many times before quitting. Use 0 to accept
endlessly. (default: quit after a client has hung up)

Connect mode (TCP only):
If the remote server is gone, the client will re-connect
to it x many times before giving up. Use 0 to reconnect
endlessy. (default: don't)
This might be handy for reverse shells ;-)

""",
    )
    advanced.add_argument(
        "--reinit-robin",
        metavar="port",
        default=[],
        type=_args_check_robin_ports,
        help="""Connect mode (TCP only):
If the client does multiple initial connections to a
remote endpoint (via --reinit), this option instructs it
to also "round-robin" different ports to connect to. It
will stop iterating after first successfull connection
and stick with it or quit if --reinit limit is reached.
Use comma separated string: 80,81,82 or a range 80-100.
Set --reinit to at least the number of ports to probe +1
Set --reinit-wait to 0
This helps to evade EGRESS firewalls for reverse shells
Use with -z/--zero to probe outbound allowed ports.
Ensure to have enough listeners at the remote endpoint.

""",
    )
    advanced.add_argument(
        "--reconn-robin",
        metavar="port",
        default=[],
        type=_args_check_robin_ports,
        help="""Connect mode (TCP only):
If the remote endpoint is gone after initial successful
connection, and the the client is set to reconnect with
(--reconn), it will connect back by "round-robin" to
different ports. It will stop after --reconn limit has
reached.
Set --reconn to at least the number of ports to probe +1
Set --reconn-wait to 0
This help your reverse shell to evade intrusion
prevention systems that will cut your connection and
block the outbound port.

""",
    )
    advanced.add_argument(
        "--reinit-wait",
        metavar="s",
        default=1,
        type=int,
        help="Wait x seconds between re-inits. (default: 1)\n\n",
    )
    advanced.add_argument(
        "--reconn-wait",
        metavar="s",
        default=1,
        type=int,
        help="Wait x seconds between re-connects. (default: 1)\n\n",
    )
    advanced.add_argument(
        "--udp-ping-intvl",
        metavar="s",
        default=False,
        type=int,
        help="""Connect mode (UDP only):
As UDP is stateless, a client must first connect to a
server before the server can communicate with it.
If you listen on UDP and wait for a reverse UDP client
or reverse UDP shell, you can only talk to it after it
has sent you some initial data, as UDP does not have a
"connect" state like TCP.
This option instructs the UDP client to send a single
newline every s seconds. By not only doing it once,
but in intervals, you can also maintain a connection
if you restart your listening server.
""",
    )
    misc.add_argument("-h", "--help", action="help", help="Show this help message and exit")
    misc.add_argument(
        "-V",
        "--version",
        action="version",
        version=get_version(),
        help="Show version information and exit",
    )

    # Retrieve arguments
    args = parser.parse_args()

    # Check mutually exclive arguments
    _args_check_mutually_exclusive(parser, args)

    # TODO: Exit on unimplemented features
    if args.remote:
        parser.print_usage()
        print(
            "%s: error: -L/--local and -R/--remote are not yet implemented" % (APPNAME),
            file=sys.stderr,
        )
        sys.exit(1)

    return args


# -------------------------------------------------------------------------------------------------
# MAIN ENTRYPOINT
# -------------------------------------------------------------------------------------------------
def main():
    """Run the program."""
    args = get_args()
    host = args.hostname if args.hostname is not None else "0.0.0.0"
    port = args.port

    # Set netcat options
    net_opts = {
        "udp": args.udp,
        "bufsize": 1024,
        "backlog": 0,
        "nodns": args.nodns,
        "reinit": True if (type(args.reinit) is int and args.reinit == 0) else args.reinit,
        "reconn": True if (type(args.reconn) is int and args.reconn == 0) else args.reconn,
        "reinit_robin": args.reinit_robin,
        "reconn_robin": args.reconn_robin,
        "reinit_wait": args.reinit_wait,
        "reconn_wait": args.reconn_wait,
        "udp_ping_intvl": args.udp_ping_intvl,  # only for udp client and only for rev-shell (0:off)
    }
    # Initialize logger
    logger = Logger(args.verbose)
    # Initialize encoder
    encoder = StringEncoder(logger)
    # Use command modulde
    if args.cmd:
        module_opts = {"executable": args.cmd}
        mod = NetcatModuleCommand(logger, encoder, module_opts)
    # Use output module
    else:
        module_opts = {"linefeed": "\r\n" if args.crlf else "\n"}
        mod = NetcatModuleOutput(logger, encoder, module_opts)

    # Run local port-forward
    # -> listen locally and forward traffic to remote (connect)
    if args.local:
        # Create listen and client instances
        lhost = args.local.split(":")[0]
        lport = int(args.local.split(":")[1])
        net_srv = NetcatServer(logger, encoder, lhost, lport, net_opts)
        net_cli = NetcatClient(logger, encoder, host, port, net_opts)

        # Create Runner (the set_* funcs below are brainfuck and took me 1 hour to figure it out)
        run = Runner(logger)

        # [srv] User-Client connects here, sends data and the Server takes it as input
        run.set_recv_generator(net_srv.receive)
        # [cli] Runner parses data from Server on to Proxy-Client, which sends/connect it further
        run.set_output_callback(net_cli.send)
        # [cli] Proxy-Client waits for response and receives data back
        run.set_input_generator(net_cli.receive)
        # [srv] Runner parses data from Proxy-Client onto Server, which sends/back to User-Client
        run.set_send_callback(net_srv.send)

        # And finally run
        run.run()
    # Run server
    if args.listen:
        net = NetcatServer(logger, encoder, host, port, net_opts)
        run = Runner(logger)
        run.set_recv_generator(net.receive)
        run.set_input_generator(mod.input_generator)
        run.set_send_callback(net.send)
        run.set_output_callback(mod.output_callback)
        run.run()
    # Run client
    else:
        net = NetcatClient(logger, encoder, host, port, net_opts)
        run = Runner(logger)
        run.set_recv_generator(net.receive)
        run.set_input_generator(mod.input_generator)
        run.set_send_callback(net.send)
        run.set_output_callback(mod.output_callback)
        if type(args.udp_ping_intvl) is int and args.udp_ping_intvl > 0:
            run.set_time_action(args.udp_ping_intvl, net.send, "\x00")
        run.run()


if __name__ == "__main__":
    # Catch Ctrl+c and exit without error message
    try:
        main()
    except KeyboardInterrupt:
        print()
        sys.exit(1)
