#!/usr/bin/env python3

import progressbar
import argparse
import pathlib
import logging
import os
import sys
from typing import List, BinaryIO

try:
	progressbar.streams
except AttributeError:
	print("progressbar is installed but progressbar2 required.", file=sys.stderr)
	exit(1)

from DyldExtractor.extraction_context import ExtractionContext
from DyldExtractor.macho.macho_context import MachOContext
from DyldExtractor.dyld.dyld_context import DyldContext

from DyldExtractor.dyld.dyld_structs import (
	dyld_cache_image_info
)

from DyldExtractor.converter import (
	slide_info,
	macho_offset,
	linkedit_optimizer,
	stub_fixer,
	objc_fixer
)


class _DyldExtractorArgs(argparse.Namespace):

	dyld_path: pathlib.Path
	extract: str
	output: pathlib.Path
	list_frameworks: bool
	filter: str
	verbosity: int
	pass


def _getArguments():
	"""Get program arguments.

	"""

	parser = argparse.ArgumentParser()
	parser.add_argument(
		"dyld_path",
		type=pathlib.Path,
		help="A path to the target DYLD cache. If it is a split cache, use the path for the main cache (the one without a file type)."  # noqa
	)
	parser.add_argument(
		"-e", "--extract",
		help="The name of the framework to extract. This can be longer for frameworks like UIKit, for example \"UIKit.framework/UIKit\""  # noqa
	)
	parser.add_argument(
		"-o", "--output",
		help="Specify the output path for the extracted framework. By default it extracts to the binaries folder."  # noqa
	)
	parser.add_argument(
		"-l", "--list-frameworks", action="store_true",
		help="List all frameworks in the cache."
	)
	parser.add_argument(
		"-f", "--filter",
		help="Filter out frameworks when listing them."
	)
	parser.add_argument(
		"-a", "--addresses", action="store_true",
		help="List addresses along with framework paths. Only applies when --list-frameworks is specified."
	)
	parser.add_argument(
		"-b", "--basenames", action="store_true",
		help="Print only the basenames of each framework. Only applies when --list-frameworks is specified."
		)
	parser.add_argument(
		"--lookup",
		help="Find the library that an address lives in. E.g. dyldex --lookup 0x18008e9f8 dyld_shared_cache_arm64e."
	)
	parser.add_argument(
		"-v", "--verbosity", type=int, choices=[0, 1, 2, 3], default=1,
		help="Increase verbosity, Option 1 is the default. | 0 = None | 1 = Critical Error and Warnings | 2 = 1 + Info | 3 = 2 + debug |"  # noqa
	)

	return parser.parse_args(namespace=_DyldExtractorArgs)


def _extractImage(
	dyldFilePath: pathlib.Path,
	dyldCtx: DyldContext,
	image: dyld_cache_image_info,
	outputPath: str
) -> None:
	"""Extract an image and save it.

	The order of converters is essentially a reverse of Apple's SharedCacheBuilder
	"""

	logger = logging.getLogger()

	statusBar = progressbar.ProgressBar(
		prefix="{variables.unit} >> {variables.status} :: [",
		variables={"unit": "--", "status": "--"},
		widgets=[progressbar.widgets.AnimatedMarker(), "]"],
		redirect_stdout=True
	)

	subCacheFiles: List[BinaryIO] = []
	try:
		# add sub caches if there are any
		subCacheFiles = dyldCtx.addSubCaches(dyldFilePath)

		# get a a writable copy of the MachOContext
		machoOffset, context = dyldCtx.convertAddr(image.address)
		machoCtx = MachOContext(context.fileObject, machoOffset, True)

		# Add sub caches if necessary
		if dyldCtx.hasSubCaches():
			mappings = dyldCtx.mappings
			mainFileMap = next(
				(mapping[0] for mapping in mappings if mapping[1] == context)
			)
			machoCtx.addSubfiles(
				mainFileMap,
				((m, ctx.makeCopy(copyMode=True)) for m, ctx in mappings)
			)
			pass

		extractionCtx = ExtractionContext(dyldCtx, machoCtx, statusBar, logger)

		slide_info.processSlideInfo(extractionCtx)
		linkedit_optimizer.optimizeLinkedit(extractionCtx)
		stub_fixer.fixStubs(extractionCtx)
		objc_fixer.fixObjC(extractionCtx)

		writeProcedures = macho_offset.optimizeOffsets(extractionCtx)

		# Write the MachO file
		with open(outputPath, "wb") as outFile:
			statusBar.update(unit="Extractor", status="Writing file")

			for procedure in writeProcedures:
				outFile.seek(procedure.writeOffset)
				outFile.write(
					procedure.fileCtx.getBytes(procedure.readOffset, procedure.size)
				)
				pass
			pass

		statusBar.update(unit="Extractor", status="Done")

	finally:
		for file in subCacheFiles:
			file.close()
			pass
		pass
	pass


def _filterImages(imagePaths: List[str], filterTerm: str):
	filteredPaths = []
	filterTerm = filterTerm.lower()

	for path in imagePaths:
		if filterTerm in path.lower():
			filteredPaths.append(path)

	return filteredPaths


def main():
	args = _getArguments()

	# Configure Logging
	level = logging.WARNING  # default option

	if args.verbosity == 0:
		# Set the log level so high that it doesn't do anything
		level = 100
	elif args.verbosity == 2:
		level = logging.INFO
	elif args.verbosity == 3:
		level = logging.DEBUG

	# needed for logging compatibility
	progressbar.streams.wrap_stderr()  # type:ignore

	logging.basicConfig(
		format="{asctime}:{msecs:3.0f} [{levelname:^9}] {filename}:{lineno:d} : {message}",  # noqa
		datefmt="%H:%M:%S",
		style="{",
		level=level
	)

	with open(args.dyld_path, "rb") as f:
		dyldCtx = DyldContext(f)

		# enumerate images, create a map of paths and images
		imageMap = {}
		for imageData in dyldCtx.images:
			path = dyldCtx.readString(imageData.pathFileOffset)
			path = path[0:-1]  # remove null terminator
			path = path.decode("utf-8")

			imageMap[path] = imageData

		# Find the image that an address lives in
		if args.lookup:
			lookupAddr = int(args.lookup, 0)

			imagePaths = imageMap.keys()

			# sort the paths so they're in VM address order
			sortedPaths = sorted(imagePaths, key=lambda path: imageMap[path].address)

			previousImagePath = None
			for path in sortedPaths:
				imageAddr = imageMap[path].address
				if lookupAddr < imageAddr:
					if previousImagePath is None:
						print("Error: address before first image!", file=sys.stderr)
						sys.exit(1)
					print(os.path.basename(previousImagePath) if args.basenames else previousImagePath)
					return
				else:
					previousImagePath = path
			# We got to the end of the list, must be the last image
			path = sortedPaths[-1]
			print(os.path.basename(path) if args.basenames else path)
			return

		# list images option
		if args.list_frameworks:
			imagePaths = imageMap.keys()

			# filter if needed
			if args.filter:
				filterTerm = args.filter.strip().lower()
				imagePaths = set(_filterImages(imagePaths, filterTerm))

			# sort the paths so they're displayed in VM address order
			sortedPaths = sorted(imagePaths, key=lambda path: imageMap[path].address)

			print("Listing Images\n--------------")
			for fullpath in sortedPaths:
				path = os.path.basename(fullpath) if args.basenames else fullpath
				if args.addresses:
					print(f"{hex(imageMap[fullpath].address)} : {path}")
				else:
					print(path)

			return

		# extract image option
		if args.extract:
			extractionTarget = args.extract.strip()
			targetPaths = _filterImages(imageMap.keys(), extractionTarget)
			if len(targetPaths) == 0:
				print(f"Unable to find image \"{extractionTarget}\"")
				return

			outputPath = args.output
			if outputPath is None:
				outputPath = pathlib.Path("binaries/" + extractionTarget)
				os.makedirs(outputPath.parent, exist_ok=True)

			print(f"Extracting {targetPaths[0]}")
			_extractImage(args.dyld_path, dyldCtx, imageMap[targetPaths[0]], outputPath)
			return


if "__main__" == __name__:
	main()
	pass
