"""
Test suite to validate the parallel-llm package structure and basic functionality.
These tests ensure the package is properly configured and can be imported.
"""

import sys
import pytest
from importlib.metadata import version, entry_points


class TestPackageStructure:
    """Test package structure and metadata"""
    
    def test_package_import(self):
        """Test that the main package can be imported"""
        import parallel_llm
        assert parallel_llm is not None
    
    def test_package_version(self):
        """Test that version is accessible and matches expected format"""
        import parallel_llm
        assert hasattr(parallel_llm, '__version__')
        assert isinstance(parallel_llm.__version__, str)
        # Version should be in format X.Y.Z
        parts = parallel_llm.__version__.split('.')
        assert len(parts) >= 2, f"Invalid version format: {parallel_llm.__version__}"
    
    def test_package_version_matches_metadata(self):
        """Test that __version__ matches package metadata"""
        import parallel_llm
        pkg_version = version('parallel-llm')
        assert parallel_llm.__version__ == pkg_version


class TestCoreImports:
    """Test that all core classes and functions can be imported"""
    
    def test_import_config_classes(self):
        """Test configuration classes import"""
        from parallel_llm import (
            ModelConfig,
            MultimodalConfig,
            TrainingConfig,
            InferenceConfig,
            get_default_config,
        )
        assert ModelConfig is not None
        assert MultimodalConfig is not None
        assert TrainingConfig is not None
        assert InferenceConfig is not None
        assert callable(get_default_config)
    
    def test_import_diffusion_transformer(self):
        """Test DiffusionTransformer import"""
        from parallel_llm import DiffusionTransformer
        assert DiffusionTransformer is not None
    
    def test_import_trainer(self):
        """Test DistributedTrainer import"""
        from parallel_llm import DistributedTrainer
        assert DistributedTrainer is not None
    
    def test_import_generator(self):
        """Test ParallelGenerator and related imports"""
        from parallel_llm import (
            ParallelGenerator,
            GenerationConfig,
            create_generator,
        )
        assert ParallelGenerator is not None
        assert GenerationConfig is not None
        assert callable(create_generator)
    
    def test_all_exports(self):
        """Test that __all__ contains expected exports"""
        import parallel_llm
        assert hasattr(parallel_llm, '__all__')
        expected_exports = [
            '__version__',
            'ModelConfig',
            'MultimodalConfig',
            'TrainingConfig',
            'InferenceConfig',
            'get_default_config',
            'DiffusionTransformer',
            'DistributedTrainer',
            'ParallelGenerator',
            'GenerationConfig',
            'create_generator',
        ]
        for export in expected_exports:
            assert export in parallel_llm.__all__, f"{export} not in __all__"


class TestConfigInstantiation:
    """Test that configuration classes can be instantiated"""
    
    def test_model_config_creation(self):
        """Test ModelConfig can be created with default values"""
        from parallel_llm import ModelConfig
        try:
            config = ModelConfig(
                vocab_size=1000,
                hidden_size=256,
                num_hidden_layers=4,
                num_attention_heads=4,
            )
            assert config.vocab_size == 1000
            assert config.hidden_size == 256
        except Exception as e:
            pytest.fail(f"ModelConfig instantiation failed: {e}")
    
    def test_training_config_creation(self):
        """Test TrainingConfig can be created with default values"""
        from parallel_llm import TrainingConfig
        try:
            config = TrainingConfig(
                batch_size=8,
                learning_rate=1e-4,
            )
            assert config.batch_size == 8
            assert config.learning_rate == 1e-4
        except Exception as e:
            pytest.fail(f"TrainingConfig instantiation failed: {e}")
    
    def test_generation_config_creation(self):
        """Test GenerationConfig can be created"""
        from parallel_llm import GenerationConfig
        try:
            config = GenerationConfig(
                max_new_tokens=100,
                temperature=1.0,
            )
            assert config.max_new_tokens == 100
            assert config.temperature == 1.0
        except Exception as e:
            pytest.fail(f"GenerationConfig instantiation failed: {e}")


class TestCLIEntryPoints:
    """Test that CLI entry points are registered"""
    
    def test_train_cli_registered(self):
        """Test that parallel-llm-train entry point is registered"""
        eps = entry_points()
        # Try both old and new API
        if hasattr(eps, 'select'):
            # Python 3.10+
            console_scripts = eps.select(group='console_scripts')
        else:
            # Python 3.9
            console_scripts = eps.get('console_scripts', [])
        
        script_names = [ep.name for ep in console_scripts]
        assert 'parallel-llm-train' in script_names, \
            "parallel-llm-train entry point not registered"
    
    def test_infer_cli_registered(self):
        """Test that parallel-llm-infer entry point is registered"""
        eps = entry_points()
        if hasattr(eps, 'select'):
            console_scripts = eps.select(group='console_scripts')
        else:
            console_scripts = eps.get('console_scripts', [])
        
        script_names = [ep.name for ep in console_scripts]
        assert 'parallel-llm-infer' in script_names, \
            "parallel-llm-infer entry point not registered"


class TestModuleStructure:
    """Test internal module structure"""
    
    def test_cli_module_exists(self):
        """Test that cli module exists"""
        from parallel_llm import cli
        assert cli is not None
    
    def test_config_module_exists(self):
        """Test that config module exists"""
        from parallel_llm import config
        assert config is not None
    
    def test_diffusion_transformer_module_exists(self):
        """Test that diffusion_transformer module exists"""
        from parallel_llm import diffusion_transformer
        assert diffusion_transformer is not None
    
    def test_trainer_module_exists(self):
        """Test that trainer module exists"""
        from parallel_llm import trainer
        assert trainer is not None
    
    def test_parallel_generator_module_exists(self):
        """Test that parallel_generator module exists"""
        from parallel_llm import parallel_generator
        assert parallel_generator is not None


if __name__ == '__main__':
    # Run tests with pytest
    pytest.main([__file__, '-v'])
