# Copyright 2013 Mirantis Inc.
# All Rights Reserved
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.
#

import os
import tempfile

from migrate import exceptions as migrate_exception
from migrate.versioning import api as versioning_api
import mock
import sqlalchemy

from oslo.db import exception as db_exception
from oslo.db.sqlalchemy import migration
from oslo.db.sqlalchemy import test_base
from oslo_db.sqlalchemy import migration as private_migration
from oslo_db.tests.old_import_api import utils as test_utils


class TestMigrationCommon(test_base.DbTestCase):
    def setUp(self):
        super(TestMigrationCommon, self).setUp()

        migration._REPOSITORY = None
        self.path = tempfile.mkdtemp('test_migration')
        self.path1 = tempfile.mkdtemp('test_migration')
        self.return_value = '/home/openstack/migrations'
        self.return_value1 = '/home/extension/migrations'
        self.init_version = 1
        self.test_version = 123

        self.patcher_repo = mock.patch.object(private_migration, 'Repository')
        self.repository = self.patcher_repo.start()
        self.repository.side_effect = [self.return_value, self.return_value1]

        self.mock_api_db = mock.patch.object(versioning_api, 'db_version')
        self.mock_api_db_version = self.mock_api_db.start()
        self.mock_api_db_version.return_value = self.test_version

    def tearDown(self):
        os.rmdir(self.path)
        self.mock_api_db.stop()
        self.patcher_repo.stop()
        super(TestMigrationCommon, self).tearDown()

    def test_db_version_control(self):
        with test_utils.nested(
            mock.patch('oslo_db.sqlalchemy.migration._find_migrate_repo'),
            mock.patch.object(versioning_api, 'version_control'),
        ) as (mock_find_repo, mock_version_control):
            mock_find_repo.return_value = self.return_value

            version = migration.db_version_control(
                self.engine, self.path, self.test_version)

            self.assertEqual(version, self.test_version)
            mock_version_control.assert_called_once_with(
                self.engine, self.return_value, self.test_version)

    def test_db_version_return(self):
        ret_val = migration.db_version(self.engine, self.path,
                                       self.init_version)
        self.assertEqual(ret_val, self.test_version)

    def test_db_version_raise_not_controlled_error_first(self):
        patcher = mock.patch.object(private_migration, 'db_version_control')
        with patcher as mock_ver:

            self.mock_api_db_version.side_effect = [
                migrate_exception.DatabaseNotControlledError('oups'),
                self.test_version]

            ret_val = migration.db_version(self.engine, self.path,
                                           self.init_version)
            self.assertEqual(ret_val, self.test_version)
            mock_ver.assert_called_once_with(self.engine, self.path,
                                             version=self.init_version)

    def test_db_version_raise_not_controlled_error_tables(self):
        with mock.patch.object(sqlalchemy, 'MetaData') as mock_meta:
            self.mock_api_db_version.side_effect = \
                migrate_exception.DatabaseNotControlledError('oups')
            my_meta = mock.MagicMock()
            my_meta.tables = {'a': 1, 'b': 2}
            mock_meta.return_value = my_meta

            self.assertRaises(
                db_exception.DbMigrationError, migration.db_version,
                self.engine, self.path, self.init_version)

    @mock.patch.object(versioning_api, 'version_control')
    def test_db_version_raise_not_controlled_error_no_tables(self, mock_vc):
        with mock.patch.object(sqlalchemy, 'MetaData') as mock_meta:
            self.mock_api_db_version.side_effect = (
                migrate_exception.DatabaseNotControlledError('oups'),
                self.init_version)
            my_meta = mock.MagicMock()
            my_meta.tables = {}
            mock_meta.return_value = my_meta
            migration.db_version(self.engine, self.path, self.init_version)

            mock_vc.assert_called_once_with(self.engine, self.return_value1,
                                            self.init_version)

    def test_db_sync_wrong_version(self):
        self.assertRaises(db_exception.DbMigrationError,
                          migration.db_sync, self.engine, self.path, 'foo')

    def test_db_sync_upgrade(self):
        init_ver = 55
        with test_utils.nested(
            mock.patch('oslo_db.sqlalchemy.migration._find_migrate_repo'),
            mock.patch.object(versioning_api, 'upgrade')
        ) as (mock_find_repo, mock_upgrade):

            mock_find_repo.return_value = self.return_value
            self.mock_api_db_version.return_value = self.test_version - 1

            migration.db_sync(self.engine, self.path, self.test_version,
                              init_ver)

            mock_upgrade.assert_called_once_with(
                self.engine, self.return_value, self.test_version)

    def test_db_sync_downgrade(self):
        with test_utils.nested(
            mock.patch('oslo_db.sqlalchemy.migration._find_migrate_repo'),
            mock.patch.object(versioning_api, 'downgrade')
        ) as (mock_find_repo, mock_downgrade):

            mock_find_repo.return_value = self.return_value
            self.mock_api_db_version.return_value = self.test_version + 1

            migration.db_sync(self.engine, self.path, self.test_version)

            mock_downgrade.assert_called_once_with(
                self.engine, self.return_value, self.test_version)

    def test_db_sync_sanity_called(self):
        with test_utils.nested(
            mock.patch('oslo_db.sqlalchemy.migration._find_migrate_repo'),
            mock.patch('oslo_db.sqlalchemy.migration._db_schema_sanity_check'),
            mock.patch.object(versioning_api, 'downgrade')
        ) as (mock_find_repo, mock_sanity, mock_downgrade):

            mock_find_repo.return_value = self.return_value
            migration.db_sync(self.engine, self.path, self.test_version)

            self.assertEqual([mock.call(self.engine), mock.call(self.engine)],
                             mock_sanity.call_args_list)

    def test_db_sync_sanity_skipped(self):
        with test_utils.nested(
            mock.patch('oslo_db.sqlalchemy.migration._find_migrate_repo'),
            mock.patch('oslo_db.sqlalchemy.migration._db_schema_sanity_check'),
            mock.patch.object(versioning_api, 'downgrade')
        ) as (mock_find_repo, mock_sanity, mock_downgrade):

            mock_find_repo.return_value = self.return_value
            migration.db_sync(self.engine, self.path, self.test_version,
                              sanity_check=False)

            self.assertFalse(mock_sanity.called)
