# Copyright DST Group. Licensed under the MIT license.
from datetime import datetime
from ipaddress import IPv4Address


# use msf module auxiliary/scanner/ssh/ssh_login
# set RHOSTS to target, maybe set VERBOSE to false, also can set RPORT but this is 22 by default
# set USERPASS_FILE (I set it to "/usr/share/metasploit_framework/data/wordlists/root_userpass.txt", having first added
# credentials "user user" to last line of file)
# Gives non-TTY session that needs to be interacted with first with "sessions -i 1" (may not be 1)
from time import strptime

from csle_cyborg.shared.actions.msf_actions_folder.msf_action import lo, lo_subnet
from csle_cyborg.shared.actions.msf_actions_folder.remote_code_execution_folder.remote_code_execution import RemoteCodeExecution
from csle_cyborg.simulator.session import SessionType
from csle_cyborg.shared.enums import ProcessType
from csle_cyborg.shared.observation import Observation
from csle_cyborg.simulator.state import State


class SSHLoginExploit(RemoteCodeExecution):
    def __init__(self, ip_address: IPv4Address, agent: str, session: int, port: int):
        super().__init__(session=session, agent=agent)
        self.target = ip_address
        self.port = port

    def sim_execute(self, state: State):
        obs = Observation()
        obs.set_success(False)
        if self.session not in state.sessions[self.agent]:
            return obs
        session = state.sessions[self.agent][self.session]

        if session.session_type != SessionType.MSF_SERVER or not session.active:
            return obs

        target_subnet = None
        if self.target == lo:
            target_subnet = lo_subnet
            server_interface = [i for i in state.hosts[session.host].interfaces if i.ip_address == lo][0]
        else:
            for subnet in state.subnets.values():
                if self.target in subnet.ip_addresses:
                    target_subnet = subnet
                    break

            server_session, server_interface = self.get_local_source_interface(local_session=session, remote_address=self.target,
                                                                      state=state)


            if server_interface is None:
                return obs

            if not self.test_nacl(port=self.port, target_subnet=target_subnet,
                                  originating_subnet=state.subnets[server_interface.subnet]):
                return obs

        server_address = server_interface.ip_address
        if server_address is None:
            return obs

        if self.target == IPv4Address("127.0.0.1"):
            target_host = state.hosts[session.host]
        else:
            target_host = state.hosts[state.ip_addresses[self.target]]

        ssh_proc = None
        port = None
        # should also check for a non-ssh process listening on port 22 - only extra info in obs will be that this
        # process exists on that port and IP
        for proc in target_host.processes:
            if proc.process_type == ProcessType.SSH:
                for conn in proc.connections:
                    if conn['local_port'] == self.port:
                        ssh_proc = proc
                        port = conn
                if ssh_proc is not None:
                    break

        if ssh_proc is None or (port["local_address"] != IPv4Address("0.0.0.0") and port["local_address"] != self.target):
            return obs

        user_found = False
        user_user = None
        for u in target_host.users:
            if u.bruteforceable:
                user_found = True
                user_user = u

        if user_found:
            obs.set_success(True)
            obs.add_process(hostid=str(self.target), local_address=self.target, local_port=self.port, status="open",
                            process_type="ssh", app_protocol='ssh')

            user_ssh = target_host.add_process(name="sshd", ppid=ssh_proc.pid, path=ssh_proc.path,
                                                                   user=user_user, process_type="ssh")

            new_session = state.add_session(host=target_host.hostname, agent=self.agent,
                                            user=user_user.username, session_type="msf shell", parent=server_session.ident)
            process = target_host.get_process(new_session.pid)
            process.name = "bash"
            process.ppid = user_ssh.pid
            process.path = "/bin/"
            process.user = user_user
            remote_port = state.hosts[server_session.host].get_ephemeral_port()
            new_connection = {"local_port": self.port,
                              "Application Protocol": "tcp",
                              "remote_address": server_address,
                              "remote_port": remote_port,
                              "local_address": self.target}
            process.connections.append(new_connection)

            remote_port_dict = {'local_port': remote_port,
                                "Application Protocol": "ssh",
                                "local_address": server_address,
                                "remote_address": self.target,
                                "remote_port": self.port
                                }
            state.hosts[server_session.host].get_process(server_session.pid).connections.append(remote_port_dict)
            if session != server_session:
                remote_port = None
            obs.add_process(hostid=str(server_address), local_address=server_address, remote_address=str(self.target),
                            local_port=remote_port, remote_port=self.port)
            obs.add_process(hostid=str(self.target), local_address=str(self.target), remote_address=server_address,
                            local_port=self.port, remote_port=remote_port)
            obs.add_session_info(hostid=str(self.target), username=user_user.username, session_id=new_session.ident, session_type="msf shell", agent=self.agent)
            if target_host.os_type == OperatingSystemType.LINUX:
                obs.add_user_info(hostid=str(self.target), username=user_user.username, password=user_user.password, uid=user_user.uid)

                obs.add_system_info(hostid=str(self.target), hostname=target_host.hostname, architecture=target_host.architecture, os_kernel=target_host.kernel, os_type=target_host.os_type, os_distribution=target_host.distribution)
            else:
                obs.add_user_info(hostid=str(self.target), username=user_user.username, password=user_user.password)
        return obs

    def emu_execute(self, session_handler) -> Observation:
        obs = Observation()
        stop_on_success = True  # for the speed up during testing
        from csle_cyborg.CybORG import MSFSessionHandler
        if type(session_handler) is not MSFSessionHandler:
            obs.set_success(False)
            return obs
        output = session_handler.execute_module(mtype='auxiliary', mname='scanner/ssh/ssh_login',
                                         opts={'RHOSTS': str(self.target), "USERPASS_FILE": '/usr/share/wordlists/top100_userpass_msf.txt', "STOP_ON_SUCCESS": stop_on_success})
        obs.add_raw_obs(output)
        obs.set_success(False)
        username = None  # cheat to allow us to know the username of the new session
        try:
            for line in output.split('\n'):
                # Example success line: "[+] 10.0.2.164:22 - Success: 'pi:raspberry' 'uid=1001(pi) gid=1001(pi) groups=1001(pi) Linux pretend-pi 4.15.0-1057-aws #59-Ubuntu SMP Wed Dec 4 10:02:00 UTC 2019 x86_64 x86_64 x86_64 GNU/Linux '"
                # another example: [+] Success: 'pi:raspberry' 'uid=1001(pi) gid=1001(pi) groups=1001(pi) Linux ip-10-0-31-50 4.15.0-1057-aws #59-Ubuntu SMP Wed Dec 4 10:02:00 UTC 2019 x86_64 x86_64 x86_64 GNU/Linux '
                # another example: [+] Success: 'pi:raspberry' 'Could not chdir to home directory /home/pi: No such file or directory uid=1001(pi) gid=1001(pi) groups=1001(pi) Could not chdir to home directory /home/pi: No such file or directory Linux ip-10-0-10-199 4.15.0-1057-aws #59-Ubuntu SMP Wed Dec 4 10:02:00 UTC 2019 x86_64 x86_64 x86_64 GNU/Linux '
                if '[+]' in line:
                    # session_handler._log_debug(f"New positive result: {line}")
                    if ':' in line.split('Success: ')[0]:
                        ip_address, port = line.split('Success: ')[0].split(' ')[1].split(':')
                    else:
                        ip_address = self.target
                        port = self.port
                    line = line.split('Success: ')[-1]
                    split = line.split(' ')
                    # print(list(enumerate(split)))
                    if 'Could not chdir to home directory' in line:
                        split = [split[0]] + split[13:16] + split[28:]
                    username, password = split[0].replace('\'', '').split(':')
                    if 'id: command not found' not in line:
                        uid = split[1].replace('\'', '').split('=')[1].split('(')[0]
                        # gid = split[2].replace('\'', '').split('=')[1].split('(')[0]
                        # group = split[3].replace('\'', '').split('=')[1].split('(')[1].split(')')[0]
                        os = split[4]
                        hostname = split[5]
                        os_kv = split[6]
                        os_d = split[7].split('-')[1]
                        arch = split[15]

                        obs.add_user_info(hostid=str(self.target), username=username, uid=int(uid), password=password)
                        obs.add_process(hostid=str(self.target), local_port=port, local_address=ip_address, app_protocol='ssh', status='open', process_type="ssh")
                        obs.add_system_info(hostid=str(self.target), hostname=hostname, os_type=os, os_distribution=os_d, architecture=arch)
                    else:
                        obs.add_user_info(hostid=str(self.target), username=username, password=password)
                        obs.add_process(hostid=str(self.target), local_port=port, local_address=ip_address,
                                        app_protocol='ssh', status='open', process_type="ssh")
                    # print(f'ip: {ip_address}')
                    # print(f'port: {port}')
                    # print(f'user: {username}')
                    # print(f'password: {password}')
                    # print(f'uid: {uid}')
                    # print(f'gid: {gid}')
                    # print(f'groups: {group}')
                    # print(f'OS Type: {os}')
                    # print(f'hostname: {hostname}')
                    # print(f'OS kernel version: {os_kv}')
                    # print(f'OS distro: {os_d}')
                    # print(f'arch: {arch}')
                    # print(f'time: {time}')
                if '[*]' in line:
                    if "Command shell session" in line:
                        obs.set_success(True)
                        split = line.split(' ')
                        # print(list(enumerate(split)))
                        session = int(split[4])
                        local_address, local_port = split[8][:-1].split(':')
                        if '-' in split[6]:
                            temp = split[6].replace('(', '').split(':')[0]
                            origin, remote_address = temp.split('-')
                            # obs.add_process(hostid=str(origin), remote_address=remote_address, local_address=origin)
                            # obs.add_process(hostid=str(remote_address), local_address=remote_address, remote_address=origin)
                            remote_port = None
                        else:
                            remote_address, remote_port = split[6].replace('(', '').split(':')

                        #date = datetime.fromisoformat(split[10] + ' ' + split[11])
                        obs.add_session_info(hostid=str(self.target), username=username, session_id=session, agent=self.agent, session_type='msf shell')
                        obs.add_process(hostid=str(self.target), local_port=local_port, remote_port=remote_port, local_address=local_address, remote_address=remote_address)
                        obs.add_process(hostid=str(remote_address), local_address=remote_address,
                                        local_port=remote_port, remote_port=local_port,
                                        remote_address=local_address)
                        # print(f'session: {session}')
                        # print(f'local_port: {local_port}')
                        # print(f'local_address: {local_address}')
                        # print(f'remote_port: {remote_port}')
                        # print(f'remote_address: {remote_address}')
                        # print(f'date: {date}')
        except Exception as ex:
            session_handler._log_debug(f'Error occured in parsing of output: {output}')
            raise ex

        # session_handler._log_debug(output)
        return obs

    def __str__(self):
        return super(SSHLoginExploit, self).__str__() + f", Target: {self.target}:{self.port}"