#!/usr/bin/env python3
import asyncio
import concurrent.futures
import json
import sys
import singer
import singer.utils as singer_utils
from singer import metadata, metrics

import tap_salesforce.salesforce
from tap_salesforce.sync import (sync_stream, resume_syncing_bulk_query, get_stream_version)
from tap_salesforce.salesforce import Salesforce
from tap_salesforce.salesforce.exceptions import (
    TapSalesforceException, TapSalesforceQuotaExceededException)
from tap_salesforce.salesforce.credentials import (
    OAuthCredentials,
    PasswordCredentials,
    parse_credentials
)

LOGGER = singer.get_logger()

# the tap requires these keys
REQUIRED_CONFIG_KEYS = ['api_type',
                        'select_fields_by_default']

# and either one of these credentials

# OAuth:
# - client_id
# - client_secret
# - refresh_token
OAUTH_CONFIG_KEYS = OAuthCredentials._fields

# Password:
# - username
# - password
# - security_token
PASSWORD_CONFIG_KEYS = PasswordCredentials._fields

CONFIG = {
    'refresh_token': None,
    'client_id': None,
    'client_secret': None,
    'start_date': None
}

FORCED_FULL_TABLE = {
    'BackgroundOperationResult' # Does not support ordering by CreatedDate
}

def get_replication_key(sobject_name, fields):
    if sobject_name in FORCED_FULL_TABLE:
        return None

    fields_list = [f['name'] for f in fields]

    if 'SystemModstamp' in fields_list:
        return 'SystemModstamp'
    elif 'LastModifiedDate' in fields_list:
        return 'LastModifiedDate'
    elif 'CreatedDate' in fields_list:
        return 'CreatedDate'
    elif 'LoginTime' in fields_list and sobject_name == 'LoginHistory':
        return 'LoginTime'
    return None

def stream_is_selected(mdata):
    return mdata.get((), {}).get('selected', False)

def build_state(raw_state, catalog):
    state = {}

    for catalog_entry in catalog['streams']:
        tap_stream_id = catalog_entry['tap_stream_id']
        catalog_metadata = metadata.to_map(catalog_entry['metadata'])
        replication_method = catalog_metadata.get((), {}).get('replication-method')

        version = singer.get_bookmark(raw_state,
                                      tap_stream_id,
                                      'version')

        # Preserve state that deals with resuming an incomplete bulk job
        if singer.get_bookmark(raw_state, tap_stream_id, 'JobID'):
            job_id = singer.get_bookmark(raw_state, tap_stream_id, 'JobID')
            batches = singer.get_bookmark(raw_state, tap_stream_id, 'BatchIDs')
            current_bookmark = singer.get_bookmark(raw_state, tap_stream_id, 'JobHighestBookmarkSeen')
            state = singer.write_bookmark(state, tap_stream_id, 'JobID', job_id)
            state = singer.write_bookmark(state, tap_stream_id, 'BatchIDs', batches)
            state = singer.write_bookmark(state, tap_stream_id, 'JobHighestBookmarkSeen', current_bookmark)

        if replication_method == 'INCREMENTAL':
            replication_key = catalog_metadata.get((), {}).get('replication-key')
            replication_key_value = singer.get_bookmark(raw_state,
                                                        tap_stream_id,
                                                        replication_key)
            if version is not None:
                state = singer.write_bookmark(
                    state, tap_stream_id, 'version', version)
            if replication_key_value is not None:
                state = singer.write_bookmark(
                    state, tap_stream_id, replication_key, replication_key_value)
        elif replication_method == 'FULL_TABLE' and version is None:
            state = singer.write_bookmark(state, tap_stream_id, 'version', version)

    return state

# pylint: disable=undefined-variable
def create_property_schema(field, mdata):
    field_name = field['name']

    if field_name == "Id":
        mdata = metadata.write(
            mdata, ('properties', field_name), 'inclusion', 'automatic')
    else:
        mdata = metadata.write(
            mdata, ('properties', field_name), 'inclusion', 'available')

    property_schema, mdata = salesforce.field_to_property_schema(field, mdata)

    return (property_schema, mdata)


# pylint: disable=too-many-branches,too-many-statements
def do_discover(sf):
    """Describes a Salesforce instance's objects and generates a JSON schema for each field."""
    global_description = sf.describe()

    objects_to_discover = {o['name'] for o in global_description['sobjects']}
    key_properties = ['Id']

    sf_custom_setting_objects = []
    object_to_tag_references = {}

    # For each SF Object describe it, loop its fields and build a schema
    entries = []
    for sobject_name in objects_to_discover:

        # Skip blacklisted SF objects depending on the api_type in use
        # ChangeEvent objects are not queryable via Bulk or REST (undocumented)
        if sobject_name in sf.get_blacklisted_objects() \
           or sobject_name.endswith("ChangeEvent"):
            continue

        sobject_description = sf.describe(sobject_name)

        # Cache customSetting and Tag objects to check for blacklisting after
        # all objects have been described
        if sobject_description.get("customSetting"):
            sf_custom_setting_objects.append(sobject_name)
        elif sobject_name.endswith("__Tag"):
            relationship_field = next(
                (f for f in sobject_description["fields"] if f.get("relationshipName") == "Item"),
                None)
            if relationship_field:
                # Map {"Object":"Object__Tag"}
                object_to_tag_references[relationship_field["referenceTo"]
                                         [0]] = sobject_name

        fields = sobject_description['fields']
        replication_key = get_replication_key(sobject_name, fields)

        unsupported_fields = set()
        properties = {}
        mdata = metadata.new()

        found_id_field = False

        # Loop over the object's fields
        for f in fields:
            field_name = f['name']
            field_type = f['type']

            if field_name == "Id":
                found_id_field = True

            property_schema, mdata = create_property_schema(f, mdata)

            # Compound Address fields cannot be queried by the Bulk API
            if (f['type'] == "address" and sf.api_type in (
                    tap_salesforce.salesforce.BULK_API_TYPE, tap_salesforce.salesforce.BULK_V2_API_TYPE)):
                unsupported_fields.add((field_name, 'cannot query compound address fields with bulk API'))

            # we haven't been able to observe any records with a json field, so we
            # are marking it as unavailable until we have an example to work with
            if f['type'] == "json":
                unsupported_fields.add((field_name, 'do not currently support json fields - please contact support'))

            # Blacklisted fields are dependent on the api_type being used
            field_pair = (sobject_name, field_name)
            if field_pair in sf.get_blacklisted_fields():
                unsupported_fields.add(
                    (field_name, sf.get_blacklisted_fields()[field_pair]))

            inclusion = metadata.get(mdata, ('properties', field_name), 'inclusion')

            if sf.select_fields_by_default and inclusion != 'unsupported':
                mdata = metadata.write(mdata, ('properties', field_name), 'selected-by-default', True)

            properties[field_name] = property_schema

        if replication_key:
            mdata = metadata.write(mdata, ('properties', replication_key), 'inclusion', 'automatic')

        # There are cases where compound fields are referenced by the associated
        # subfields but are not actually present in the field list
        field_name_set = {f['name'] for f in fields}
        filtered_unsupported_fields = [f for f in unsupported_fields if f[0] in field_name_set]
        missing_unsupported_field_names = [f[0] for f in unsupported_fields if f[0] not in field_name_set]

        if missing_unsupported_field_names:
            LOGGER.info("Ignoring the following unsupported fields for object %s as they are missing from the field list: %s",
                        sobject_name,
                        ', '.join(sorted(missing_unsupported_field_names)))

        if filtered_unsupported_fields:
            LOGGER.info("Not syncing the following unsupported fields for object %s: %s",
                        sobject_name,
                        ', '.join(sorted([k for k, _ in filtered_unsupported_fields])))

        # Salesforce Objects are skipped when they do not have an Id field
        if not found_id_field:
            LOGGER.info(
                "Skipping Salesforce Object %s, as it has no Id field",
                sobject_name)
            continue

        # Any property added to unsupported_fields has metadata generated and
        # removed
        for prop, description in filtered_unsupported_fields:
            if metadata.get(mdata, ('properties', prop),
                            'selected-by-default'):
                metadata.delete(
                    mdata, ('properties', prop), 'selected-by-default')

            mdata = metadata.write(
                mdata, ('properties', prop), 'unsupported-description', description)
            mdata = metadata.write(
                mdata, ('properties', prop), 'inclusion', 'unsupported')

        if replication_key:
            mdata = metadata.write(
                mdata, (), 'valid-replication-keys', [replication_key])
            mdata = metadata.write(
                mdata, (), 'replication-key', replication_key
            )
            mdata = metadata.write(
                mdata, (), 'replication-method', "INCREMENTAL"
            )
        else:
            mdata = metadata.write(
                mdata,
                (),
                'forced-replication-method',
                {
                    'replication-method': 'FULL_TABLE',
                    'reason': 'No replication keys found from the Salesforce API'})

        mdata = metadata.write(mdata, (), 'table-key-properties', key_properties)

        schema = {
            'type': 'object',
            'additionalProperties': False,
            'properties': properties
        }

        entry = {
            'stream': sobject_name,
            'tap_stream_id': sobject_name,
            'schema': schema,
            'metadata': metadata.to_list(mdata)
        }

        entries.append(entry)

    # For each custom setting field, remove its associated tag from entries
    # See Blacklisting.md for more information
    unsupported_tag_objects = [object_to_tag_references[f]
                               for f in sf_custom_setting_objects if f in object_to_tag_references]
    if unsupported_tag_objects:
        LOGGER.info( #pylint:disable=logging-not-lazy
            "Skipping the following Tag objects, Tags on Custom Settings Salesforce objects " +
            "are not supported by the Bulk API:")
        LOGGER.info(unsupported_tag_objects)
        entries = [e for e in entries if e['stream']
                   not in unsupported_tag_objects]

    result = {'streams': entries}
    json.dump(result, sys.stdout, indent=4)

async def sync_catalog_entry(sf, catalog_entry, state):
    stream_version = get_stream_version(catalog_entry, state)
    stream = catalog_entry['stream']
    stream_alias = catalog_entry.get('stream_alias')
    stream_name = catalog_entry["tap_stream_id"]
    activate_version_message = singer.ActivateVersionMessage(
        stream=(stream_alias or stream), version=stream_version)

    catalog_metadata = metadata.to_map(catalog_entry['metadata'])
    replication_key = catalog_metadata.get((), {}).get('replication-key')

    mdata = metadata.to_map(catalog_entry['metadata'])

    if not stream_is_selected(mdata):
        LOGGER.info("%s: Skipping - not selected", stream_name)
        return

    LOGGER.info("%s: Starting", stream_name)

    singer.write_state(state)
    key_properties = metadata.to_map(catalog_entry['metadata']).get((), {}).get('table-key-properties')
    singer.write_schema(
        stream,
        catalog_entry['schema'],
        key_properties,
        replication_key,
        stream_alias)

    loop = asyncio.get_event_loop()

    job_id = singer.get_bookmark(state, catalog_entry['tap_stream_id'], 'JobID')
    if job_id:
        with metrics.record_counter(stream) as counter:
            LOGGER.info("Found JobID from previous Bulk Query. Resuming sync for job: %s", job_id)
            # Resuming a sync should clear out the remaining state once finished
            await loop.run_in_executor(None, resume_syncing_bulk_query, sf, catalog_entry, job_id, state, counter)
            LOGGER.info("Completed sync for %s", stream_name)
            state.get('bookmarks', {}).get(catalog_entry['tap_stream_id'], {}).pop('JobID', None)
            state.get('bookmarks', {}).get(catalog_entry['tap_stream_id'], {}).pop('BatchIDs', None)
            bookmark = state.get('bookmarks', {}).get(catalog_entry['tap_stream_id'], {}).pop('JobHighestBookmarkSeen', None)
            state = singer.write_bookmark(
                state,
                catalog_entry['tap_stream_id'],
                replication_key,
                bookmark)
            singer.write_state(state)
    else:
        state_msg_threshold = CONFIG.get('state_message_threshold', 1000)

        # Tables with a replication_key or an empty bookmark will emit an
        # activate_version at the beginning of their sync
        bookmark_is_empty = state.get('bookmarks', {}).get(
            catalog_entry['tap_stream_id']) is None

        if replication_key or bookmark_is_empty:
            singer.write_message(activate_version_message)
            state = singer.write_bookmark(state,
                                          catalog_entry['tap_stream_id'],
                                          'version',
                                          stream_version)
        await loop.run_in_executor(None, sync_stream, sf, catalog_entry, state, state_msg_threshold)
        LOGGER.info("Completed sync for %s", stream_name)

def do_sync(sf, catalog, state):
    LOGGER.info("Starting sync")

    max_workers = CONFIG.get('max_workers', 8)
    executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
    loop = asyncio.get_event_loop()
    loop.set_default_executor(executor)

    try:
        streams_to_sync = catalog["streams"]

        # Schedule one task for each catalog entry to be extracted
        # and run them concurrently.
        sync_tasks = (sync_catalog_entry(sf, catalog_entry, state)
                      for catalog_entry in streams_to_sync)
        tasks = asyncio.gather(*sync_tasks)
        loop.run_until_complete(tasks)
    finally:
        loop.run_until_complete(loop.shutdown_asyncgens())
        loop.close()

    singer.write_state(state)
    LOGGER.info("Finished sync")

def main_impl():
    args = singer_utils.parse_args(REQUIRED_CONFIG_KEYS)
    CONFIG.update(args.config)

    credentials = parse_credentials(CONFIG)
    sf = None
    try:
        sf = Salesforce(
            credentials=credentials,
            quota_percent_total=CONFIG.get('quota_percent_total'),
            quota_percent_per_run=CONFIG.get('quota_percent_per_run'),
            is_sandbox=CONFIG.get('is_sandbox'),
            select_fields_by_default=CONFIG.get('select_fields_by_default'),
            default_start_date=CONFIG.get('start_date'),
            api_type=CONFIG.get('api_type'))
        sf.login()

        if args.discover:
            do_discover(sf)
        elif args.properties or args.catalog:
            catalog = args.properties or args.catalog.to_dict()
            state = build_state(args.state, catalog)
            do_sync(sf, catalog, state)
    finally:
        if sf:
            if sf.rest_requests_attempted > 0:
                LOGGER.debug(
                    "This job used %s REST requests towards the Salesforce quota.",
                    sf.rest_requests_attempted)
            if sf.jobs_completed > 0:
                LOGGER.debug(
                    "Replication used %s Bulk API jobs towards the Salesforce quota.",
                    sf.jobs_completed)
            if sf.auth.login_timer:
                sf.auth.login_timer.cancel()


def main():
    try:
        main_impl()
    except TapSalesforceQuotaExceededException as e:
        LOGGER.critical(e)
        sys.exit(2)
    except TapSalesforceException as e:
        LOGGER.critical(e)
        sys.exit(1)
    except Exception as e:
        LOGGER.critical(e)
        raise e
