#!/usr/bin/env python3

import argparse
import errno
import io
import logging
import multiprocessing
import pathlib
import signal
import sys
import progressbar

from typing import (
	List,
	BinaryIO,
	Tuple
)

from DyldExtractor.converter import (
	linkedit_optimizer,
	macho_offset,
	objc_fixer,
	slide_info,
	stub_fixer
)

from DyldExtractor.dyld.dyld_context import DyldContext
from DyldExtractor.extraction_context import ExtractionContext
from DyldExtractor.macho.macho_context import MachOContext

# check dependencies
try:
	assert sys.version_info >= (3, 9, 5)
except AssertionError:
	print("Python 3.9.5 or greater is required", file=sys.stderr)
	exit(1)

try:
	progressbar.streams
except AttributeError:
	print("progressbar is installed but progressbar2 required.", file=sys.stderr)
	exit(1)


class _DyldExtractorArgs(argparse.Namespace):

	dyld_path: pathlib.Path
	output: pathlib.Path
	jobs: int
	verbosity: int
	pass


def _createArgParser() -> argparse.ArgumentParser:
	parser = argparse.ArgumentParser(description="Extract all images from a Dyld Shared Cache.")  # noqa
	parser.add_argument(
		"dyld_path",
		type=pathlib.Path,
		help="A path to the target DYLD cache."
	)
	parser.add_argument(
		"-o", "--output",
		type=pathlib.Path,
		help="Specify the output path for the extracted frameworks. By default it extracts to './binaries/'."  # noqa
	)
	parser.add_argument(
		"-j", "--jobs", type=int, default=multiprocessing.cpu_count(),
		help="Number of jobs to run simultaneously."  # noqa
	)
	parser.add_argument(
		"-v", "--verbosity",
		choices=[0, 1, 2, 3],
		default=1,
		type=int,
		help="Increase verbosity, Option 1 is the default. | 0 = None | 1 = Critical Error and Warnings | 2 = 1 + Info | 3 = 2 + debug |"  # noqa
	)

	return parser


class _DummyProgressBar():
	def update(*args, **kwargs):
		pass
	pass


def _workerInitializer():
	"""
	Ignore KeyboardInterrupt in workers so that the main process
	can receive it and stop everything.
	"""
	signal.signal(signal.SIGINT, signal.SIG_IGN)
	pass


def _extractImage(
	dyldPath: pathlib.Path,
	outputDir: pathlib.Path,
	imageIndex: int,
	imagePath: str,
	loggingLevel: int
) -> str:
	# change imagePath to a relative path
	if imagePath[0] == "/":
		imagePath = imagePath[1:]
		pass

	outputPath = outputDir / imagePath

	# setup logging
	logger = logging.getLogger(f"Worker: {outputPath}")

	loggingStream = io.StringIO()
	handler = logging.StreamHandler(loggingStream)
	formatter = logging.Formatter(
		fmt="{asctime}:{msecs:03.0f} [{levelname:^9}] {filename}:{lineno:d} : {message}",  # noqa
		datefmt="%H:%M:%S",
		style="{",
	)

	handler.setFormatter(formatter)
	logger.addHandler(handler)
	logger.setLevel(loggingLevel)

	# Process the image
	with open(dyldPath, "rb") as f:
		subCacheFiles: List[BinaryIO] = []
		try:
			dyldCtx = DyldContext(f)

			# add sub caches if there are any
			subCacheFiles = dyldCtx.addSubCaches(dyldPath)

			machoOffset, context = dyldCtx.convertAddr(
				dyldCtx.images[imageIndex].address
			)
			machoCtx = MachOContext(context.fileObject, machoOffset, True)

			# Add sub caches if necessary
			if dyldCtx.hasSubCaches():
				mappings = dyldCtx.mappings
				mainFileMap = next(
					(mapping[0] for mapping in mappings if mapping[1] == context)
				)
				machoCtx.addSubfiles(
					mainFileMap,
					((m, ctx.makeCopy(copyMode=True)) for m, ctx in mappings)
				)
				pass

			extractionCtx = ExtractionContext(
				dyldCtx,
				machoCtx,
				_DummyProgressBar(),
				logger
			)

			slide_info.processSlideInfo(extractionCtx)
			linkedit_optimizer.optimizeLinkedit(extractionCtx)
			stub_fixer.fixStubs(extractionCtx)
			objc_fixer.fixObjC(extractionCtx)

			writeProcedures = macho_offset.optimizeOffsets(extractionCtx)

			# write the file
			outputPath.parent.mkdir(parents=True, exist_ok=True)
			with open(outputPath, "wb") as outFile:
				for procedure in writeProcedures:
					outFile.seek(procedure.writeOffset)
					outFile.write(
						procedure.fileCtx.getBytes(procedure.readOffset, procedure.size)
					)
					pass
				pass
			pass

		except OSError as e:
			if e.errno == errno.EMFILE:
				logger.error("Too many files open, you may need to increase your FD limit.")  # noqa
			else:
				raise e

		except Exception as e:
			logger.exception(e)
			pass

		finally:
			for file in subCacheFiles:
				file.close()
				pass
			pass
		pass

	handler.close()
	loggingStream.flush()
	loggingOutput = loggingStream.getvalue()
	loggingStream.close()
	return loggingOutput


def _main() -> None:
	argParser = _createArgParser()
	args = argParser.parse_args(namespace=_DyldExtractorArgs())

	# Make the output dir
	if args.output is None:
		outputDir = pathlib.Path("binaries")
		pass
	else:
		outputDir = pathlib.Path(args.output)
		pass

	outputDir.mkdir(parents=True, exist_ok=True)

	if args.verbosity == 0:
		# Set the log level so high that it doesn't do anything
		loggingLevel = 100
	elif args.verbosity == 1:
		loggingLevel = logging.WARNING
	elif args.verbosity == 2:
		loggingLevel = logging.INFO
	elif args.verbosity == 3:
		loggingLevel = logging.DEBUG

	# create a list of image paths
	imagePaths: List[str] = []
	with open(args.dyld_path, "rb") as f:
		dyldCtx = DyldContext(f)

		# Check magic
		if dyldCtx.header.magic[0:4] != b"dyld":
			print("Cache's magic does not start with 'dyld', most likely given a file that's not a cache or the file is broken.")  # noqa
			return

		for image in dyldCtx.images:
			imagePath = dyldCtx.readString(
				image.pathFileOffset
			)[0:-1].decode("utf-8")
			imagePaths.append(imagePath)
			pass
		pass

	with multiprocessing.Pool(args.jobs, initializer=_workerInitializer) as pool:
		# Create a job for each image
		jobs: List[Tuple[str, multiprocessing.pool.AsyncResult]] = []
		jobsComplete = 0
		for i, imagePath in enumerate(imagePaths):
			# The index should correspond with its index in the DSC
			extractionArgs = (args.dyld_path, outputDir, i, imagePath, loggingLevel)
			jobs.append((imagePath, pool.apply_async(_extractImage, extractionArgs)))
			pass

		# setup a progress bar
		progressBar = progressbar.ProgressBar(
			max_value=len(jobs),
			redirect_stdout=True
		)

		# Record potential logging output for each job
		jobOutputs: List[str] = []

		# wait for all jobs
		while len(jobs):
			for i in reversed(range(len(jobs))):
				imagePath, job = jobs[i]
				if job.ready():
					jobs.pop(i)

					imageName = imagePath.split("/")[-1]
					print(f"Processed: {imageName}")

					jobOutput = job.get()
					if jobOutput:
						summary = f"----- {imageName} -----\n{jobOutput}--------------------\n"
						jobOutputs.append(summary)
						print(summary)
						pass

					jobsComplete += 1
					progressBar.update(jobsComplete)
					pass
				pass
			pass

		# close the pool and cleanup
		pool.close()
		pool.join()
		progressBar.update(jobsComplete, force=True)

		# reprint any job output
		print("\n\n----- Summary -----")
		print("".join(jobOutputs))
		print("-------------------\n")
		pass
	pass


if __name__ == "__main__":
	_main()
