from __future__ import annotations

import json
import logging
import re
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence

import requests

from .config import APIConfig

SYSTEM_PROMPT = (
"""
You are a strict JSON classifier assessing whether company reports are sustainability-related.

The user will provide text extracted from the first pages of a document (English or Italian).

Your task:
1. Decide if the document is a sustainability report.
   - Reply "true" if you are confident it is.
   - Reply "false" if you are confident it is not.
   - Certainty is the level of the confidence of your decision ( certainty is a number in range 0.0 to 1.0)
2. If the classification is true, extract:
   - Company name
   - Year of the report
3. If the classification is false or "unknown", set company and year to null.

Output requirements:
- Respond with one JSON object only.
- Do not include any explanations, comments, or text outside the JSON.
- The JSON must match this exact schema:
"""
    "{\n"
    "  \"is_sustainability_report\": true/false,\n"
    "  \"certainty\": 0.0-1.0,\n"
    "  \"company\": \"Company Name or null\",\n"
    "  \"year\": \"YYYY or null\"\n"
    "}\n"
)


class LLMCallError(RuntimeError):
    """Raised when an LLM call fails."""


@dataclass
class LLMResponse:
    is_sustainability_report: bool
    certainty: float
    company: Optional[str]
    year: Optional[str]
    raw_content: str
    config_name: str
    model: str
    api_base: str
    messages: List[Dict[str, str]]
    raw_api_response: Dict[str, Any]


class LLMClient:
    def __init__(self, configs: List[APIConfig], cooldown_seconds: float = 60.0, timeout_seconds: int = 300) -> None:
        self.configs = configs
        self.cooldown_seconds = cooldown_seconds
        self.timeout_seconds = timeout_seconds

    def classify(self, chunks: Sequence[str]) -> Optional[LLMResponse]:
        if not chunks:
            logging.warning("No chunks available for LLM classification")
            return None
        messages = self._build_messages(chunks)
        last_error: Optional[Exception] = None
        for attempt in range(2):  # two full passes over the config list
            for config in self.configs:
                if not config.is_available():
                    continue
                try:
                    response = self._call_config(config, messages)
                    return response
                except Exception as exc:
                    logging.warning("LLM call failed with config %s: %s", config.name, exc)
                    config.block(self.cooldown_seconds)
                    last_error = exc
            if attempt == 0 and last_error is not None:
                time.sleep(1.0)
        if last_error:
            logging.error("All LLM configs failed: %s", last_error)
        return None

    @staticmethod
    def _build_messages(chunks: Sequence[str]) -> List[Dict[str, str]]:
        user_content_parts = [f"[Chunk {idx}]\n{chunk}" for idx, chunk in enumerate(chunks, start=1)]
        user_content = "\n\n".join(user_content_parts)
        doc_content  = f"[Document]\n{user_content}"
        return [
            {"role": "user", "content": SYSTEM_PROMPT},
            {"role": "user", "content": doc_content},
        ]

    def _call_config(self, config: APIConfig, messages: List[Dict[str, str]]) -> LLMResponse:
        url = f"{config.api_base}/chat/completions"
        payload = {
            "model": config.model,
            "messages": messages,
            "temperature": 0,
            "response_format": {"type": "json_object"},
        }
        response = requests.post(
            url,
            headers=config.headers,
            json=payload,
            timeout=self.timeout_seconds,
        )
        if response.status_code != 200:
            raise LLMCallError(f"HTTP {response.status_code}: {response.text}")
        try:
            data = response.json()
        except json.JSONDecodeError as exc:
            raise LLMCallError(f"Invalid JSON response: {response.text}") from exc
        try:
            content = data["choices"][0]["message"]["content"]
            parsed = self._extract_json_payload(content)
            if parsed is None:
                raise LLMCallError(f"Model returned non-JSON content: {content}")

            is_report = bool(parsed.get("is_sustainability_report"))
            certainty = float(parsed.get("certainty", 0))
            company = parsed.get("company")
            if company in ("", None):
                company = None
            year = parsed.get("year")
            if year in ("", None):
                year = None
        except (KeyError, IndexError) as exc:
            raise LLMCallError(f"Unexpected response schema: {data}") from exc

        return LLMResponse(
            is_sustainability_report=is_report,
            certainty=certainty,
            company=company,
            year=year,
            raw_content=content,
            config_name=config.name,
            model=config.model,
            api_base=config.api_base,
            messages=[dict(message) for message in messages],
            raw_api_response=data,
        )

    @staticmethod
    def _extract_json_payload(raw_response: str) -> Optional[Dict[str, Any]]:
        try:
            match = re.search(r"(\{.*\}|\[.*\])", raw_response, re.DOTALL)
            if not match:
                logging.warning("No JSON found in LLM response: %s", raw_response)
                return None
            json_str = match.group(1).strip()
            return json.loads(json_str)
        except json.JSONDecodeError as exc:
            logging.error("Failed to parse JSON from LLM response: %s", exc)
            return None
