# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

import json
import logging
import os
import time
import boto3
import textractcaller as tc
import textractmanifest as tm

from datetime import datetime
from botocore.config import Config
from typing import List

logger = logging.getLogger(__name__)
version = "0.0.1"
s3 = boto3.client('s3')
step_functions_client = boto3.client(service_name='stepfunctions')
sqs = boto3.client('sqs')

config = Config(retries={'max_attempts': 0, 'mode': 'standard'})
textract = boto3.client("textract", config=config)


def convert_manifest_queries_config_to_caller(
        queries_config: List[tm.Query]) -> tc.QueriesConfig:
    if queries_config:
        return tc.QueriesConfig(queries=[
            tc.Query(text=x.text, alias=x.alias, pages=x.pages)
            for x in queries_config
        ])
    else:
        return tc.QueriesConfig(queries=[])


def convert_manifest_features_to_caller(
        features: List[str]) -> List[tc.Textract_Features]:
    if features:
        return [tc.Textract_Features[x] for x in features]
    else:
        return []


def lambda_handler(event, _):
    log_level = os.environ.get('LOG_LEVEL', 'INFO')
    logger.setLevel(log_level)
    logger.info(f"version: {version}")
    logger.info(json.dumps(event))

    s3_output_bucket = os.environ.get('S3_OUTPUT_BUCKET')
    s3_output_prefix = os.environ.get('S3_OUTPUT_PREFIX')
    textract_api = os.environ.get('TEXTRACT_API', 'GENERIC')

    sqs_queue_url = os.environ.get('SQS_QUEUE_URL', None)
    if not sqs_queue_url:
        raise Exception("no SQS_QUEUE_URL set")

    if not s3_output_bucket or not s3_output_prefix:
        raise ValueError(
            f"no s3_output_bucket: {s3_output_bucket} or s3_output_prefix: {s3_output_prefix} defined."
        )
    logger.debug(f"LOG_LEVEL: {log_level} \n \
                S3_OUTPUT_BUCKET: {s3_output_bucket} \n \
                S3_OUTPUT_PREFIX: {s3_output_prefix} \n \
                TEXTRACT_API: {textract_api} \n \
                SQS_QUEUE_URL: {sqs_queue_url}")

    processing_status: bool = True
    files_with_failures: List[str] = list()
    for record in event['Records']:
        if not ("eventSource" in record
                and record["eventSource"]) == "aws:sqs":
            raise ValueError("Unsupported eventSource in record")

        message = json.loads(record["body"])
        token = message['Token']
        execution_id = message['ExecutionId']

        if not "Payload" in message:
            raise ValueError("Need Payload with manifest to process message.")

        receipt_handle = record["receiptHandle"]

        manifest: tm.IDPManifest = tm.IDPManifestSchema().load(
            message["Payload"]['manifest'])  #type: ignore

        number_of_pages: int = 0
        if 'numberOfPages' in message["Payload"]:
            number_of_pages = int(message['Payload']['numberOfPages'])

        if number_of_pages > 1:
            logger.error("more than 1 page")
            sqs.delete_message(QueueUrl=sqs_queue_url,
                               ReceiptHandle=receipt_handle)
            step_functions_client.send_task_failure(
                taskToken=token,
                error='TooManyPagesForSync',
                cause=
                f'Document with > 1 page was send to a Sync Textract API endpoint. (number pages: {number_of_pages}'
            )

        try:
            start_time = round(time.time() * 1000)
            logger.info(
                f"textract_sync_{textract_api}_number_of_pages_send_to_process: {number_of_pages}"
            )
            if textract_api == 'ANALYZEID':
                logger.debug(f"calling AnalyzeID")
                if manifest.document_pages:
                    textract_response = tc.call_textract_analyzeid(
                        document_pages=manifest.document_pages)  #type: ignore
                    s3_filename, _ = os.path.splitext(
                        os.path.basename(manifest.document_pages[0]))
                elif manifest.s3_path:
                    textract_response = tc.call_textract_analyzeid(
                        document_pages=[manifest.s3_path])  #type: ignore
                    s3_filename, _ = os.path.splitext(
                        os.path.basename(manifest.s3_path))
                else:
                    raise ValueError(
                        f"no document_pages, no s3_path in manifest for call to AnalyzeID"
                    )
            elif textract_api == 'EXPENSE':
                s3_path = manifest.s3_path
                logger.info(f"s3_path: {s3_path} \n \
                            token: {token} \n \
                            execution_id: {execution_id}")
                s3_filename, _ = os.path.splitext(os.path.basename(s3_path))
                logger.debug(f"calling Expense")
                textract_response = tc.call_textract_expense(
                    input_document=s3_path,
                    boto3_textract_client=textract,
                )
            else:
                s3_path = manifest.s3_path
                logger.info(f"s3_path: {s3_path} \n \
                            token: {token} \n \
                            execution_id: {execution_id}")
                s3_filename, _ = os.path.splitext(os.path.basename(s3_path))

                features = convert_manifest_features_to_caller(
                    manifest.textract_features)
                queries_config = convert_manifest_queries_config_to_caller(
                    manifest.queries_config)
                logger.debug(f"before call_textract\n \
                    input_document: {s3_path} \n \
                    features: {features}\n \
                    queries_config: {queries_config}")
                textract_response = tc.call_textract(
                    input_document=s3_path,
                    boto3_textract_client=textract,
                    features=features,
                    queries_config=queries_config)

            call_duration = round(time.time() * 1000) - start_time
            output_bucket_key = s3_output_prefix + "/" + s3_filename + datetime.utcnow(
            ).isoformat() + "/" + s3_filename + ".json"
            s3.put_object(Body=bytes(
                json.dumps(textract_response, indent=4).encode('UTF-8')),
                          Bucket=s3_output_bucket,
                          Key=output_bucket_key)
            logger.info(
                f"textract_sync_{textract_api}_call_duration_in_ms: {call_duration}"
            )
            logger.info(
                f"textract_sync_{textract_api}_number_of_pages_processed: {number_of_pages}"
            )
            try:
                step_functions_client.send_task_success(
                    taskToken=token,
                    output=json.dumps({
                        "TextractOutputJsonPath":
                        f"s3://{s3_output_bucket}/{output_bucket_key}"
                    }))
            except step_functions_client.exceptions.InvalidToken:
                logger.error(f"InvalidToken for message: {message} ")
                sqs.delete_message(QueueUrl=sqs_queue_url,
                                   ReceiptHandle=receipt_handle)
            except step_functions_client.exceptions.TaskDoesNotExist:
                logger.error(f"TaskDoesNotExist for message: {message} ")
                sqs.delete_message(QueueUrl=sqs_queue_url,
                                   ReceiptHandle=receipt_handle)
            except step_functions_client.exceptions.TaskTimedOut:
                logger.error(f"TaskTimedOut for message: {message} ")
                sqs.delete_message(QueueUrl=sqs_queue_url,
                                   ReceiptHandle=receipt_handle)
            except step_functions_client.exceptions.InvalidOutput:
                # Not sure if to delete here or not, could be a bug in the code that a hot fix could solve, but don't want to retry infinite, which can cause run-away-cost. For now, delete
                logger.error(f"InvalidOutput for message: {message} ")
                sqs.delete_message(QueueUrl=sqs_queue_url,
                                   ReceiptHandle=receipt_handle)

        except textract.exceptions.InvalidS3ObjectException:
            cause = f"InvalidS3ObjectException for object: {s3_path}"
            error = "InvalidS3ObjectException"
            logger.error(cause)
            sqs.delete_message(QueueUrl=sqs_queue_url,
                               ReceiptHandle=receipt_handle)
            step_functions_client.send_task_failure(taskToken=token,
                                                    error=error,
                                                    cause=cause)
        except textract.exceptions.InvalidParameterException:
            error = f"InvalidParameterException",
            cause = f"textract.exceptions.InvalidParameterException: for manifest: {manifest}"
            logger.error(cause)
            sqs.delete_message(QueueUrl=sqs_queue_url,
                               ReceiptHandle=receipt_handle)
            step_functions_client.send_task_failure(taskToken=token,
                                                    error=error,
                                                    cause=cause)
        except textract.exceptions.InvalidKMSKeyException:
            error = f"InvalidKMSKeyException",
            cause = f"textract.exceptions.InvalidKMSKeyException: for manifest: {manifest}"
            logger.error(cause)
            sqs.delete_message(QueueUrl=sqs_queue_url,
                               ReceiptHandle=receipt_handle)
            error = f"UnsupportedDocumentException",
            cause = f"textract.exceptions.UnsupportedDocumentException: for manifest: {manifest}"
            logger.error(cause)
            sqs.delete_message(QueueUrl=sqs_queue_url,
                               ReceiptHandle=receipt_handle)
            step_functions_client.send_task_failure(taskToken=token,
                                                    error=error,
                                                    cause=cause)
        except textract.exceptions.DocumentTooLargeException:
            error = f"DocumentTooLargeException",
            cause = f"textract.exceptions.DocumentTooLargeException: for manifest: {manifest}"
            logger.error(cause)
            sqs.delete_message(QueueUrl=sqs_queue_url,
                               ReceiptHandle=receipt_handle)
            step_functions_client.send_task_failure(taskToken=token,
                                                    error=error,
                                                    cause=cause)
        except textract.exceptions.BadDocumentException:
            error = f"BadDocumentException",
            cause = f"textract.exceptions.BadDocumentException: for manifest: {manifest}"
            logger.error(cause)
            sqs.delete_message(QueueUrl=sqs_queue_url,
                               ReceiptHandle=receipt_handle)
            step_functions_client.send_task_failure(taskToken=token,
                                                    error=error,
                                                    cause=cause)
        except textract.exceptions.AccessDeniedException:
            error = f"AccessDeniedException",
            cause = f"textract.exceptions.AccessDeniedException: for manifest: {manifest}"
            logger.error(cause)
            sqs.delete_message(QueueUrl=sqs_queue_url,
                               ReceiptHandle=receipt_handle)
            step_functions_client.send_task_failure(taskToken=token,
                                                    error=error,
                                                    cause=cause)
        except textract.exceptions.IdempotentParameterMismatchException:
            error = f"IdempotentParameterMismatchException",
            cause = f"textract.exceptions.IdempotentParameterMismatchException: for manifest: {manifest}"
            logger.error(cause)
            sqs.delete_message(QueueUrl=sqs_queue_url,
                               ReceiptHandle=receipt_handle)
            step_functions_client.send_task_failure(taskToken=token,
                                                    error=error,
                                                    cause=cause)
        # these Exceptions we can retry, so we put them back on the queue
        except textract.exceptions.ProvisionedThroughputExceededException:
            logger.error(
                f"textract.exceptions.ProvisionedThroughputExceededException")
            processing_status = False
            files_with_failures.append(s3_path)
        except textract.exceptions.InternalServerError:
            logger.error(f"textract.exceptions.InternalServerError")
            processing_status = False
            files_with_failures.append(s3_path)
        except textract.exceptions.ThrottlingException:
            logger.error(f"textract.exceptions.ThrottlingException")
            processing_status = False
            files_with_failures.append(s3_path)
        except textract.exceptions.LimitExceededException:
            logger.error(f"textract.exceptions.LimitExceededException")
            processing_status = False
            files_with_failures.append(s3_path)
        except Exception as e:
            error = "not_handled_exception"
            cause = str(e)
            try:
                step_functions_client.send_task_failure(taskToken=token,
                                                        error=error,
                                                        cause=cause)
            except step_functions_client.exceptions.InvalidToken:
                logger.error(f"InvalidToken for message: {message} ")
                sqs.delete_message(QueueUrl=sqs_queue_url,
                                   ReceiptHandle=receipt_handle)
            except step_functions_client.exceptions.TaskDoesNotExist:
                logger.error(f"TaskDoesNotExist for message: {message} ")
                sqs.delete_message(QueueUrl=sqs_queue_url,
                                   ReceiptHandle=receipt_handle)
            except step_functions_client.exceptions.TaskTimedOut:
                logger.error(f"TaskTimedOut for message: {message} ")
                sqs.delete_message(QueueUrl=sqs_queue_url,
                                   ReceiptHandle=receipt_handle)
        if not processing_status:
            raise Exception(
                f"files with failures: {[x for x in files_with_failures]}")
