from unittest import TestCase
import luigi
from luigi import postgres
import luigi.notifications
luigi.notifications.DEBUG = True
luigi.namespace('postgres_test')

"""
Typical use cases that should be tested:

* Daily overwrite of all data in table
* Daily inserts of new segment in table
* (Daily insertion/creation of new table)
* Daily insertion of multiple (different) new segments into table


"""

# to avoid copying:
class CopyToTestDB(postgres.CopyToTable):
    host = 'localhost'
    database = 'spotify'
    user = 'spotify'
    password = 'guest'


class TestPostgresTask(CopyToTestDB):
    table = 'test_table'
    columns = (('test_text', 'text'),
               ('test_int', 'int'),
               ('test_float', 'float'))
    
    def create_table(self, connection):
        connection.cursor().execute(
            "CREATE TABLE {table} (id SERIAL PRIMARY KEY, test_text TEXT, test_int INT, test_float FLOAT)"
        .format(table=self.table))

    def rows(self):
        yield 'foo', 123, 123.45
        yield None, '-100', '5143.213'
        yield '\t\n\r\\N', 0, 0



class MetricBase(CopyToTestDB):
    table = 'metrics'
    columns = [('metric', 'text'),
               ('value', 'int')
              ]


class Metric1(MetricBase):
    param = luigi.Parameter()

    def rows(self):
        yield 'metric1', 1
        yield 'metric1', 2
        yield 'metric1', 3

class Metric2(MetricBase):
    param = luigi.Parameter()

    def rows(self):
        yield 'metric2', 1
        yield 'metric2', 4
        yield 'metric2', 3


class TestPostgresImportTask(TestCase):
    def test_default_escape(self):
        self.assertEquals(postgres.default_escape('foo'), 'foo')
        self.assertEquals(postgres.default_escape('\n'), '\\n')
        self.assertEquals(postgres.default_escape('\\\n'), '\\\\\\n')
        self.assertEquals(postgres.default_escape('\n\r\\\t\\N\\'),
                                                  '\\n\\r\\\\\\t\\\\N\\\\')

    def test_repeat(self):
        task = TestPostgresTask()
        conn = task.output().connect()
        conn.autocommit = True
        cursor = conn.cursor()
        cursor.execute('DROP TABLE IF EXISTS {table}'.format(table=task.table))
        cursor.execute('DROP TABLE IF EXISTS {marker_table}'.format(marker_table=postgres.PostgresTarget.marker_table))

        luigi.build([task], local_scheduler=True)
        luigi.build([task], local_scheduler=True) # try to schedule twice

        cursor.execute("""SELECT test_text, test_int, test_float
                          FROM test_table
                          ORDER BY id ASC""")

        rows = tuple(cursor)

        self.assertEquals(rows, (
            ('foo', 123, 123.45),
            (None, -100, 5143.213),
            ('\t\n\r\\N', 0.0, 0))
        )

    def test_multimetric(self):
        metrics = MetricBase()
        conn = metrics.output().connect()
        conn.autocommit = True
        conn.cursor().execute('DROP TABLE IF EXISTS {table}'.format(table=metrics.table))
        conn.cursor().execute('DROP TABLE IF EXISTS {marker_table}'.format(marker_table=postgres.PostgresTarget.marker_table))
        luigi.build([Metric1(20), Metric1(21), Metric2("foo")], local_scheduler=True)

        cursor = conn.cursor()
        cursor.execute('select count(*) from {table}'.format(table=metrics.table))
        self.assertEquals(tuple(cursor), ((9,),))

    def test_clear(self):
        class Metric2Copy(Metric2):
            def init_copy(self, connection):
                query = "TRUNCATE {0}".format(self.table)
                connection.cursor().execute(query)

        clearer = Metric2Copy(21)
        conn = clearer.output().connect()
        conn.autocommit = True
        conn.cursor().execute('DROP TABLE IF EXISTS {table}'.format(table=clearer.table))
        conn.cursor().execute('DROP TABLE IF EXISTS {marker_table}'.format(marker_table=postgres.PostgresTarget.marker_table))

        luigi.build([Metric1(0), Metric1(1)], local_scheduler=True)
        luigi.build([clearer], local_scheduler=True)
        cursor = conn.cursor()
        cursor.execute('select count(*) from {table}'.format(table=clearer.table))
        self.assertEquals(tuple(cursor), ((3,),))        

luigi.namespace()