#!/usr/bin/env python

import os
import sys
import json
import magic
import hashlib
import logging
import pkgutil
import argparse
from glob import glob
from functools import partial
from multiprocessing import Pool
from malduck.extractor import ExtractManager, ExtractorModules

log = logging.getLogger(__name__)

__author__  = 'c3rb3ru5'
__version__ = '1.0.0'

def extract_config_worker(modules, file_path):
    if os.stat(file_path).st_size != 0:
        modules = ExtractorModules(modules_path=modules)
        f = open(file_path, 'rb')
        data = f.read()
        f.close()
        ext = ExtractManager(modules=modules)
        ext.push_file(file_path)
        return {
            'name': file_path,
            'type': magic.from_buffer(data),
            'mime': magic.from_buffer(data, mime=True),
            'md5': hashlib.md5(data).hexdigest(),
            'sha1': hashlib.sha1(data).hexdigest(),
            'sha256': hashlib.sha256(data).hexdigest(),
            'configs': ext.config
        }
    return None

class MWCFG():

    """
    A Malware Configuration Extraction Utility for MalDuck
    """

    def __init__(self):
        pass

    def arguments(self):
        self.parser = argparse.ArgumentParser(
            prog=f'mwcfg v{__version__}',
            description='A Modular Malware Configuration Extraction Utility for MalDuck',
            epilog=f'Author: {__author__}'
        )
        self.parser.add_argument(
            '--version',
            action='version',
            version=f'v{__version__}'
        )
        self.parser.add_argument(
            '-i',
            '--input',
            type=str,
            default=None,
            help='Input File or Directory',
            required=False
        )
        self.parser.add_argument(
            '-m',
            '--modules',
            type=str,
            default=None,
            required=True,
            help='Modules'
        )
        self.parser.add_argument(
            '--list-modules',
            action='store_true',
            default=False,
            required=False
        )
        self.parser.add_argument(
            '-d',
            '--debug',
            action='store_true',
            required=False,
            default=False,
            help='Debug'
        )
        self.parser.add_argument(
            '-p',
            '--pretty',
            action='store_true',
            default=False,
            required=False,
            help='Pretty Print Configs'
        )
        self.parser.add_argument(
            '-t',
            '--threads',
            default=1,
            type=int,
            required=False,
            help='Threads'
        )
        self.parser.add_argument(
            '-r',
            '--recursive',
            action='store_true',
            default=False,
            required=False,
            help='Recursive'
        )
        self.parser.add_argument(
            '-l',
            '--log',
            type=str,
            default=None,
            help='Log to File'
        )
        self.args = self.parser.parse_args()

    def print_results(self, results):
        results = list(filter(None, results))
        log.debug(str(results))
        if results is not None and self.args.pretty is True:
            print(json.dumps(results, indent=4))
        if results is not None and self.args.pretty is False:
            print(str(results))

    def extract_configs(self):
        results = []
        if os.path.isfile(self.args.input):
            results.append(extract_config_worker(self.args.modules, self.args.input))
        if os.path.isdir(self.args.input):
            files = glob(f'{self.args.input}/**', recursive=self.args.recursive)
            files = [f for f in files if os.path.isfile(f)]
            pool = Pool(processes=self.args.threads)
            results = pool.map(partial(extract_config_worker, self.args.modules,), files)
        self.print_results(results)

    def logging(self):
        level = logging.ERROR
        log_format = '%(asctime)s - %(levelname)s - %(message)s'
        datefmt = '%Y-%m-%dT%H:%M:%S%z'
        if self.args.debug is True:
            level = logging.DEBUG
        logging.basicConfig(filename=self.args.log, level=level, format=log_format, datefmt=datefmt)

    def print_modules(self):
        modules = [name for _, name, _ in pkgutil.iter_modules([self.args.modules])]
        for module in modules:
            print(module)

    def main(self):
        self.arguments()
        self.logging()
        if self.args.list_modules is True:
            self.print_modules()
            sys.exit(0)
        if self.args.input is None:
            log.error('input file required')
            sys.exit(1)
        self.extract_configs()

if __name__ in  '__main__':
    try:
        mwcfg = MWCFG()
        mwcfg.main()
    except KeyboardInterrupt:
        log.warning('interrupted, exiting...')
        sys.exit(0)
    except Exception  as error:
        log.error(error)
        sys.exit(1)
