# pylint: disable=missing-module-docstring
import os
from functools import partial
from typing import Generator, Callable, Tuple, List, Dict, Any
from datetime import datetime
import oracledb
from oracledb.cursor import Cursor
from dwh_oppfolging.apis.secrets_api_v1 import get_secrets


def log_oracle_etl(
    cur: Cursor,
    schema: str,
    table: str,
    etl_date: datetime,
    rows_inserted: int = -1,
    rows_updated: int = -1,
    rows_deleted: int = -1,
    log_text: str = "",
):
    """inserts into logging table, does not commit"""
    sql = f"insert into {schema}.etl_logg select :0,:1,:2,:3,:4,:5 from dual"
    cur.execute(sql, [table, etl_date, rows_inserted, rows_updated, rows_deleted, log_text])


def _fix_timestamp_inputtypehandler(cur, val, arrsize):
    if isinstance(val, datetime) and val.microsecond > 0:
        # pylint: disable=no-member
        return cur.var(oracledb.DB_TYPE_TIMESTAMP, arraysize=arrsize) # type: ignore
        # pylint: enable=no-member
    # No return value implies default type handling


def create_oracle_connection(user: str):
    """creates oracle connection with db access
    you have to call .commit() yourself"""

    oracle_secrets = get_secrets()[os.environ["ORACLE_ENV"]]
    con = oracledb.connect(  # type: ignore
        user=oracle_secrets[user+"_user"],
        password=oracle_secrets[user+"_pw"],
        dsn=oracle_secrets["dsn"],
        encoding="utf-8",
        nencoding="utf-8",
    )
    con.inputtypehandler = _fix_timestamp_inputtypehandler
    return con


def is_table_empty(cur: Cursor, schema: str, table: str) -> bool:
    """checks if rowcount is 0"""
    sql = f"select count(*) from {schema}.{table}"
    rowcount = cur.execute(sql).fetchone()[0]  # type: ignore
    return rowcount == 0


def update_table_immediate(cur: Cursor, schema: str, table: str, update_sql: str):
    """basic update of table using provided sql
    the sql must have a :today datetime bind (used for lastet_dato etc.,)"""
    update_date = datetime.today()
    count_sql = f"select count(*) from {schema}.{table}"
    num_rows_old = cur.execute(count_sql).fetchone()[0] # type: ignore
    cur.execute(update_sql, today=update_date) # type: ignore
    num_rows_new = cur.execute(count_sql).fetchone()[0] # type: ignore
    rows_inserted = num_rows_new - num_rows_old
    rows_updated = cur.rowcount - 1
    print("inserted", rows_inserted, "new records")
    print("updated", rows_updated, "existing records")
    log_oracle_etl(cur, schema, table, update_date, rows_inserted, rows_updated)
    print("logged etl for", table)


def insert_to_table_batched(
    cur: Cursor,
    schema: str,
    table: str,
    batch_factory: Callable[..., Generator[List[Dict[str, Any]], None, None]],
    need_insert_date: bool = True,
    needs_last_modified_date: Tuple[bool, str] = (False, ""),
    skip_on_column: Tuple[bool, str] = (False, "")
):
    """inserts to table in batches generated by the batch factory
    it is assumed that the number and name of columns in each row of each batch remain constant

    if need_insert_date is set,
        insert_date will be sent as a keyword parameter to the batch factory (useful for lastet_dato)
    if needs_last_modified_date is set,
        this date will be fetched from the table in the column given (useful for oppdatert_dato_kilde)
        and then sent as a keyword parameter to the batch factory
        If the table is empty, it defaults to datetime(1900, 1, 1)
    if skip_on_column is set,
        the rows which has a value in this column which already exists
        will not be inserted (useful for hash uniqueness)
    """
    insert_date = datetime.today()
    insert_sql = ""
    insert_fmt = f"insert into {schema}.{table} " + "({col_names}) select (:{col_binds}) from dual"
    mod_date_sql = f"select max({skip_on_column[1]}) from {schema}.{table}"
    rows_inserted = 0
    if need_insert_date:
        batch_factory = partial(batch_factory, insert_date=insert_date)
    if needs_last_modified_date[0]:
        last_modified_date = cur.execute(mod_date_sql).fetchone()[0] # type: ignore
        batch_factory = partial(batch_factory, last_modified_date=last_modified_date)
    for batch in batch_factory():
        if not insert_sql:
            cols = [*(batch[0])]
            col_names = ",".join(cols)
            col_binds = ",:".join(cols)
            insert_sql = insert_fmt.format(col_names=col_names, col_binds=col_binds)
        cur.executemany(insert_sql, batch)
        rows_inserted += cur.rowcount
        print("inserted", cur.rowcount, "rows")
    log_oracle_etl(cur, schema, table, insert_date, rows_inserted)
