#!/usr/bin/python3
# Copyright (C) 2019 Evgeny Golyshev <eugulixes@gmail.com>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

"""BSC server. BSC stands for Build Status Codes.
The server is very simple and supposed to handle one client at a time.
"""

import logging
import os
import socket
import sys
from argparse import ArgumentParser

from daemon import DaemonContext, pidfile
from redis.exceptions import RedisError

from pieman import codes, util

BUFSIZE = 4096

DAEMON_NAME = 'bscd'

LOGGER = logging.getLogger(DAEMON_NAME)


def test_write_perm(dir_name):
    """Tests if the current process has enough permissions to create files
    in the specified directory.
    """

    dir_name = dir_name if dir_name else './'
    if not os.access(dir_name, os.W_OK):
        LOGGER.fatal('%s is not writable for the process. Make sure that the '
                     'directory exists and the process has enough permissions '
                     'to create files there.', dir_name)
        sys.exit(1)


class BSCServer:
    """Class implementing the build status codes server. """

    def __init__(self, unix_socket_name, channel_name, redis_host='127.0.0.1',
                 redis_port=6379):
        self._channel_name = channel_name
        self._conn = None
        self._redis_conn = util.connect_to_redis(redis_host, redis_port)
        self._server = None
        self._unix_socket_name = unix_socket_name

    #
    # Private methods
    #

    def _accept(self):
        """Accepts a connection. """

        self._conn, _ = self._server.accept()

        LOGGER.debug('Client connected')

    def _close_server(self):
        """Closes the specified server and removes the corresponding Unix
        domain socket file.
        """
        sock_name = self._server.getsockname()

        self._server.close()

        if os.path.exists(sock_name):
            os.unlink(sock_name)

    def _run_loop(self):
        """Runs the loop which handles requests from the client. """

        status_codes = [name for name in dir(codes) if name.endswith('_CODE')]

        handlers = {getattr(codes, code): self.send_code_to_channel
                    for code in status_codes}

        handlers[b'EXIT'] = self.exit_server

        self._accept()

        while True:
            request = self._conn.recv(BUFSIZE)

            if not request:
                LOGGER.debug('Client disconnected')

                self._accept()

                # Not to pass an empty request further, skip the current
                # iteration
                continue

            try:
                handler = handlers[request]
            except KeyError:
                LOGGER.fatal('Unknown request %s', request)
                continue

            handler(request)

    #
    # Public methods
    #

    def exit_server(self, _code):
        """Shuts down the server. """

        LOGGER.info('Shutting down the server since the EXIT request was sent')
        self._close_server()
        sys.exit(0)

    def run(self):
        """Runs the server. """

        self._server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        self._server.bind(self._unix_socket_name)

        self._server.listen(1)

        try:
            self._run_loop()
        except KeyboardInterrupt:
            LOGGER.info('Shutting down the server since the signal was '
                        'generated by Ctrl-C')
            self._close_server()
            sys.exit(1)

    def send_code_to_channel(self, code):
        """Sends the specified build code to the Redis channel. """
        LOGGER.info('Handling %s', code)
        self._redis_conn.publish(self._channel_name, code)


def main():
    """The main entry point. """

    parser = ArgumentParser()
    parser.add_argument('-a', '--unix-socket-name',
                        default='/var/run/{}.sock'.format(DAEMON_NAME),
                        help='Unix domain socket file name',
                        metavar='UNIX_SOCKET_NAME')
    parser.add_argument('-d', '--daemonize', action='store_true',
                        help='daemonize the server')
    parser.add_argument('-H', '--redis-host', default='127.0.0.1',
                        help='Redis server host', metavar='REDIS_HOST')
    parser.add_argument('-l', '--log-file-prefix', default='',
                        help='path prefix for log files',
                        metavar='LOG_FILE_PREFIX')
    parser.add_argument('-n', '--channel-name', default='build_status_codes',
                        help='channel which is used for publishing build '
                             'status codes to',
                        metavar='CHANNEL_NAME')
    parser.add_argument('-p', '--pid',
                        default='/var/run/{}.pid'.format(DAEMON_NAME),
                        help='server pid file name',
                        metavar='PID_FILE_NAME')
    parser.add_argument('-P', '--redis-port', default='6379',
                        help='server pid file name', metavar='REDIS_PORT',
                        type=int)
    parser.add_argument('-v', '--verbose', action='store_true',
                        help='turn on verbose mode')

    args = parser.parse_args()

    util.init_logger(LOGGER, logging.DEBUG if args.verbose else logging.INFO,
                     args.log_file_prefix)

    if os.path.exists(args.unix_socket_name):
        LOGGER.fatal('Address already in use')
        sys.exit(1)
    else:
        test_write_perm(os.path.dirname(args.unix_socket_name))

    if args.log_file_prefix:
        test_write_perm(os.path.dirname(args.log_file_prefix))
        LOGGER.debug('All log records will be forwarded to %s',
                     args.log_file_prefix)
    else:
        LOGGER.debug('No log file will be created')

    try:
        bscd = BSCServer(args.unix_socket_name, args.channel_name,
                         redis_host=args.redis_host,
                         redis_port=args.redis_port)
    except RedisError:
        LOGGER.fatal('Could not connect to Redis')
        sys.exit(1)

    if args.daemonize:
        test_write_perm(os.path.dirname(args.pid))

        LOGGER.debug('Server will be daemonized')

        # Preserve the fds related to the logger to prevent closing when
        # daemonizing the server.
        preserved_fds = []
        for handler in LOGGER.handlers:
            if hasattr(handler, 'stream'):
                preserved_fds.append(handler.stream.fileno())  # pylint: disable=no-member

        with DaemonContext(files_preserve=preserved_fds,
                           pidfile=pidfile.TimeoutPIDLockFile(args.pid),
                           working_directory=os.getcwd()):
            bscd.run()
    else:
        bscd.run()


if __name__ == '__main__':
    main()
