from collections import OrderedDict
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union

from google.cloud.bigquery import SchemaField

from bigquery_frame import BigQueryBuilder, DataFrame
from bigquery_frame import functions as f
from bigquery_frame.auth import get_bq_client
from bigquery_frame.column import Column, StringOrColumn, cols_to_str
from bigquery_frame.dataframe import strip_margin
from bigquery_frame.transformations_impl.flatten import flatten_schema
from bigquery_frame.utils import quote, str_to_col

OrderedTree = Union["OrderedTree", Dict[str, "OrderedTree"]]

BIGQUERY_TYPE_ALIASES = {
    "INT64": "INT64",
    "INT": "INT64",
    "SMALLINT": "INT64",
    "INTEGER": "INT64",
    "BIGINT": "INT64",
    "TINYINT": "INT64",
    "BYTEINT": "INT64",
    "NUMERIC": "NUMERIC",
    "DECIMAL": "NUMERIC",
    "FLOAT": "FLOAT64",
    "BIGNUMERIC": "BIGNUMERIC",
    "BIGDECIMAL": "BIGNUMERIC",
}
"""This dict gives for each type the most commonly use alias
Source: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#numeric_types
"""

BIGQUERY_TYPE_CONVERSIONS = {
    "BOOL": "STRING",
    "INT64": "NUMERIC",
    "NUMERIC": "BIGNUMERIC",
    "BIGNUMERIC": "FLOAT64",
    "FLOAT64": "STRING",
    "DATE": "DATETIME",
    "DATETIME": "TIMESTAMP",
    "TIMESTAMP": "STRING",
    "TIME": "STRING",
    "BYTES": "STRING",
    "GEOGRAPHY": "STRING",
}
"""This dict gives for each type the smallest supertype that it can be converted to.
Source: https://cloud.google.com/bigquery/docs/reference/standard-sql/conversion_rules
"""


def _resolve_type_alias(tpe: str) -> str:
    """Return the most commonly used alias for this type, if any. Else, return the input.

    Based on: https://cloud.google.com/bigquery/docs/reference/standard-sql/conversion_rules

    :param tpe: a BigQuery type
    :return: the most common alias for this type
    """
    return BIGQUERY_TYPE_ALIASES.get(tpe, tpe)


def _list_wider_types(tpe: str) -> Generator[str, None, None]:
    """Return a generator of all wider types into which the given type can be cast.

    Based on: https://cloud.google.com/bigquery/docs/reference/standard-sql/conversion_rules

    >>> list(_list_wider_types("INT64"))
    ['INT64', 'NUMERIC', 'BIGNUMERIC', 'FLOAT64', 'STRING']
    >>> list(_list_wider_types("INTEGER"))
    ['INT64', 'NUMERIC', 'BIGNUMERIC', 'FLOAT64', 'STRING']
    >>> list(_list_wider_types("STRING"))
    ['STRING']
    >>> list(_list_wider_types("FOO"))
    ['FOO']

    :param tpe: input type
    :return: a generator of wider types
    """
    current_type = _resolve_type_alias(tpe)
    max_loop = 10
    counter = 0
    next_type = BIGQUERY_TYPE_CONVERSIONS.get(current_type, None)
    while current_type != next_type:
        counter += 1
        yield current_type
        assert_true(
            counter < max_loop, "This should not happen. Please check thatBIGQUERY_CONVERSIONS does not contain a loop."
        )
        if next_type is None:
            break
        else:
            current_type = next_type
            next_type = BIGQUERY_TYPE_CONVERSIONS.get(current_type, None)


def find_wider_type_for_two(t1: str, t2: str) -> Optional[str]:
    """Find the smallest common type into which the two given types can both be cast.
    Returns None if no such type exists.

    Based on: https://cloud.google.com/bigquery/docs/reference/standard-sql/conversion_rules

    >>> find_wider_type_for_two("INT64", "DECIMAL")
    'NUMERIC'
    >>> find_wider_type_for_two("TINYINT", "INTEGER")
    'INT64'
    >>> find_wider_type_for_two("DECIMAL", "FLOAT64")
    'FLOAT64'
    >>> find_wider_type_for_two("DATE", "TIMESTAMP")
    'TIMESTAMP'
    >>> find_wider_type_for_two("DATE", "TIME")
    'STRING'
    >>> find_wider_type_for_two("ARRAY<INT>", "INT")

    :param t1: a BigQuery type
    :param t2: another BigQuery type
    :return: the smallest common type for the two
    """
    l1 = list(_list_wider_types(t1))
    for t in _list_wider_types(t2):
        if t in l1:
            return t
    return None


def get_common_columns(left_schema: List[SchemaField], right_schema: List[SchemaField]) -> List[Tuple[str, str]]:
    """Return a list of common Columns between two DataFrame schemas, along with the widest common type
    of for the two columns.

    When columns already have the same type or have incompatible types, they are simply not cast.

    >>> bq = BigQueryBuilder(get_bq_client())
    >>> df1 = bq.sql('''SELECT 'A' as id, CAST(1 as BIGINT) as d, 'a' as a''')
    >>> df2 = bq.sql('''SELECT 'A' as id, CAST(1 as FLOAT64) as d, ['a'] as a''')
    >>> get_common_columns(df1.schema, df2.schema)
    [('id', None), ('d', 'FLOAT64'), ('a', None)]

    :param left_schema:
    :param right_schema:
    :return:
    """
    left_fields = {field.name: field for field in left_schema}
    right_fields = {field.name: field for field in right_schema}

    def get_columns():
        for name, left_field in left_fields.items():
            if name in right_fields:
                right_field: SchemaField = right_fields[name]
                if left_field:
                    if right_field.field_type == left_field.field_type:
                        yield name, None
                    else:
                        common_type = find_wider_type_for_two(left_field.field_type, right_field.field_type)
                        if common_type is not None:
                            yield name, common_type
                        else:
                            yield name, None

    return list(get_columns())


def _build_nested_struct_tree(common_columns: List[Tuple[str, str]], struct_separator: str) -> OrderedTree:
    def rec_insert(node: OrderedDict, col_name: str, col_type: str) -> None:
        if struct_separator in col_name:
            struct, subcol = col_name.split(struct_separator, 1)
            if struct not in node:
                node[struct] = OrderedDict()
            rec_insert(node[struct], subcol, col_type)
        else:
            node[col_name] = col_type

    tree = OrderedDict()
    for col_name, col_type in common_columns:
        rec_insert(tree, col_name, col_type)
    return tree


def _build_struct_from_tree(node: OrderedTree, prefix: str = "") -> List[Column]:
    """

    >>> tree = OrderedDict([('s!', OrderedDict([('c', None), ('d', "FLOAT64")]))])
    >>> for c in _build_struct_from_tree(tree): print(c)
    (
      SELECT AS STRUCT
        `c` as c,
        CAST(`d` as FLOAT64) as d
      FROM UNNEST(s)
    ) as s
    >>> tree = OrderedDict([('s', OrderedDict([('e!', "FLOAT64")]))])
    >>> for c in _build_struct_from_tree(tree): print(c)
    STRUCT(ARRAY(
      SELECT
        CAST(`X` as FLOAT64)
      FROM UNNEST(`s`.`e`) as X
    ) as e) as s

    :param node:
    :param prefix:
    :return:
    """
    array_separator = "!"
    struct_separator = "."

    cols = []
    for key, type_or_children in node.items():
        is_array = key[-1] == array_separator
        key_no_sep = key.replace(array_separator, "")
        if type_or_children is None or isinstance(type_or_children, str):
            col = f.col(prefix + key_no_sep)
            if type_or_children is not None:
                tpe = type_or_children
                if is_array:
                    col = select_from_unnest(col, "X", f.cast("X", tpe))
                else:
                    col = f.cast(col, tpe)
            cols.append(col.alias(key_no_sep))
        else:
            if is_array:
                fields = _build_struct_from_tree(type_or_children, prefix="")
                struct_col = select_struct_from_unnest(Column(prefix + key_no_sep), fields).alias(key_no_sep)
            else:
                fields = _build_struct_from_tree(type_or_children, prefix + key + struct_separator)
                struct_col = f.struct(*fields).alias(key_no_sep)
            cols.append(struct_col)
    return cols


def make_dataframes_comparable(left_df: DataFrame, right_df: DataFrame) -> Tuple[DataFrame, DataFrame]:
    """Given two DataFrames, returns two new corresponding DataFrames where only the common columns are kept,
    and where the order of the columns has been reorganized to be the same in both DataFrames.
    When possible, it also widens the type of the fields to the most narrow common type.
    This transformation is also applied recursively on nested columns, including those inside
    repeated records (a.k.a. ARRAY<STRUCT<>>).

    >>> bq = BigQueryBuilder(get_bq_client())
    >>> df1 = bq.sql('SELECT 1 as id, STRUCT(1 as a, [STRUCT(2 as c, 3 as d)] as b, [4, 5] as e) as s')
    >>> df2 = bq.sql('SELECT 1 as id, STRUCT(2 as a, [STRUCT(3.0 as c, "4" as d)] as b, [5.0, 6.0] as e) as s')
    >>> df1.union(df2).show() # doctest: +ELLIPSIS
    Traceback (most recent call last):
      ...
    google.api_core.exceptions.BadRequest: 400 Column 2 in UNION ALL has incompatible types: ...
    >>> df1, df2 = make_dataframes_comparable(df1, df2)
    >>> df1.union(df2).show()
    +----+------------------------------------------------------+
    | id |                                                    s |
    +----+------------------------------------------------------+
    |  1 | {'a': 1, 'b': {'c': 2.0, 'd': '3'}, 'e': [4.0, 5.0]} |
    |  1 | {'a': 2, 'b': {'c': 3.0, 'd': '4'}, 'e': [5.0, 6.0]} |
    +----+------------------------------------------------------+

    :return:
    """
    # left_df = df1
    # right_df = df2
    struct_separator = "."
    array_separator = "!"
    left_schema_flat = flatten_schema(
        left_df.schema, explode=True, struct_separator=struct_separator, array_separator=array_separator
    )
    right_schema_flat = flatten_schema(
        right_df.schema, explode=True, struct_separator=struct_separator, array_separator=array_separator
    )
    common_columns = get_common_columns(left_schema_flat, right_schema_flat)

    # The idea is to recursively write a "SELECT struct(a, struct(s.b.c, s.b.d)) as s" for each nested column.
    tree = _build_nested_struct_tree(common_columns, struct_separator)
    return left_df.select(*_build_struct_from_tree(tree)), right_df.select(*_build_struct_from_tree(tree))


def select_struct_from_unnest(unnest: StringOrColumn, cols: List[StringOrColumn]):
    """
    >>> select_struct_from_unnest('s.b', ['c', 'd'])
    Column('(
      SELECT AS STRUCT
        `c`,
        `d`
      FROM UNNEST(s.b)
    )')

    :param unnest:
    :param cols:
    :return:
    """
    cols = str_to_col(cols)
    return Column(
        strip_margin(
            f"""
    |(
    |  SELECT AS STRUCT
    |{cols_to_str(cols, indentation=4)}
    |  FROM UNNEST({unnest})
    |)"""
        )
    )


def select_from_unnest(unnest: StringOrColumn, element_name: str, col: StringOrColumn):
    """
    >>> select_from_unnest("s", element_name="a", col=f.cast("a", "STRING"))
    Column('ARRAY(
      SELECT
        CAST(`a` as STRING)
      FROM UNNEST(s) as a
    )')

    :param unnest:
    :param element_name:
    :param col:
    :return:
    """
    col = str_to_col(col)
    return Column(
        strip_margin(
            f"""
    |ARRAY(
    |  SELECT
    |    {col}
    |  FROM UNNEST({unnest}) as {element_name}
    |)"""
        )
    )


def join(df1: DataFrame, df2: DataFrame, on: Optional[StringOrColumn] = None, how: Optional[str] = None):
    """Joins with another :class:`DataFrame`, using the given join expression.

    Examples
    --------
    The following performs a full outer join between ``df1`` and ``df2``.
    >>> bq = BigQueryBuilder(get_bq_client())
    >>> df1 = bq.sql('SELECT * FROM UNNEST([STRUCT("Alice" as name), STRUCT("Bob" as name)])').alias("df1")
    >>> df2 = bq.sql('SELECT "Bob" as name, 85 as height').alias("df2")
    >>> join(df1, df2, f.col("df1.name") == f.col("df2.name"), "left").select(f.col("df1.name"), f.col("df2.height")).sort("name DESC").show()
    +-------+--------+
    |  name | height |
    +-------+--------+
    |   Bob |     85 |
    | Alice |   null |
    +-------+--------+
    >>> join(df1, df2, "name", "left").select(f.col("df1.name"), f.col("df2.height")).sort("name DESC").show()
    +-------+--------+
    |  name | height |
    +-------+--------+
    |   Bob |     85 |
    | Alice |   null |
    +-------+--------+

    # >>> join(df1, df2, 'name', 'outer').select('name', 'height').sort(desc("name")).collect()
    # [Row(name='Tom', height=80), Row(name='Bob', height=85), Row(name='Alice', height=None)]
    # >>> cond = [df.name == df3.name, df.age == df3.age]
    # >>> join(df1, df3, cond, 'outer').select(df.name, df3.age).collect()
    # [Row(name='Alice', age=2), Row(name='Bob', age=5)]
    # >>> join(df1, df2, 'name').select(df.name, df2.height).collect()
    # [Row(name='Bob', height=85)]
    # >>> join(df1, df4, ['name', 'age']).select(df.name, df.age).collect()
    # [Row(name='Bob', age=5)]

    :param df1:
    :param df2: Right side of the join
    :param on: str, list or :class:`Column`, optional
        a string for the join column name, a list of column names,
        a join expression (Column), or a list of Columns.
        If `on` is a string or a list of strings indicating the name of the join column(s),
        the column(s) must exist on both sides, and this performs an equi-join.
    :param how: str, optional
        default ``inner``. Must be one of: ``inner``, ``cross``, ``outer``,
        ``full``, ``fullouter``, ``full_outer``, ``left``, ``leftouter``, ``left_outer``,
        ``right``, ``rightouter``, ``right_outer``, ``semi``, ``leftsemi``, ``left_semi``,
        ``anti``, ``leftanti`` and ``left_anti``.
    :return:
    """
    # TODO: Handle join types
    # - SEMI and ANTI keywords do not exist in BQ, but the WHERE IN and WHERE NOT IN syntax do
    # - OUTER JOIN does not exists but LEFT OUTER JOIN and RIGHT OUTER JOIN do

    if how is not None:
        how = how.upper()
        assert_true(
            how
            in [
                "INNER",
                "CROSS",
                "OUTER",
                "FULL",
                "FULLOUTER",
                "FULL_OUTER",
                "LEFT",
                "LEFTOUTER",
                "LEFT_OUTER",
                "RIGHT",
                "RIGHTOUTER",
                "RIGHT_OUTER",
                "SEMI",
                "LEFTSEMI",
                "LEFT_SEMI",
                "ANTI",
                "LEFTANTI",
                "LEFT_ANTI",
            ]
        )
        how += " "
    else:
        how = ""
    df1_short_alias = quote(df1._alias.replace("`", "").split(".")[-1])
    df2_short_alias = quote(df2._alias.replace("`", "").split(".")[-1])
    if isinstance(on, str):
        on = [on]
    on_clause = ""
    if on is not None:
        if isinstance(on, list):
            on_clause = f"\nUSING ({cols_to_str(on)})"
        else:
            on_clause = f"\nON {on}"
    query = strip_margin(
        f"""
        |SELECT 
        |    {df1_short_alias}, 
        |    {df2_short_alias} 
        |FROM {quote(df1._alias)} 
        |{how}JOIN {quote(df2._alias)}{on_clause}
        |"""
    )
    return df1._apply_query(query, deps=[df1, df2])


def assert_true(assertion: bool, error_message: Any = None) -> None:
    """Raise an AssertionError with the given error_message if the assertion passed is false

    >>> assert_true(3==4, "3 <> 4")
    Traceback (most recent call last):
    ...
    AssertionError: 3 <> 4

    >>> assert_true(3==3, "3 <> 4")

    :param assertion: assertion that will be checked
    :param error_message: error message to display if the assertion is false
    """
    if not assertion:
        if error_message is None:
            raise AssertionError()
        else:
            raise AssertionError(str(error_message))
