from pathlib import Path
from typing import Set, Dict, Callable, Any, List, Generator, Tuple

from checker21.core import Project
from checker21.management import AnonymousProjectCommand
from checker21.norminette.fix_machine import NorminetteFixMachine
from checker21.utils.norminette import NorminetteCheckStatus, Norminette, NorminetteError, NorminetteFileCheckResult


class Command(AnonymousProjectCommand):
	help = 'Runs norminette related tasks'
	state_filename: str = "norminette.json"

	_norminette: Norminette

	basic_subcommands: Set[str] = {
		"version",
	}
	basic_subcommands_alt_names: Dict[str, str] = {
		"v": "version",
	}
	subcommands: Set[str] = {
		"check", "errors", "all", "clear", "fix", "stats",
	}
	subcommands_alt_names: Dict[str, str] = {
		"re": "all",
	}

	def add_arguments(self, parser):
		super().add_arguments(parser)
		parser.add_argument(
			'subcommand',
			help='What norminette should do. By default check files by norm.',
			metavar='command',
			nargs='?',
			default='',
		)
		parser.add_argument(
			'--user',
			help='A username to insert into autogenerated headers',
			metavar='user',
			nargs='?',
			default='',
		)
		parser.add_argument(
			'--email',
			help='An email to insert into autogenerated headers',
			metavar='email',
			nargs='?',
			default='',
		)

	def get_basic_subcommand(self, subcommand_name: str) -> Callable[[Dict], None]:
		subcommand_name = subcommand_name.lower()
		subcommand_name = self.basic_subcommands_alt_names.get(subcommand_name, subcommand_name)
		if subcommand_name in self.basic_subcommands:
			return getattr(self, f"handle_{subcommand_name}")

	def get_subcommand(self, subcommand_name: str) -> Callable[[Project, Dict], None]:
		if not subcommand_name:
			# by default return check handler
			return self.handle_check
		subcommand_name = subcommand_name.lower()
		subcommand_name = self.subcommands_alt_names.get(subcommand_name, subcommand_name)
		if subcommand_name in self.subcommands:
			return getattr(self, f"handle_{subcommand_name}")

	def get_user(self, options) -> str:
		# TODO parse username from env
		return options.get("user") or ''

	def get_email(self, user: str, options) -> str:
		email = options.get("email") or ''
		if not email and user:
			email = f"{user}@student.21-school.ru"
		return email

	def handle(self, *args, **options) -> None:
		subcommand_name: str = options.pop("subcommand", "")
		basic_subcommand = self.get_basic_subcommand(subcommand_name)
		if basic_subcommand:
			self._norminette = Norminette()
			basic_subcommand(**options)
			return

		subcommand = self.get_subcommand(subcommand_name)
		if not subcommand:
			self.stderr.write(f'Unknown norminette command: "{subcommand_name}"!')
			return

		project_path = self._resolve_project_path(options)
		if not project_path:
			return
		temp_folder = self._resolve_project_temp_path(project_path)
		project = Project(project_path, temp_folder)
		with project.path:
			self._norminette = self.load_norminette(project)
			subcommand(project, **options)

	@property
	def norminette(self):
		return self._norminette

	def load_norminette(self, project: Project):
		return Norminette.load(project.temp_folder / self.state_filename)

	def handle_version(self, **options):
		if self.norminette.version is None:
			self.stdout.write(self.style.ERROR("Norminette is not found!"))
			return
		self.stdout.write(self.style.INFO(f"Using norminette {self.norminette.version}"))

	def handle_check(
			self,
			project: Project,
			*,
			only_errors: bool = False,
			only_new: bool = False,
			silent: bool = False,
			**options
	) -> None:
		self.handle_version(**options)
		if not self.norminette.version:
			return

		result = self.norminette.check_project(project)
		self.norminette.save()
		if not only_new:
			result = self.norminette.state.result

		if not result:
			self.stdout.write(self.style.INFO("Nothing has changed!"))
			return

		if silent:
			return

		if only_errors:
			result = {key: info for key, info in result.items() if info["status"] != NorminetteCheckStatus.OK}
		self.print_result(result)

	def handle_errors(self, project: Project, **options) -> None:
		self.handle_check(project, only_errors=True, **options)

	def print_result(self, result) -> None:
		if not result:
			self.stdout.write(self.style.SUCCESS("OK"))
			return

		for file, info in result.items():
			status = info["status"]

			if status == NorminetteCheckStatus.OK:
				self.stdout.write(self.style.SUCCESS(info["line"]))

			elif status == NorminetteCheckStatus.NOT_VALID:
				self.stdout.write(self.style.WARNING(info["line"]))

			elif status == NorminetteCheckStatus.ERROR:
				self.stdout.write(self.style.ERROR(info["line"]))
				for error in info["errors"]:
					self.stdout.write(error)

			if "warnings" in info:
				for warning in info["warnings"]:
					self.stdout.write(self.style.WARNING(warning))

	def handle_clear(self, project: Project, **options) -> None:
		try:
			self.norminette.state.path.unlink()
		except FileNotFoundError:
			pass
		self.stdout.write(self.style.INFO("The norminette cache has been cleared!"))

	def handle_all(self, project: Project, **options) -> None:
		self.handle_clear(project)
		self.handle_check(project)

	def handle_fix(self, project: Project, **options) -> None:
		user = self.get_user(options)
		email = self.get_email(user, options)

		fix_machine = NorminetteFixMachine(self.stdout, self.stderr, self.style)
		fix_machine.set_user_email(user, email)

		if self.norminette.state.result:
			self.stdout.write(self.style.INFO("Trying to fix cached errors..."))
		else:
			self.stdout.write(self.style.INFO("Collecting errors by norminette check..."))
			self.handle_check(project, silent=True)

		total_errors = self._count_errors(self.norminette.state.result)

		ast_fix_count = self._handle_fix_by_ast(project, fix_machine)

		stats = self._handle_fix_by_norminette(project, fix_machine)
		total_fix_count, try_fix_count = stats
		total_fix_count += ast_fix_count

		result = self.norminette.state.result
		for file, info in result.items():
			status = info["status"]

			if status != NorminetteCheckStatus.ERROR:
				continue

			errors_to_fix: List[NorminetteError] = []
			use_norminette_machines = False
			for error in self._iter_parsed_errors(info["errors"]):
				if error.code in {
					"INVALID_HEADER",
					"NO_ARGS_VOID",
					"SPACE_BEFORE_FUNC",
					"CONSECUTIVE_SPC",
					"BRACE_SHOULD_EOL",
					"BRACE_NEWLINE",
					"SPACE_REPLACE_TAB",
					"SPACE_AFTER_KW",
					"SPC_AFTER_PAR",
					"SPC_AFTER_OPERATOR",
					"NO_SPC_BFR_OPR",
					"TAB_INSTEAD_SPC",
					"TAB_REPLACE_SPACE",
					"NO_SPC_AFR_PAR",
					"SPC_BFR_OPERATOR",
					"NL_AFTER_VAR_DECL",
					"SPC_AFTER_POINTER",
					"SPC_BEFORE_NL",
				}:
					errors_to_fix.append(error)
				elif error.code in {
					"MISALIGNED_FUNC_DECL",
					"MISALIGNED_VAR_DECL",
				}:
					errors_to_fix.append(error)
					use_norminette_machines = True

			_try_to_fix = len(errors_to_fix)
			if _try_to_fix == 0:
				continue

			path = Path(file)
			if not path.exists():
				continue

			self.stdout.write(self.style.INFO(f'Trying to fix errors in {file} by regexp'))
			fix_machine.load_file(path)
			if use_norminette_machines:
				fix_machine.run_with_norminette()

			for error in errors_to_fix:
				fix_machine.fix_norm_error(error)

			fix_machine.save()
			try_fix_count += _try_to_fix
			total_fix_count += fix_machine.fix_count

		if try_fix_count > 0:
			self.handle_check(project, only_new=True)
		new_total_errors = self._count_errors(self.norminette.state.result)

		self.stdout.write(self.style.INFO(f"Norm errors: {total_errors} -> {new_total_errors}"))
		if try_fix_count == 0:
			self.stdout.write(self.style.INFO("There is no any errors found that could be fixed by this plugin"))
			return

		# self.stdout.write(self.style.INFO(f"Fixed norm errors: {total_fix_count} / {try_fix_count}"))

	def _count_errors(self, result: Dict[str, NorminetteFileCheckResult]) -> int:
		total_errors = 0
		for file, info in result.items():
			errors = info.get('errors')
			if errors:
				total_errors += len(errors)
		return total_errors

	def _iter_parsed_errors(self, errors: List[str]) -> Generator[NorminetteError, None, None]:
		for error in errors:
			_error = NorminetteError.parse(error)
			if _error:
				yield _error

	def _handle_fix_by_norminette(self, project: Project, fix_machine: NorminetteFixMachine) -> Tuple[int, int]:
		"""
		:return: (total_fix_count, try_fix_count)
		"""
		try_fix_count = 0
		total_fix_count = 0
		fixed_files = 0

		result = self.norminette.state.result
		for file, info in result.items():
			status = info["status"]

			if status != NorminetteCheckStatus.ERROR:
				continue

			norminette_errors_to_fix: List[NorminetteError] = []
			for error in self._iter_parsed_errors(info["errors"]):
				if error.code in {
					"RETURN_PARENTHESIS",
				}:
					norminette_errors_to_fix.append(error)

			_try_to_fix = len(norminette_errors_to_fix)
			if _try_to_fix == 0:
				continue

			path = Path(file)
			if not path.exists():
				continue

			self.stdout.write(self.style.INFO(f'Trying to fix errors in {file} by norminette'))
			fix_machine.load_file(path)
			fix_machine.run_with_norminette()
			fix_machine.save()

			fixed_files += 1
			try_fix_count += _try_to_fix
			total_fix_count += fix_machine.fix_count

		if fixed_files:
			self.handle_check(project, silent=True)
		return (try_fix_count, total_fix_count)

	def _handle_fix_by_ast(self, project: Project, fix_machine: NorminetteFixMachine) -> int:
		"""
		:return: total_fix_count
		"""
		fixed_files = 0

		result = self.norminette.state.result
		errors_count = self._count_errors(result)
		for file, info in result.items():
			status = info["status"]

			if status != NorminetteCheckStatus.ERROR:
				continue

			has_ast_error = False
			for error in self._iter_parsed_errors(info["errors"]):
				if error.code in {
					"PREPROC_START_LINE",
					"PREPROC_BAD_INDENT",
					"EMPTY_LINE_EOF",
					"EMPTY_LINE_FUNCTION",
					"CONSECUTIVE_NEWLINES",
					"EMPTY_LINE_FILE_START",
					"SPACE_AFTER_KW",
					"SPC_BFR_POINTER",
					"SPC_AFTER_POINTER",
					"NO_SPC_AFR_OPR",
					"SPC_AFTER_PAR",
					"NO_SPC_AFR_PAR",
					"NO_SPC_BFR_PAR",
					"SPC_BFR_PAR",
					"NO_SPC_BFR_OPR",
					"SPC_BFR_OPERATOR",
					"SPC_AFTER_OPERATOR",
					"TOO_MANY_INSTR",
					"EOL_OPERATOR",
					"TOO_MANY_TAB",
					"TOO_FEW_TAB",
					"MIXED_SPACE_TAB",
					"COMMA_START_LINE",
					"SPC_BEFORE_NL",
					"NEWLINE_PRECEDES_FUNC",
					"BRACE_NEWLINE",
				}:
					has_ast_error = True
					break

			if not has_ast_error:
				continue

			path = Path(file)
			if not path.exists():
				continue

			self.stdout.write(self.style.INFO(f'Trying to fix errors in {file} by AST'))
			fix_machine.load_file(path)
			fix_machine.run_ast_refactoring()
			fix_machine.save()

			fixed_files += 1

		if fixed_files:
			self.handle_check(project, silent=True)
			result = self.norminette.state.result
			total_fix_count = errors_count - self._count_errors(result)
		else:
			total_fix_count = 0
		return total_fix_count

	def handle_stats(self, project: Project, **options) -> None:
		result = self.norminette.state.result
		files_count = 0
		files_error = 0
		files_ok = 0
		total_errors = 0
		for file, info in result.items():
			files_count += 1

			status = info["status"]
			if status == NorminetteCheckStatus.ERROR:
				files_error += 1
				total_errors += len(info["errors"])
			elif status == NorminetteCheckStatus.OK:
				files_ok += 1

		self.stdout.write(f"Files checked by norm: {files_count}")
		self.stdout.write(self.style.SUCCESS(f"Files ok: {files_ok}"))
		style = self.style.ERROR if files_error else self.style.SUCCESS
		self.stdout.write(style(f"Files with error: {files_error}"))
		style = self.style.ERROR if total_errors else self.style.SUCCESS
		self.stdout.write(style(f"Norm errors: {total_errors}"))
