import json
import os
from unittest.mock import MagicMock, mock_open, patch
import pytest
from amazon_sagemaker_jupyter_scheduler.environment_detector import (
    JupyterLabEnvironment,
)

from amazon_sagemaker_jupyter_scheduler.internal_metadata_adapter import (
    InternalMetadataAdapter,
)

SAGEMAKER_INTERNAL_METADATA_FILE = "/opt/.sagemakerinternal/internal-metadata.json"

INTERNAL_METADATA_CONTENT = {
    "Stage": "Production",
    "FirstPartyImages": ["image1", "image2"],
    "CustomImages": ["custom_image1", "custom_image2"],
}


@pytest.fixture(autouse=True)
def mock_jupyter_lab_environment():
    with patch(
        "amazon_sagemaker_jupyter_scheduler.internal_metadata_adapter.get_sagemaker_environment",
        return_value=JupyterLabEnvironment.SAGEMAKER_STUDIO,
    ):
        yield


class TestInternalMetadataAdapterStudio:
    @patch("os.path.getmtime")
    @patch("builtins.open", mock_open(read_data=json.dumps(INTERNAL_METADATA_CONTENT)))
    def test_init(self, getmtime_mock):
        getmtime_mock.return_value = 0
        adapter = InternalMetadataAdapter()

        assert adapter.config_file == SAGEMAKER_INTERNAL_METADATA_FILE
        assert adapter.metadata == INTERNAL_METADATA_CONTENT

    @patch("os.path.getmtime")
    @patch("builtins.open", mock_open(read_data=json.dumps(INTERNAL_METADATA_CONTENT)))
    def test_get_stage(self, getmtime_mock):
        getmtime_mock.return_value = 0
        adapter = InternalMetadataAdapter()

        assert adapter.get_stage() == "Production"

    @patch("os.path.getmtime")
    @patch("builtins.open")
    def test_get_images(self, open_mock, getmtime_mock):
        # Initial file content
        open_mock.return_value = mock_open(
            read_data=json.dumps(INTERNAL_METADATA_CONTENT)
        ).return_value

        adapter = InternalMetadataAdapter()
        assert adapter.get_first_party_images() == ["image1", "image2"]
        assert adapter.get_custom_images() == ["custom_image1", "custom_image2"]

        # Updated file content 1st time
        updated_content = {
            "Stage": "Production",
            "FirstPartyImages": ["new_image1", "new_image2"],
            "CustomImages": ["new_custom_image1", "new_custom_image2"],
        }
        open_mock.return_value = mock_open(
            read_data=json.dumps(updated_content)
        ).return_value
        getmtime_mock.side_effect = [0, 1]
        assert adapter.get_first_party_images() == ["new_image1", "new_image2"]

        open_mock.return_value = mock_open(
            read_data=json.dumps(updated_content)
        ).return_value
        getmtime_mock.side_effect = [0, 1]
        assert adapter.get_custom_images() == ["new_custom_image1", "new_custom_image2"]

        # Updated file content 2nd time
        other_content = {
            "Stage": "Production",
            "FirstPartyImages": ["other_image1", "other_image2"],
            "CustomImages": ["other_custom_image1", "other_custom_image2"],
        }
        open_mock.return_value = mock_open(
            read_data=json.dumps(other_content)
        ).return_value
        getmtime_mock.side_effect = [0, 1]
        assert adapter.get_first_party_images() == ["other_image1", "other_image2"]

        open_mock.return_value = mock_open(
            read_data=json.dumps(other_content)
        ).return_value
        getmtime_mock.side_effect = [0, 1]
        assert adapter.get_custom_images() == [
            "other_custom_image1",
            "other_custom_image2",
        ]
