#!/usr/bin/env python3

"""
<sphinx>
ssh-command
-----------

Calls remote process using SSH and expects: exit code, keywords in the output

Parameters:

- user (default: root)
- host
- port (default: 22)
- private_key
- password
- ssh_bin (default: ssh)
- sshpass_bin (default: sshpass)
- ssh_opts (example: -o StrictHostKeyChecking=no)
- known_hosts_file (default: ~/.ssh/known_hosts)
- command (default: uname -a)
- timeout: (default: 15, unit: seconds)
- expected_keywords (Keywords expected to be in stdout/stderr. Separated by ";")
- unexpected_keywords (Keywords not expected to be present in stdout/stderr. Separated by ";")
- expected_exit_code (default: 0)
</sphinx>
"""


from typing import List
import os
import sys
import inspect


path = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) + '/../../'
sys.path.insert(0, path)

from infracheck.infracheck.checklib.ssh import SSH


class SshCommand(SSH):
    def main(self, command: str, expected_keywords: List[str], expected_exit_code: int, unexpected_keywords: List[str]):
        out, exit_code = self.execute(command)
        truncated_output = (str(out)[-256:]).strip()

        if int(exit_code) != expected_exit_code:
            return False, "Unexpected exit code %i with output: %s" % (exit_code, truncated_output)

        for keyword in expected_keywords:
            if keyword and keyword not in out:
                return False, "Expected keyword '%s' to be present in output. Got: %s" % (
                    keyword, truncated_output
                )

        for keyword in unexpected_keywords:
            if keyword and keyword in out:
                return False, "Unexpected keyword '%s'. Should not be present in the output. Got: %s" % (
                    keyword, truncated_output
                )

        return True, "OK"


if __name__ == '__main__':
    if not os.getenv('HOST'):
        print("HOST is mandatory")
        exit(1)

    app = SshCommand(
        user=os.getenv('USER', 'root'),
        host=os.getenv('HOST'),
        port=int(os.getenv('PORT', 22)),
        password=os.getenv('PASSWORD', ''),
        private_key=os.getenv('PRIVATE_KEY', ''),
        ssh_bin=os.getenv('SSH_BIN', 'ssh'),
        sshpass_bin=os.getenv('SSHPASS_BIN', 'sshpass'),
        ssh_opts=os.getenv('SSH_OPTS', ''),
        timeout=int(os.getenv('TIMEOUT', 15)),
        known_hosts_file=os.getenv('KNOWN_HOSTS_FILE', os.path.expanduser('~/.ssh/known_hosts'))
    )

    status, message = app.main(
        command=os.getenv('COMMAND', 'uname -a'),
        expected_keywords=os.getenv('EXPECTED_KEYWORDS', '').split(';'),
        expected_exit_code=int(os.getenv('EXPECTED_EXIT_CODE', 0)),
        unexpected_keywords=os.getenv('UNEXPECTED_KEYWORDS', '').split(';')
    )

    print(message)
    sys.exit(0 if status else 1)
