from rdflib.namespace import RDF
from six import text_type
from sqlalchemy.sql import expression, functions

from rdflib_sqlalchemy.constants import (
    ASSERTED_TYPE_PARTITION,
    ASSERTED_NON_TYPE_PARTITION,
    ASSERTED_LITERAL_PARTITION,
    QUOTED_PARTITION,
    CONTEXT_SELECT,
    COUNT_SELECT,
    FULL_TRIPLE_PARTITIONS,
    TRIPLE_SELECT,
)


def query_analysis(query, store, connection):
    """
    Helper function.

    For executing EXPLAIN on all dispatched SQL statements -
    for the pupose of analyzing index usage.

    """
    res = connection.execute("explain " + query)
    rt = res.fetchall()[0]
    table, joinType, posKeys, _key, key_len, \
        comparedCol, rowsExamined, extra = rt
    if not _key:
        assert joinType == "ALL"
        if not hasattr(store, "queryOptMarks"):
            store.queryOptMarks = {}
        hits = store.queryOptMarks.get(("FULL SCAN", table), 0)
        store.queryOptMarks[("FULL SCAN", table)] = hits + 1

    if not hasattr(store, "queryOptMarks"):
        store.queryOptMarks = {}
    hits = store.queryOptMarks.get((_key, table), 0)
    store.queryOptMarks[(_key, table)] = hits + 1


def union_select(select_components, distinct=False, select_type=TRIPLE_SELECT):
    """
    Helper function for building union all select statement.

    Args:
        select_components (iterable of tuples): Indicates the table and table type
            (table_name, where_clause_string, table_type)
        distinct (bool): Whether to eliminate duplicate results
        select_type (int): From `rdflib_sqlalchemy.constants`. Either `.COUNT_SELECT`,
            `.CONTEXT_SELECT`, `.TRIPLE_SELECT`

    """
    selects = []
    for table, whereClause, tableType in select_components:
        # TODO: if whereClause is None, skip calling 'where'
        if select_type == COUNT_SELECT:
            c = table.c
            if tableType == ASSERTED_TYPE_PARTITION:
                cols = [c.member, c.klass]
            elif tableType in (ASSERTED_LITERAL_PARTITION, ASSERTED_NON_TYPE_PARTITION, QUOTED_PARTITION):
                cols = [c.subject, c.predicate, c.object]
            else:
                raise ValueError('Unrecognized table type {}'.format(tableType))
            inner_select = expression.select(*cols).where(whereClause).distinct().select_from(table).alias()
            select_clause = expression.select(expression.func.count().label('aCount')).select_from(inner_select)
        elif select_type == CONTEXT_SELECT:
            select_clause = expression.select(table.c.context)
            if whereClause is not None:
                select_clause = select_clause.where(whereClause)
        elif tableType in FULL_TRIPLE_PARTITIONS:
            select_clause = table.select().where(whereClause) if whereClause is not None else table.select()
        elif tableType == ASSERTED_TYPE_PARTITION:
            select_clause = expression.select(
                *[table.c.id.label("id"),
                 table.c.member.label("subject"),
                 expression.literal(text_type(RDF.type)).label("predicate"),
                 table.c.klass.label("object"),
                 table.c.context.label("context"),
                 table.c.termComb.label("termcomb"),
                 expression.literal_column("NULL").label("objlanguage"),
                 expression.literal_column("NULL").label("objdatatype")]).where(
                whereClause)
        elif tableType == ASSERTED_NON_TYPE_PARTITION:
            all_table_columns = [c for c in table.columns] + \
                                [expression.literal_column("NULL").label("objlanguage"),
                                 expression.literal_column("NULL").label("objdatatype")]
            if whereClause is not None:
                select_clause = expression.select(*all_table_columns).select_from(table).where(whereClause)
            else:
                select_clause = expression.select(*all_table_columns).select_from(table)
        selects.append(select_clause)

    order_statement = []
    if select_type == TRIPLE_SELECT:
        order_statement = [
            expression.literal_column("subject"),
            expression.literal_column("predicate"),
            expression.literal_column("object"),
        ]
    if distinct and select_type != COUNT_SELECT:
        return expression.union(*selects).order_by(*order_statement)
    else:
        return expression.union_all(*selects).order_by(*order_statement)
