"""
Training script for Multimodal (Text + Image) Parallel-LLM
Demonstrates training a model that can handle both text and image inputs.

KNOWN ISSUE: PyTorch has compatibility issues on Windows.
This script will demonstrate the import structure and provide guidance.
For actual execution, use Linux/WSL with CUDA support.
"""
import os
import sys

print("="*60)
print("Parallel-LLM Multimodal Training Example")
print("="*60)

# Check Python version
print(f"Python version: {sys.version}")

# Check platform
print(f"Platform: {sys.platform}")

# Add project root to sys.path to allow importing parallel_llm from source
# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
# print(f"Added to path: {os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))}")

# Try to import dependencies
PYTORCH_AVAILABLE = False
PARALLEL_LLM_AVAILABLE = False

print("\n[1/3] Checking PyTorch availability...")
try:
    import torch
    import torch.distributed as dist
    from torch.utils.data import DataLoader, DistributedSampler
    from transformers import AutoTokenizer, AutoImageProcessor
    from datasets import load_dataset
    PYTORCH_AVAILABLE = True
    print("✅ PyTorch, transformers, and datasets imported successfully")
except Exception as e:
    print("❌ PyTorch/transformers/datasets not available:")
    print(f"   Error: {e}")
    print("   This is a known issue on Windows with PyTorch binaries.")
    print("   Solutions:")
    print("   - Use WSL (Windows Subsystem for Linux)")
    print("   - Use a Linux environment")
    print("   - Use Docker with CUDA support")

if PYTORCH_AVAILABLE:
    print("\n[2/3] Checking parallel_llm package...")
    try:
        # Updated imports
        from parallel_llm.core import DiffusionTransformer, MultimodalConfig
        from parallel_llm.training import DistributedTrainer, TrainingConfig
        from parallel_llm.utils import MultimodalDataset
        PARALLEL_LLM_AVAILABLE = True
        print("✅ parallel_llm package imported successfully")
    except ImportError as e:
        print("❌ parallel_llm package not available:")
        print(f"   Error: {e}")
        print("   Please install the package: pip install -e .")
else:
    print("\n[2/3] Skipping parallel_llm check (PyTorch not available)")

print("\n[3/3] Configuration check...")
if PYTORCH_AVAILABLE and PARALLEL_LLM_AVAILABLE:
    print("✅ All dependencies available - example can run")
else:
    print("❌ Dependencies missing - example cannot run on this system")
    print("\nTo run this example:")
    print("1. Use Linux or WSL environment")
    print("2. Install CUDA 12.1+")
    print("3. Install PyTorch: pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121")
    print("4. Install dependencies: pip install transformers datasets")
    print("5. Install parallel-llm: pip install -e .")

def setup_distributed():
    if "LOCAL_RANK" in os.environ:
        dist.init_process_group(backend="nccl")
        local_rank = int(os.environ["LOCAL_RANK"])
        torch.cuda.set_device(local_rank)
        return local_rank
    return 0

def main():
    if PYTORCH_AVAILABLE and PARALLEL_LLM_AVAILABLE:
        print("\n🚀 Running actual multimodal training...")

        local_rank = setup_distributed()
        is_main_process = local_rank == 0

        if is_main_process:
            print("="*50)
            print("Starting Multimodal Training")
            print("="*50)

        # 1. Multimodal Configuration (smaller for demo)
        model_config = MultimodalConfig(
            # Text parameters
            vocab_size=50257,
            hidden_size=256,  # Smaller for demo
            num_hidden_layers=4,  # Fewer layers for demo

            # Vision parameters
            vision_encoder="vit",
            image_size=224,
            patch_size=16,
            vision_hidden_size=384,  # Smaller ViT

            # Fusion parameters
            fusion_type="cross_attention",
            num_cross_attention_layers=2,  # Fewer layers for demo

            # Training objectives
            use_contrastive=False,  # Disable for demo compatibility
            contrastive_temperature=0.07
        )

        train_config = TrainingConfig(
            output_dir="./checkpoints/multimodal",
            num_train_steps=25,    # Very small for ultra-quick demo
            batch_size=2,  # Smaller batch size for demo
            learning_rate=1e-3,    # Higher learning rate for demo
            warmup_steps=5,        # Short warmup
            mixed_precision="no",  # Disable for compatibility
            gradient_checkpointing=False,  # Disable for compatibility
            use_fsdp=False,  # Disable FSDP for single-device
            use_deepspeed=False,  # Disable DeepSpeed for single-device
            logging_steps=5,
            save_steps=25,
            eval_steps=10,
            use_wandb=False
        )

        # 2. Data Preparation
        if is_main_process:
            print("Loading processors and dataset...")

        tokenizer = AutoTokenizer.from_pretrained("gpt2")
        tokenizer.pad_token = tokenizer.eos_token

        # Use a standard ViT image processor
        image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")

        # Example: Load COCO captions or similar dataset
        # For demo purposes, we'll assume a dataset with 'image' and 'caption' columns
        # dataset = load_dataset("coco", split="train")

        # Mock dataset for demonstration if real one isn't available
        class MockMultimodalDataset(torch.utils.data.Dataset):
            def __init__(self, length=25):  # Reduced for demo
                self.length = length
            def __len__(self): return self.length
            def __getitem__(self, idx):
                # Create mock image tensor normalized to [0,1] range
                mock_image = torch.rand(3, 224, 224)  # Uniform [0,1] instead of normal distribution
                return {
                    "image": mock_image,  # Mock image tensor in [0,1] range
                    "text": "A description of an image."
                }

        dataset = MockMultimodalDataset()

        # Use the utility class from parallel_llm.utils
        # Note: In a real scenario, pass the HuggingFace dataset directly
        # train_dataset = MultimodalDataset(dataset, tokenizer, image_processor, text_column="caption")

        # For this mock, we'll just use a simple wrapper or the mock itself if it returns tensors
        # But let's show how to use the library's dataset class properly with a real-ish structure
        train_dataset = MultimodalDataset(
            dataset=[{"image": torch.randn(3, 224, 224), "text": "demo"} for _ in range(100)],
            tokenizer=tokenizer,
            image_processor=image_processor,
            text_column="text",
            image_column="image"
        )

        sampler = None
        # Only use distributed sampler if actually running distributed training
        if "LOCAL_RANK" in os.environ and dist.is_initialized():
            sampler = DistributedSampler(train_dataset)

        train_dataloader = DataLoader(
            train_dataset,
            batch_size=train_config.batch_size,
            sampler=sampler,
            shuffle=(sampler is None),
            num_workers=2
        )

        # 3. Model & Trainer
        if is_main_process:
            print("Initializing Multimodal DiffusionTransformer...")

        model = DiffusionTransformer(model_config)

        trainer = DistributedTrainer(
            model=model,
            train_config=train_config,
            model_config=model_config,
            train_dataloader=train_dataloader
        )

        # 4. Train
        if is_main_process:
            print("Starting training...")

        trainer.train()
    else:
        print("\n📋 Example structure demonstration:")
        print("This example would perform the following steps:")
        print("1. Set up distributed training environment")
        print("2. Configure Multimodal DiffusionTransformer with ViT vision encoder")
        print("3. Set up cross-attention fusion and contrastive learning")
        print("4. Load GPT-2 tokenizer and ViT image processor")
        print("5. Create MultimodalDataset from image-text pairs")
        print("6. Initialize DistributedTrainer with gradient checkpointing")
        print("7. Run training loop with 25K steps and mixed precision")
        print("\nTo run this example, use a Linux environment with CUDA support.")

if __name__ == "__main__":
    main()
