#!/usr/bin/env python
# coding: utf8
#
# Copyright (c) 2021 Centre National d'Etudes Spatiales (CNES).
#
# This file is part of Pandora plugin MC-CNN
#
#     https://github.com/CNES/Pandora_plugin_mccnn
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
This module contains functions to test Pandora + plugin_mc-cnn
"""

from tempfile import TemporaryDirectory
import unittest
import rasterio
import numpy as np
import xarray as xr

import pandora
from pandora import matching_cost
from pandora_plugin_mc_cnn.plugin_mc_cnn import MCCNN  # pylint: disable=unused-import


# pylint: disable=unsubscriptable-object
class TestPlugin(unittest.TestCase):
    """
    TestPlugin class allows to test Pandora + plugin_mc-cnn
    """

    def setUp(self):
        """
        Method called to prepare the test fixture

        """
        self.disp_ref = rasterio.open("tests/image/disp_left.tif").read(1)
        self.disp_sec = rasterio.open("tests/image/disp_right.tif").read(1)

        # We define a nan array to improve tests readability
        self.nan_array = [
            [np.nan, np.nan, np.nan],
            [np.nan, np.nan, np.nan],
            [np.nan, np.nan, np.nan],
            [np.nan, np.nan, np.nan],
            [np.nan, np.nan, np.nan],
            [np.nan, np.nan, np.nan],
            [np.nan, np.nan, np.nan],
            [np.nan, np.nan, np.nan],
            [np.nan, np.nan, np.nan],
            [np.nan, np.nan, np.nan],
            [np.nan, np.nan, np.nan],
            [np.nan, np.nan, np.nan],
            [np.nan, np.nan, np.nan],
        ]

    @staticmethod
    def error(data, gt, threshold, unknown_disparity=0):
        """
        Percentage of bad pixels whose error is > threshold

        """
        nb_row, nb_col = data.shape
        nb_error = 0
        for row in range(nb_row):
            for col in range(nb_col):
                if gt[row, col] != unknown_disparity:
                    if abs((data[row, col] + gt[row, col])) > threshold:
                        nb_error += 1

        return nb_error / float(nb_row * nb_col)

    def test_mc_cnn(self):
        """
        Test Pandora + plugin_mc-cnn

        """
        # Create temporary directory
        with TemporaryDirectory() as tmp_dir:
            pandora.main("tests/test_cfg_mccnn_fast.json", tmp_dir, verbose=False)

            # Check the reference disparity map
            if self.error(rasterio.open(tmp_dir + "/left_disparity.tif").read(1), self.disp_ref, 1) > 0.17:
                raise AssertionError

            # Check the secondary disparity map
            if self.error(-1 * rasterio.open(tmp_dir + "/right_disparity.tif").read(1), self.disp_sec, 1) > 0.17:
                raise AssertionError

    def test_mc_cnn_default_values(self):
        """
        Test Pandora + plugin_mc-cnn without specifying parameters window size, subpix and model_path
        Uses the default model path : "mc_cnn_fast_mb_weights.pt" stored in MC-CNN pip package

        """
        # Create temporary directory
        with TemporaryDirectory() as tmp_dir:
            pandora.main("tests/test_cfg_mccnn_fast_default_values.json", tmp_dir, verbose=False)

            # Check the reference disparity map
            if self.error(rasterio.open(tmp_dir + "/left_disparity.tif").read(1), self.disp_ref, 1) > 0.17:
                raise AssertionError

            # Check the secondary disparity map
            if self.error(-1 * rasterio.open(tmp_dir + "/right_disparity.tif").read(1), self.disp_sec, 1) > 0.17:
                raise AssertionError

    def test_mc_cnn_multiband_values(self):
        """
        Test Pandora + plugin_mc-cnn with multiband input images
        Uses the default model path : "mc_cnn_fast_mb_weights.pt" stored in MC-CNN pip package

        """
        # Create temporary directory
        with TemporaryDirectory() as tmp_dir:
            pandora.main("tests/test_cfg_mccnn_fast_default_values.json", tmp_dir, verbose=False)

            # Check the reference disparity map
            if self.error(rasterio.open(tmp_dir + "/left_disparity.tif").read(1), self.disp_ref, 1) > 0.17:
                raise AssertionError

            # Check the secondary disparity map
            if self.error(-1 * rasterio.open(tmp_dir + "/right_disparity.tif").read(1), self.disp_sec, 1) > 0.17:
                raise AssertionError

    def test_invalidates_cost(self):
        """
        Test the pipeline compute cost volume, and invalid cost with pandora function

        """
        # ------------ Test the method with a reference mask ( secondary mask contains valid pixels ) ------------
        # Mask convention
        # cfg['image']['valid_pixels'] = 0
        # cfg['image']['no_data'] = 1
        # invalid_pixels all other values
        data = np.zeros((13, 13), dtype=np.float64)
        data += 0.1
        mask = np.array(
            (
                [
                    [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                    [0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0],
                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                    [0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0],
                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                    [0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                ]
            ),
            dtype=np.int16,
        )

        left = xr.Dataset(
            {"im": (["row", "col"], data), "msk": (["row", "col"], mask)},
            coords={"row": np.arange(data.shape[0]), "col": np.arange(data.shape[1])},
        )
        left.attrs["valid_pixels"] = 0
        left.attrs["no_data_mask"] = 1
        left.attrs["crs"] = None
        left.attrs["transform"] = None

        data = np.zeros((13, 13), dtype=np.float64)
        data += 0.1
        # Secondary mask contains valid pixels
        mask = np.zeros((13, 13), dtype=np.int16)
        right = xr.Dataset(
            {"im": (["row", "col"], data), "msk": (["row", "col"], mask)},
            coords={"row": np.arange(data.shape[0]), "col": np.arange(data.shape[1])},
        )
        right.attrs["valid_pixels"] = 0
        right.attrs["no_data_mask"] = 1

        # Cost volume before invalidation, disparities = -1, 0, 1
        cv_before_invali = np.array(
            [
                self.nan_array,
                self.nan_array,
                self.nan_array,
                self.nan_array,
                self.nan_array,
                [
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, -1.0, -1.0],
                    [-1.0, -1.0, -1.0],
                    [-1.0, -1.0, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                ],
                [
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, -1.0, -1.0],
                    [-1.0, -1.0, -1.0],
                    [-1.0, -1.0, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                ],
                [
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, -1.0, -1.0],
                    [-1.0, -1.0, -1.0],
                    [-1.0, -1.0, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                ],
                self.nan_array,
                self.nan_array,
                self.nan_array,
                self.nan_array,
                self.nan_array,
            ],
            dtype=np.float32,
        )

        # Cost volume ground truth after invalidation
        cv_ground_truth = np.array(
            [
                self.nan_array,
                self.nan_array,
                self.nan_array,
                self.nan_array,
                self.nan_array,
                [
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [-1.0, -1.0, -1.0],
                    [-1.0, -1.0, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                ],
                [
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, -1.0, -1.0],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                ],
                [
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, -1.0, -1.0],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                ],
                self.nan_array,
                self.nan_array,
                self.nan_array,
                self.nan_array,
                self.nan_array,
            ],
            dtype=np.float32,
        )

        matching_cost_ = matching_cost.AbstractMatchingCost(
            **{
                "matching_cost_method": "mc_cnn",
                "window_size": 11,
                "subpix": 1,
                "model_path": "tests/weights/mc_cnn_fast_mb_weights.pt",
            }
        )
        cv = matching_cost_.compute_cost_volume(left, right, disp_min=-1, disp_max=1)

        # Check if the calculated cost volume is equal to the ground truth (same shape and all elements equals)
        np.testing.assert_allclose(cv["cost_volume"].data, cv_before_invali, rtol=1e-06)

        # Masked cost volume with pandora function
        matching_cost_.cv_masked(left, right, cv, -1, 1)
        # Check if the calculated cost volume is equal to the ground truth (same shape and all elements equals)
        np.testing.assert_allclose(cv["cost_volume"].data, cv_ground_truth, rtol=1e-06)
        # ------------ Test the method with a secondary mask ( reference mask contains valid pixels ) ------------
        # Mask convention
        # cfg['image']['valid_pixels'] = 0
        # cfg['image']['no_data'] = 1
        # invalid_pixels all other values
        data = np.zeros((13, 13), dtype=np.float64)
        data += 0.1
        # Reference mask contains valid pixels
        mask = np.zeros((13, 13), dtype=np.int16)

        left = xr.Dataset(
            {"im": (["row", "col"], data), "msk": (["row", "col"], mask)},
            coords={"row": np.arange(data.shape[0]), "col": np.arange(data.shape[1])},
        )
        left.attrs["valid_pixels"] = 0
        left.attrs["no_data_mask"] = 1
        left.attrs["crs"] = None
        left.attrs["transform"] = None

        data = np.zeros((13, 13), dtype=np.float64)
        data += 0.1
        mask = np.array(
            (
                [
                    [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                    [0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0],
                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                    [0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0],
                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                    [0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                ]
            ),
            dtype=np.int16,
        )

        right = xr.Dataset(
            {"im": (["row", "col"], data), "msk": (["row", "col"], mask)},
            coords={"row": np.arange(data.shape[0]), "col": np.arange(data.shape[1])},
        )
        right.attrs["valid_pixels"] = 0
        right.attrs["no_data_mask"] = 1

        # Cost volume before invalidation, disparities = -1, 0, 1
        cv_before_invali = np.array(
            [
                self.nan_array,
                self.nan_array,
                self.nan_array,
                self.nan_array,
                self.nan_array,
                [
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, -1.0, -1.0],
                    [-1.0, -1.0, -1.0],
                    [-1.0, -1.0, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                ],
                [
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, -1.0, -1.0],
                    [-1.0, -1.0, -1.0],
                    [-1.0, -1.0, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                ],
                [
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, -1.0, -1.0],
                    [-1.0, -1.0, -1.0],
                    [-1.0, -1.0, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                ],
                self.nan_array,
                self.nan_array,
                self.nan_array,
                self.nan_array,
                self.nan_array,
            ],
            dtype=np.float32,
        )

        # Cost volume ground truth after invalidation
        cv_ground_truth = np.array(
            [
                self.nan_array,
                self.nan_array,
                self.nan_array,
                self.nan_array,
                self.nan_array,
                [
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, -1.0],
                    [np.nan, -1.0, -1.0],
                    [-1.0, -1.0, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                ],
                [
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, -1.0, np.nan],
                    [-1.0, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                ],
                [
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, -1.0, np.nan],
                    [-1.0, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                    [np.nan, np.nan, np.nan],
                ],
                self.nan_array,
                self.nan_array,
                self.nan_array,
                self.nan_array,
                self.nan_array,
            ],
            dtype=np.float32,
        )

        matching_cost_ = matching_cost.AbstractMatchingCost(
            **{
                "matching_cost_method": "mc_cnn",
                "window_size": 11,
                "subpix": 1,
                "model_path": "tests/weights/mc_cnn_fast_mb_weights.pt",
            }
        )
        cv = matching_cost_.compute_cost_volume(left, right, disp_min=-1, disp_max=1)

        # Check if the calculated cost volume is equal to the ground truth (same shape and all elements equals)
        np.testing.assert_allclose(cv["cost_volume"].data, cv_before_invali, rtol=1e-06)

        # Masked cost volume with pandora function
        matching_cost_.cv_masked(left, right, cv, -1, 1)

        # Check if the calculated cost volume is equal to the ground truth (same shape and all elements equals)
        np.testing.assert_allclose(cv["cost_volume"].data, cv_ground_truth, rtol=1e-06)


if __name__ == "__main__":
    unittest.main()
