from dbt.adapters.sql import SQLAdapter
from dbt.adapters.db2_for_i import DB2ForIConnectionManager
import dbt.exceptions
from dbt.adapters.db2_for_i.relation import DB2ForIRelation

import agate

from typing import (
    Optional, List
)


class DB2ForIAdapter(SQLAdapter):
    ConnectionManager = DB2ForIConnectionManager
    Relation = DB2ForIRelation

    @classmethod
    def date_function(cls):
        return "current_timestamp"


    @classmethod
    def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str:
        column = agate_table.columns[col_idx]
        # see https://github.com/fishtown-analytics/dbt/pull/2255
        lens = [len(d.encode("utf-8")) for d in column.values_without_nulls()]
        max_len = max(lens) if lens else 64
        length = max_len if max_len > 16 else 16
        return "varchar({})".format(length)


    @classmethod
    def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:
        return "timestamp"


    @classmethod
    def convert_boolean_type(cls, agate_table: agate.Table, col_idx: int) -> str:
        # some db2 for i versions do not support boolean data types, so I will use smallint here for now 
        return "smallint"

    @classmethod
    def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str:
        decimals = agate_table.aggregate(agate.MaxPrecision(col_idx))
        return "float" if decimals else "int"


    @classmethod
    def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str:
        return "time"

    @classmethod
    def convert_date_type(cls, agate_table: agate.Table, col_idx: int) -> str:
        return "date"

    def debug_query(self) -> None:
        self.execute('select 1 as one from sysibm.sysdummy1')


    # Methods used in adapter tests
    def timestamp_add_sql(self, add_to: str, number: int = 1, interval: str = 'hour') -> str:
        return f"{add_to} - {number} {interval}"

    def string_add_sql(self, add_to: str, value: str, location='append') -> str:
        if location == 'append':
            return f"{add_to} || '{value}'"
        elif location == 'prepend':
            return f"'{value}' || {add_to}"
        else:
            raise dbt.exceptions.DbtRuntimeError(
                f'Got an unexpected location value of "{location}"'
            )

    
    def get_rows_different_sql(self, relation_a: DB2ForIRelation, relation_b: DB2ForIRelation, column_names: Optional[List[str]] = None, except_operator: str = 'EXCEPT') -> str:
        """
        Generate SQL for a query that returns a single row with a two
        columns: the number of rows that are different between the two
        relations and the number of mismatched rows.
        """
        # This method only really exists for test reasons.
        names: List[str]
        if column_names is None:
            columns = self.get_columns_in_relation(relation_a)
            names = sorted((self.quote(c.name) for c in columns))
        else:
            names = sorted((self.quote(n) for n in column_names))
        columns_csv = ", ".join(names)

        sql = COLUMNS_EQUAL_SQL.format(
            columns=columns_csv,
            relation_a=str(relation_a),
            relation_b=str(relation_b),
            except_op=except_operator,
        )

        return sql


    
COLUMNS_EQUAL_SQL = """
with diff_count as (
    SELECT
        1 as id,
        COUNT(*) as num_missing FROM (
            (SELECT {columns} FROM {relation_a} {except_op}
             SELECT {columns} FROM {relation_b})
             UNION ALL
            (SELECT {columns} FROM {relation_b} {except_op}
             SELECT {columns} FROM {relation_a})
        ) as a
), table_a as (
    SELECT COUNT(*) as num_rows FROM {relation_a}
), table_b as (
    SELECT COUNT(*) as num_rows FROM {relation_b}
), row_count_diff as (
    select
        1 as id,
        table_a.num_rows - table_b.num_rows as difference
    from table_a, table_b
)
select
    row_count_diff.difference as row_count_difference,
    diff_count.num_missing as num_mismatched
from row_count_diff
join diff_count on row_count_diff.id = diff_count.id
""".strip()