"""
OCR Engine with multi-backend support and memory management.

SOLUTION 1: Simplified prompt design.
- Fewer markers (reduced from 20+ to ~10 core markers)
- Uses ## HEADING ## instead of [HEADING]...[/HEADING] to avoid bracket priming
- Paragraph breaks signaled by blank lines (natural for VLMs)
- Explicit instructions against code fences and preamble
- Custom tags use the same simplified format
"""

import gc
import os
import torch
from typing import List, Tuple, Optional, Callable, Dict
from PIL import Image
from .device import get_best_device, get_device_info, get_torch_dtype

os.environ.setdefault("PYTORCH_MPS_HIGH_WATERMARK_RATIO", "0.0")


class OCREngine:
    MODEL_OPTIONS = {
        "qwen2.5-vl-7b": "Qwen/Qwen2.5-VL-7B-Instruct",
        "qwen2.5-vl-3b": "Qwen/Qwen2.5-VL-3B-Instruct",
    }
    DEFAULT_MODEL = "qwen2.5-vl-7b"

    MPS_MAX_DIM = 1280
    CUDA_MAX_DIM = 2048
    CPU_MAX_DIM = 1024

    def __init__(self, model_key: str = None, cache_dir: str = None):
        self.model = None
        self.processor = None
        self.model_key = model_key or self.DEFAULT_MODEL
        self.model_id = self.MODEL_OPTIONS.get(self.model_key, self.model_key)
        self.cache_dir = cache_dir
        self.device_info = get_device_info()
        self.device = self.device_info["device"]

    def is_loaded(self) -> bool:
        return self.model is not None

    def get_max_image_dim(self) -> int:
        if self.device == "mps":
            return self.MPS_MAX_DIM
        elif self.device == "cuda":
            return self.CUDA_MAX_DIM
        return self.CPU_MAX_DIM

    def load_model(self, progress_callback: Optional[Callable] = None):
        if self.is_loaded():
            return

        if progress_callback:
            progress_callback(f"Detected device: {self.device_info['device_name']}")

        from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor

        if progress_callback:
            progress_callback("Loading tokenizer and processor...")

        proc_kwargs = {"trust_remote_code": True}
        if self.cache_dir:
            proc_kwargs["cache_dir"] = self.cache_dir
        self.processor = AutoProcessor.from_pretrained(self.model_id, **proc_kwargs)

        if progress_callback:
            progress_callback("Loading model weights (this may take a few minutes on first run)...")

        dtype = get_torch_dtype(self.device_info["recommended_dtype"])
        model_kwargs = {"trust_remote_code": True, "torch_dtype": dtype}
        if self.cache_dir:
            model_kwargs["cache_dir"] = self.cache_dir

        if self.device == "cuda":
            model_kwargs["device_map"] = "auto"
            try:
                model_kwargs["attn_implementation"] = "flash_attention_2"
            except Exception:
                pass
            try:
                from transformers import BitsAndBytesConfig
                model_kwargs["quantization_config"] = BitsAndBytesConfig(
                    load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16,
                    bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
            except ImportError:
                pass
        elif self.device == "mps":
            model_kwargs["device_map"] = None
            if progress_callback:
                progress_callback("Loading for Apple Silicon (MPS)...")
        else:
            model_kwargs["device_map"] = None

        self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            self.model_id, **model_kwargs)

        if self.device == "mps" and not hasattr(self.model, "hf_device_map"):
            self.model = self.model.to("mps")

        self.model.eval()
        if progress_callback:
            progress_callback("Model loaded successfully.")

    def unload_model(self):
        if self.model is not None:
            del self.model
            self.model = None
        if self.processor is not None:
            del self.processor
            self.processor = None
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        if hasattr(torch.mps, "empty_cache"):
            torch.mps.empty_cache()

    def _flush_memory(self):
        gc.collect()
        if self.device == "cuda":
            torch.cuda.empty_cache()
        elif self.device == "mps":
            if hasattr(torch.mps, "empty_cache"):
                torch.mps.empty_cache()
            if hasattr(torch.mps, "synchronize"):
                torch.mps.synchronize()

    def _downscale_for_device(self, image: Image.Image) -> Image.Image:
        max_dim = self.get_max_image_dim()
        w, h = image.size
        if w <= max_dim and h <= max_dim:
            return image
        scale = min(max_dim / w, max_dim / h)
        new_w, new_h = int(w * scale), int(h * scale)
        return image.resize((new_w, new_h), Image.LANCZOS)

    def _run_inference(self, image: Image.Image, prompt: str, max_tokens: int = 4096) -> str:
        if not self.is_loaded():
            raise RuntimeError("Model not loaded.")

        image = self._downscale_for_device(image)

        messages = [{"role": "user", "content": [
            {"type": "image", "image": image},
            {"type": "text", "text": prompt},
        ]}]

        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True)

        from qwen_vl_utils import process_vision_info
        image_inputs, video_inputs = process_vision_info(messages)

        inputs = self.processor(
            text=[text], images=image_inputs, videos=video_inputs,
            padding=True, return_tensors="pt")

        target_device = self.model.device if hasattr(self.model, "device") else self.device
        inputs = {k: v.to(target_device) if hasattr(v, "to") else v for k, v in inputs.items()}

        with torch.inference_mode():
            output_ids = self.model.generate(
                **inputs, max_new_tokens=max_tokens,
                temperature=0.1, top_p=0.9, do_sample=True,
                repetition_penalty=1.05)

        generated_ids = output_ids[:, inputs["input_ids"].shape[1]:]
        result = self.processor.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

        del inputs, output_ids, generated_ids
        self._flush_memory()

        return result.strip()

    def extract_text_from_page(self, image: Image.Image, page_label: str = "1",
                                genre: str = "auto",
                                custom_tags: List[Dict] = None) -> str:
        prompt = self._build_prompt(genre, custom_tags)
        return self._run_inference(image, prompt)

    def infer_metadata(self, images: List[Image.Image]) -> Dict:
        prompt = (
            "Examine this page image carefully. It may be a title page or first page of a book, "
            "poem, play, letter, or other historical document.\n\n"
            "Extract the following bibliographic metadata if visible. "
            "Respond ONLY with a JSON object. No markdown fences, no explanation:\n"
            '{"title": "...", "author": "...", "date": "...", "publisher": "...", '
            '"place": "...", "language": "ISO 639 code e.g. en, fr, de, la, grc"}\n\n'
            "Use empty strings for fields not found. For language, identify the "
            "primary language of the main text (not front matter if different)."
        )
        raw = self._run_inference(images[0], prompt, max_tokens=512)

        import json
        try:
            raw = raw.strip()
            if raw.startswith("```"):
                raw = raw.split("\n", 1)[1] if "\n" in raw else raw[3:]
            if raw.endswith("```"):
                raw = raw[:-3]
            raw = raw.strip()
            if raw.startswith("json"):
                raw = raw[4:].strip()
            return json.loads(raw)
        except (json.JSONDecodeError, ValueError):
            return {}

    def detect_language(self, image: Image.Image) -> str:
        prompt = (
            "What is the primary language of the text in this image? "
            "Respond with ONLY the ISO 639 language code (e.g., en, fr, de, la, grc). "
            "No explanation."
        )
        raw = self._run_inference(image, prompt, max_tokens=32)
        code = raw.strip().lower().split()[0] if raw.strip() else "en"
        code = code.strip(".,;:\"'`")
        return code if len(code) <= 5 else "en"

    def _build_prompt(self, genre: str, custom_tags: List[Dict] = None) -> str:
        """Build a simplified OCR prompt.

        Design principles:
        - Fewer tags to reduce confusion and bracket priming
        - Use ## ... ## for headings instead of [HEADING]...[/HEADING]
        - Paragraph breaks = blank lines (natural for VLMs)
        - Only use paired tags for elements that require content extraction
        - Explicit anti-code-fence and anti-preamble instructions
        """
        prompt = (
            "Transcribe ALL text visible in this image exactly as printed. "
            "Preserve original spelling, punctuation, and capitalization.\n\n"
            "Format your output as follows. Output ONLY the transcription, "
            "with no preamble, no commentary, and no code fences.\n\n"
            "STRUCTURAL RULES:\n"
            "- Mark headings/titles on their own line: ## Heading text ##\n"
            "- Separate paragraphs with a blank line.\n"
            "- Mark the running header (repeated title at page top): RH: text\n"
            "- Mark the page number: PN: number\n"
            "- Mark footnotes inline: FN{{footnote text}}\n"
            "- Mark italic text: _italic text_\n"
            "- Mark bold text: **bold text**\n"
            "- Mark small capitals: SC{{text}}\n"
            "- Mark decorative/dropped initial capitals: DI{{letter}}\n"
            "- Mark figures or illustrations: FIG{{description}}\n"
            "- Mark catchwords at page bottom: CW{{word}}\n"
        )

        genre_extras = {
            "poetry": (
                "- Mark each verse line on its own line (no blank line between lines in a stanza).\n"
                "- Separate stanzas with a blank line.\n"
                "- Mark indented lines by starting with >>.\n"
            ),
            "drama": (
                "- Mark speaker names: SP{{name}}\n"
                "- Mark stage directions: SD{{direction}}\n"
            ),
            "manuscript": (
                "- Mark deleted/crossed-out text: DEL{{text}}\n"
                "- Mark added/inserted text: ADD{{text}}\n"
                "- Mark unclear text: UNC{{text}}\n"
                "- Mark illegible text: [illegible]\n"
                "- Mark salutations: SAL{{text}}\n"
                "- Mark closings: CLOSE{{text}}\n"
                "- Mark signatures: SIGN{{text}}\n"
                "- Mark datelines: DATE{{text}}\n"
            ),
            "auto": (
                "- If this is poetry, put each verse line on its own line, separate stanzas with blank lines.\n"
                "- If this is a play, mark speakers as SP{{name}} and stage directions as SD{{direction}}.\n"
                "- If this is a letter, mark salutations, closings, and signatures.\n"
            ),
        }

        prompt += genre_extras.get(genre, genre_extras["auto"])

        if custom_tags:
            prompt += "\nAlso mark the following:\n"
            for tag in custom_tags:
                name = tag.get("name", "").upper()
                desc = tag.get("description", "")
                prompt += f"- Mark {desc}: {name}{{{{{desc}}}}}\n"

        prompt += (
            "\nIMPORTANT: Start the transcription immediately. "
            "Do NOT wrap your output in ``` code fences. "
            "Do NOT include any introductory text like 'Here is the transcription'. "
            "Do NOT add line numbers. Just output the text with the markup above."
        )

        return prompt

    def process_pages(self, pages: List[Tuple[Image.Image, str]], genre: str = "auto",
                       custom_tags: List[Dict] = None,
                       progress_callback: Optional[Callable] = None) -> List[Tuple[str, str]]:
        results = []
        total = len(pages)
        for i, (image, label) in enumerate(pages):
            if progress_callback:
                progress_callback(f"OCR page {label} ({i+1}/{total})...", i + 1, total)
            text = self.extract_text_from_page(image, label, genre, custom_tags)
            results.append((text, label))
            self._flush_memory()
        return results
