"""
Fact-checking module using Gemini Pro with strict verification.
Maximum 2 retry attempts before failing.
"""

import json
import logging
import re
import time
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Optional

from google import genai
from google.genai import types

from config.settings import get_settings
from database.models import Topic

logger = logging.getLogger(__name__)

# Model fallback chain (best to most stable)
TEXT_MODELS = [
    "gemini-3-pro-preview",
    "gemini-3-flash-preview",
    "gemini-2.5-pro",
]

MAX_RETRIES = 3
RETRY_DELAY = 30


class FactVerdict(str, Enum):
    VERIFIED = "verified"
    MINOR_VARIATION = "minor_variation"
    INCORRECT = "incorrect"
    UNVERIFIABLE = "unverifiable"
    OUTDATED = "outdated"


@dataclass
class FactCheckResult:
    passed: bool
    verdict: str  # PASS, NEEDS_REVISION, FAIL
    issues: list[dict] = field(default_factory=list)
    summary: str = ""
    confidence: float = 0.0


def get_gemini_client() -> genai.Client:
    """Initialize Gemini client for fact-checking."""
    settings = get_settings()
    return genai.Client(api_key=settings.gemini_api_key)


def fact_check(article_content: str) -> FactCheckResult:
    """
    Perform strict fact-check on article using Gemini Pro with web search.

    Identifies factual claims and verifies each one.
    """
    prompt = _load_prompt("fact_check.txt").format(
        article_content=article_content[:15000]  # Truncate for token limits
    )

    try:
        client = get_gemini_client()
        google_search_tool = types.Tool(google_search=types.GoogleSearch())

        # Model fallback with retry logic
        response = None
        last_error = None

        for model in TEXT_MODELS:
            for attempt in range(MAX_RETRIES):
                try:
                    logger.debug(f"Trying model {model} (attempt {attempt + 1}/{MAX_RETRIES})")
                    response = client.models.generate_content(
                        model=model,
                        contents=prompt,
                        config=types.GenerateContentConfig(
                            tools=[google_search_tool],
                        ),
                    )
                    break  # Success - exit retry loop
                except Exception as e:
                    last_error = e
                    if "429" in str(e) or "RESOURCE_EXHAUSTED" in str(e):
                        if attempt < MAX_RETRIES - 1:
                            logger.warning(f"Rate limited on {model}, waiting {RETRY_DELAY}s (attempt {attempt + 1}/{MAX_RETRIES})")
                            time.sleep(RETRY_DELAY)
                            continue
                    # Non-retryable or exhausted retries - try next model
                    logger.warning(f"Model {model} failed: {e}")
                    break
            else:
                # No break = all retries failed, try next model
                continue
            if response is not None:
                break  # Success - exit model loop

        if response is None:
            raise Exception(f"All models failed. Last error: {last_error}")

        result = _parse_fact_check_response(response.text)

        logger.info(
            f"Fact-check completed: {result.verdict} "
            f"({len(result.issues)} issues, confidence: {result.confidence:.0%})"
        )
        return result

    except Exception as e:
        logger.error(f"Fact-check failed: {e}")
        return FactCheckResult(
            passed=False,
            verdict="FAIL",
            summary=f"Fact-check error: {str(e)}",
            confidence=0.0,
        )


def fact_check_with_retry(
    article_content: str,
    topic: Topic,
    max_attempts: int = 2,
    regenerate_func=None,
) -> tuple[bool, str, dict]:
    """
    Fact-check with retry logic.

    Args:
        article_content: The article markdown to check
        topic: The topic for regeneration context
        max_attempts: Maximum fact-check attempts (default 2)
        regenerate_func: Function to regenerate article if needed

    Returns:
        (passed, final_content, fact_check_log)
    """
    current_content = article_content
    log = {"attempts": [], "final_verdict": None}

    for attempt in range(max_attempts):
        logger.info(f"Fact-check attempt {attempt + 1}/{max_attempts}")

        result = fact_check(current_content)
        log["attempts"].append(
            {
                "attempt": attempt + 1,
                "verdict": result.verdict,
                "issues": result.issues,
                "confidence": result.confidence,
            }
        )

        if result.verdict == "PASS":
            log["final_verdict"] = "PASS"
            return True, current_content, log

        # NEEDS_REVISION with high confidence is acceptable
        if result.verdict == "NEEDS_REVISION" and result.confidence >= 0.75:
            # Check if issues are all minor (MINOR_VARIATION or UNVERIFIABLE)
            critical_issues = [
                i for i in result.issues
                if i.get("category", "").upper() in ("INCORRECT", "OUTDATED")
            ]
            if len(critical_issues) <= 2:
                log["final_verdict"] = "PASS_WITH_WARNINGS"
                logger.info(f"Fact-check passed with warnings: {len(result.issues)} minor issues")
                return True, current_content, log

        if result.verdict == "FAIL":
            # Immediate failure - contradicting evidence found
            log["final_verdict"] = "FAIL"
            logger.error(f"Fact-check FAIL: {result.summary}")
            return False, current_content, log

        if result.verdict == "NEEDS_REVISION" and regenerate_func:
            # Try to fix issues through regeneration
            if attempt < max_attempts - 1:
                logger.info("Attempting article revision based on fact-check feedback")
                revision_prompt = _build_revision_prompt(
                    topic, current_content, result.issues
                )
                try:
                    current_content = regenerate_func(topic, revision_prompt)
                except Exception as e:
                    logger.error(f"Revision failed: {e}")
                    continue
            else:
                # Last attempt failed
                log["final_verdict"] = "NEEDS_REVISION"
                return False, current_content, log

    log["final_verdict"] = "MAX_ATTEMPTS_REACHED"
    return False, current_content, log


def _load_prompt(filename: str) -> str:
    """Load prompt template from prompts directory."""
    prompt_path = Path(__file__).parent.parent / "prompts" / filename
    if not prompt_path.exists():
        raise FileNotFoundError(f"Prompt file not found: {prompt_path}")
    return prompt_path.read_text()


def _build_revision_prompt(topic: Topic, content: str, issues: list[dict]) -> str:
    """Build a revision prompt based on fact-check issues."""
    issues_text = "\n".join(
        [
            f"- {issue.get('claim', 'Unknown claim')}: {issue.get('issue', 'Needs verification')}"
            for issue in issues
        ]
    )

    return f"""Überarbeite den folgenden Artikel und korrigiere die identifizierten Probleme.

# PROBLEME ZU BEHEBEN
{issues_text}

# ORIGINAL-ARTIKEL
{content}

# ANWEISUNGEN
1. Korrigiere ALLE genannten Probleme mit aktuellen, verifizierbaren Fakten
2. Behalte die Struktur und den Stil bei
3. Stelle sicher, dass alle Zahlen und Behauptungen korrekt sind
4. Füge Quellenangaben hinzu wo möglich

Liefere den korrigierten Artikel:
"""


def _parse_fact_check_response(response_text: str) -> FactCheckResult:
    """Parse Gemini fact-check response into FactCheckResult."""
    try:
        text = response_text.strip()

        # Extract JSON from response
        if "```json" in text:
            json_text = text.split("```json")[1].split("```")[0]
        elif "```" in text:
            json_text = text.split("```")[1].split("```")[0]
        else:
            json_text = text

        data = json.loads(json_text)

        verdict = data.get("verdict", "FAIL").upper()
        if verdict not in ["PASS", "NEEDS_REVISION", "FAIL"]:
            verdict = "NEEDS_REVISION"

        issues = data.get("issues", [])
        confidence = float(data.get("confidence", 0.5))

        # Determine passed based on verdict
        passed = verdict == "PASS"

        return FactCheckResult(
            passed=passed,
            verdict=verdict,
            issues=issues,
            summary=data.get("summary", ""),
            confidence=confidence,
        )

    except (json.JSONDecodeError, KeyError, ValueError) as e:
        logger.warning(f"Failed to parse fact-check response: {e}")
        # Conservative default: needs revision
        return FactCheckResult(
            passed=False,
            verdict="NEEDS_REVISION",
            summary=f"Parse error: {str(e)}",
            confidence=0.0,
        )
