import subprocess
from itertools import chain

import sqlalchemy
from sqlalchemy import inspect
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.schema import CreateColumn
from sqlalchemy.sql.functions import GenericFunction
from sqlalchemy.types import Text

from puddl.conf import DBConfig

# https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#postgresql-10-identity-columns
from puddl.typing import URL


@compiles(CreateColumn, 'postgresql')
def use_identity(element, compiler, **kw):
    text = compiler.visit_create_column(element, **kw)
    text = text.replace("SERIAL", "INT GENERATED BY DEFAULT AS IDENTITY")
    return text


class Upserter:
    """
    Finds all columns with a unique constraint and upserts based on those.
    """

    def __init__(self, session, model):
        self.session = session
        self.model = model
        inspector = Inspector.from_engine(session.bind)
        _xs = inspector.get_unique_constraints(self.model.__tablename__, schema=self.model.__table__.schema)
        list_of_lists = [x['column_names'] for x in _xs]
        self.unique_columns = set(chain.from_iterable(list_of_lists))

    def upsert(self, **data):
        # read this carefully:
        # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#sqlalchemy.dialects.postgresql.dml.Insert.on_conflict_do_update
        columns = set(data.keys())
        update_columns = columns - self.unique_columns
        insert_stmt = insert(self.model).values(**data)
        excluded_mapping = {  # LHS: column name, RHS: Column instance
            col: getattr(insert_stmt.excluded, col) for col in update_columns
        }
        do_update_stmt = insert_stmt.on_conflict_do_update(
            index_elements=self.unique_columns,
            set_=excluded_mapping,
        )
        self.session.execute(do_update_stmt)


# https://docs.sqlalchemy.org/en/14/core/functions.html#sqlalchemy.sql.functions.GenericFunction
# noinspection PyPep8Naming
class puddl_upsert_role(GenericFunction):
    type = Text


