# postgresql.py

import logging
import psycopg2
import psycopg2.extras
import psycopg2.extensions
from typing import List, Tuple
from . import CfDatabase

logger = logging.getLogger(__name__)


class CfPostgresql(CfDatabase):

    def __init__(self, host: str = '127.0.0.1', port: str = '5432', database: str = None, user: str = None,
                 password: str = None, autocommit: bool = False, dictionary_cursor: bool = False,
                 encoding: str = 'utf8'):
        """Interact with a PostgreSQL database.

        Args:
            host: Database server hostname or IP.
            port: Database server port.
            database: Database name.
            user: Database server account username.
            password: Database server account password.
            autocommit: Use autocommit on changes?
            dictionary_cursor: Return the results as a dictionary?
            encoding: Database client encoding ("utf8", "latin1", "usascii")
        """
        super().__init__(host=host, port=port, database=database, user=user, password=password)

        logger.debug(f'Connecting: postgresql://[hidden]:[hidden]@{host}:{port}/{database}')
        self.connection = psycopg2.connect(
            host=host,
            port=port,
            dbname=database,
            user=user,
            password=password,
            client_encoding=encoding
        )

        if autocommit is True:
            self.connection.autocommit = True

        if dictionary_cursor is True:
            self.cursor = self.connection.cursor(cursor_factory=psycopg2.extras.DictCursor)
        else:
            self.cursor = self.connection.cursor()

        logger.debug('Database connection established.')

    @property
    def connection_status(self) -> str:
        """Return psycopg2 connection status.

        Returns:
            Connection status
        """
        connection_info = psycopg2.extensions.ConnectionInfo()

        if connection_info.status == psycopg2.extensions.TRANSACTION_STATUS_IDLE:
            return 'Idle'
        if connection_info.status == psycopg2.extensions.TRANSACTION_STATUS_ACTIVE:
            return 'Active'
        if connection_info.status == psycopg2.extensions.TRANSACTION_STATUS_INTRANS:
            return 'In Transition'
        if connection_info.status == psycopg2.extensions.TRANSACTION_STATUS_INERROR:
            return 'In Error'
        else:
            return 'Unknown'

    @staticmethod
    def split_table_path(table_path: str = None) -> Tuple[str, str]:
        """Split table_path into it's schema and table parts.

        Args:
            table_path: Table name in the form <schema>.<table>
        Returns:
            schema
            table
        """
        if table_path is None:
            raise ValueError('warehouse_table cannot be None.')

        schema = table_path.split('.')[0]
        try:
            table = table_path.split('.')[1]
        except IndexError:  # A missing period result in no second value, throwing an IndexError.
            raise ValueError(f'table_path not in proper <schema>.<table> format: {table_path}')

        return schema, table

    def table_exists(self, table_path: str = None) -> bool:
        """Test for the existence of a schema table.

        Args:
            table_path: Table name in the form <schema>.<table>
        Returns:
            True if table exists. False if not.
        """
        schema, table = self.split_table_path(table_path)

        results = self.execute(f'''
            SELECT EXISTS(
              SELECT 1
              FROM warehouse.pg_catalog.pg_tables
              WHERE schemaname = '{schema}'
              AND tablename = '{table}'
            );
        ''')

        if results[0][0] is not True:
            return False
        else:
            return True

    def fields_exist(self, table_path: str = None, table_fields: List = None):
        """Tests for the existence of table fields.

        Args:
            table_path: Table name in the form <schema>.<table>
            table_fields: A list of fields to check.
        """
        schema, table = self.split_table_path(table_path)

        # Get table columns and data types
        results = self.execute(f'''
            SELECT column_name
            FROM information_schema.columns
            WHERE 
              table_schema = '{schema}'
              AND table_name = '{table}';
        ''')

        live_table_fields = []

        for result in results:
            live_table_fields.append(result['column_name'])

        all_fields_present = True
        for table_field in table_fields:
            if table_field not in live_table_fields:
                logger.warning(f'field ({table_field}) not in live table {table_path}.')
                all_fields_present = False

        return all_fields_present
