"""
FAISS-based Menu Search Service
================================
Ultra-fast vector search for menu items using FAISS.

Features:
- 0.1-0.5 ms search time
- Supports Arabic and English
- Word correction and synonym expansion
- Loads from Supabase on startup
- ~2 MB memory footprint
"""

import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from typing import List, Tuple, Optional
from loguru import logger
import asyncio
from functools import lru_cache
import json
import os
import re


class FAISSMenuSearch:
    """FAISS-based menu search with multilingual support"""

    def __init__(self):
        self.model: Optional[SentenceTransformer] = None
        self.index: Optional[faiss.Index] = None
        self.menu_items: List[dict] = []
        self.embeddings: Optional[np.ndarray] = None
        self._loaded = False
        self.training_data: Optional[dict] = None
        self._load_training_data()
        
    def _load_training_data(self):
        """Load training data for query expansion"""
        try:
            training_file = os.path.join(
                os.path.dirname(__file__),
                '../../data/user_queries_training.json'
            )
            if os.path.exists(training_file):
                with open(training_file, 'r', encoding='utf-8') as f:
                    self.training_data = json.load(f)
                logger.info(f"✅ Loaded {len(self.training_data.get('queries', []))} training queries")
            else:
                logger.warning(f"⚠️ Training data file not found: {training_file}")
                self.training_data = {"queries": []}
        except Exception as e:
            logger.error(f"❌ Error loading training data: {e}")
            self.training_data = {"queries": []}

    def _load_model(self):
        """Load sentence transformer model (lazy loading)"""
        if self.model is None:
            logger.info("🔄 Loading multilingual sentence transformer model...")
            # This model supports 50+ languages including Arabic and English
            # Size: ~120 MB, Dimension: 384
            self.model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
            logger.info("✅ Model loaded successfully")
    
    async def load_menu(self, items: List[dict]):
        """
        Load menu items and build FAISS index
        
        Args:
            items: List of menu items from Supabase
        """
        if not items:
            logger.warning("⚠️ No menu items to load")
            return
        
        logger.info(f"🔄 Building FAISS index for {len(items)} menu items...")
        
        # Load model if not loaded
        self._load_model()
        
        # Store menu items
        self.menu_items = items
        
        # Create text representations for embedding
        texts = []
        for item in items:
            # Combine all searchable fields
            # Support both DB format (description_ar) and local format (description)
            text_parts = [
                item.get('name_ar', ''),
                item.get('name_en', ''),
                item.get('description_ar', item.get('description', '')),
                item.get('description_en', ''),
                item.get('category', '')
            ]

            # Add tags if available
            tags = item.get('tags', [])
            if tags and isinstance(tags, list):
                text_parts.extend(tags)

            # === NEW: Add aliases for better search ===
            aliases = item.get('aliases', [])
            if aliases and isinstance(aliases, list):
                text_parts.extend(aliases)

            # === NEW: Add search_keywords for better semantic matching ===
            keywords = item.get('search_keywords', [])
            if keywords and isinstance(keywords, list):
                text_parts.extend(keywords)

            text = ' '.join(filter(None, text_parts))
            texts.append(text)
        
        # Generate embeddings (run in thread pool to avoid blocking)
        loop = asyncio.get_event_loop()
        self.embeddings = await loop.run_in_executor(
            None,
            lambda: self.model.encode(texts, convert_to_numpy=True, show_progress_bar=False)
        )
        
        # Build FAISS index
        dimension = self.embeddings.shape[1]
        
        # Use IndexFlatIP for cosine similarity (faster for small datasets)
        # Normalize embeddings for cosine similarity
        faiss.normalize_L2(self.embeddings)
        
        self.index = faiss.IndexFlatIP(dimension)  # Inner Product (cosine after normalization)
        self.index.add(self.embeddings.astype('float32'))
        
        self._loaded = True
        
        logger.info(f"✅ FAISS index built successfully")
        logger.info(f"   - Items: {len(items)}")
        logger.info(f"   - Dimension: {dimension}")
        logger.info(f"   - Memory: ~{(self.embeddings.nbytes / 1024 / 1024):.1f} MB")
    
    def _correct_query(self, query: str) -> str:
        """
        Correct common misspellings in query

        Args:
            query: Original query

        Returns:
            Corrected query
        """
        # Import here to avoid circular dependency
        try:
            from app.services.word_correction_dict import word_correction_dict
            corrected = word_correction_dict.correct_query(query)
            return corrected
        except Exception as e:
            logger.warning(f"⚠️ Word correction failed: {e}")
            return query

    def _normalize_arabic_text(self, text: str) -> str:
        """
        تطبيع النص العربي:
        - إزالة "ال" التعريف
        - توحيد الهمزات
        - إزالة التشكيل
        """
        # إزالة التشكيل
        text = re.sub(r'[\u064B-\u065F]', '', text)

        # توحيد الهمزات
        text = re.sub(r'[إأآا]', 'ا', text)
        text = re.sub(r'[ىي]', 'ي', text)

        # إزالة "ال" التعريف من بداية الكلمات
        words = text.split()
        normalized_words = []

        for word in words:
            # إذا كانت الكلمة تبدأ بـ "ال" وطولها أكثر من 3 أحرف
            if word.startswith('ال') and len(word) > 3:
                word = word[2:]
            normalized_words.append(word)

        return ' '.join(normalized_words)

    def _expand_query_with_variations(self, query: str) -> List[str]:
        """
        Expand query with known variations from training data

        Args:
            query: Original query

        Returns:
            List of query variations
        """
        # تطبيع النص العربي أولاً
        normalized_query = self._normalize_arabic_text(query)

        variations = [query]

        # إضافة النسخة المطبعة إذا كانت مختلفة
        if normalized_query != query:
            variations.append(normalized_query)

        if not self.training_data:
            return variations

        query_lower = query.lower().strip()

        # Check training data for variations
        for item in self.training_data.get('queries', []):
            user_query = item.get('user_query', '').lower()
            item_variations = item.get('variations', [])

            # If query matches a known variation, add the main query
            if query_lower in [v.lower() for v in item_variations]:
                if user_query not in [v.lower() for v in variations]:
                    variations.append(item['user_query'])

            # If query matches main query, add all variations
            elif query_lower == user_query:
                for var in item_variations:
                    if var.lower() not in [v.lower() for v in variations]:
                        variations.append(var)

        if len(variations) > 1:
            logger.info(f"🔍 Expanded '{query}' to {len(variations)} variations")

        return variations[:5]  # Limit to 5 variations

    async def search(
        self,
        query: str,
        top_k: int = 5,  # ← تقليل من 10 إلى 5
        min_score: float = 0.35,  # ← زيادة من 0.3 إلى 0.35
        keyword_filter: bool = True,
        exact_match_boost: bool = True
    ) -> List[Tuple[dict, float]]:
        """
        Search for menu items using FAISS with optional keyword filtering

        التحسينات:
        - تقليل top_k الافتراضي إلى 5 لنتائج أكثر دقة
        - زيادة min_score إلى 0.35 لتجنب النتائج الضعيفة
        - إضافة exact_match_boost لتعزيز التطابق التام
        - فلتر must_contain_core قبل إرجاع النتائج

        Args:
            query: Search query (Arabic or English)
            top_k: Number of results to return (default: 5)
            min_score: Minimum similarity score (default: 0.35)
            keyword_filter: If True, filter results to include query keyword
            exact_match_boost: If True, boost exact matches to top

        Returns:
            List of (item, score) tuples sorted by score
        """
        if not self._loaded or not self.index:
            logger.warning("⚠️ FAISS index not loaded")
            return []

        if not query or not query.strip():
            return []

        try:
            # Step 1: Correct query
            corrected_query = self._correct_query(query)

            # Step 2: Expand with variations
            query_variations = self._expand_query_with_variations(corrected_query)

            # Step 3: Search with all variations and combine results
            all_results = {}  # Use dict to deduplicate by item code

            for q in query_variations:
                # Encode query
                loop = asyncio.get_event_loop()
                query_embedding = await loop.run_in_executor(
                    None,
                    lambda: self.model.encode([q], convert_to_numpy=True, show_progress_bar=False)
                )

                # Normalize for cosine similarity
                faiss.normalize_L2(query_embedding)

                # Search (get more results for filtering)
                search_k = top_k * 3 if keyword_filter else top_k
                scores, indices = self.index.search(query_embedding.astype('float32'), search_k)

                # Convert to results and add to combined results
                for score, idx in zip(scores[0], indices[0]):
                    if idx < len(self.menu_items) and score >= min_score:
                        item = self.menu_items[idx]
                        item_code = item.get('code')
                        similarity = float(score) * 100

                        # Keep highest score for each item
                        if item_code not in all_results or all_results[item_code][1] < similarity:
                            all_results[item_code] = (item, similarity)

            # Convert dict back to list
            combined_results = list(all_results.values())

            # Sort by score
            combined_results.sort(key=lambda x: x[1], reverse=True)

            # Apply exact match boost if enabled
            if exact_match_boost and combined_results:
                combined_results = self._boost_exact_matches(query, combined_results)
                logger.debug(f"🎯 Applied exact match boost")

            # Apply keyword filtering if enabled
            if keyword_filter and combined_results:
                # Try filtering with original query first
                filtered_results = self._filter_by_keyword(query, combined_results)

                # If no results, try with corrected query
                if not filtered_results and corrected_query != query:
                    filtered_results = self._filter_by_keyword(corrected_query, combined_results)

                # Fallback to unfiltered if no matches
                if not filtered_results:
                    logger.warning(
                        f"⚠️ FAISS keyword filter removed all {len(combined_results)} results for '{query}'. "
                        f"Falling back to semantic results."
                    )
                    results = combined_results[:top_k]

                    # Log mismatch warning for top result
                    if results:
                        top_item = results[0][0]
                        logger.warning(
                            f"⚠️ SEMANTIC MISMATCH: Top result '{top_item.get('name_ar')}' "
                            f"does not contain keyword '{query}' but has high semantic similarity"
                        )
                else:
                    results = filtered_results[:top_k]
                    logger.info(
                        f"✅ Keyword filter: {len(filtered_results)} relevant results "
                        f"(from {len(combined_results)} semantic matches)"
                    )
            else:
                results = combined_results[:top_k]

            # Log results with scores
            if results:
                scores_str = ', '.join([f"{item['name_ar']}: {score:.1f}%" for item, score in results])
                logger.info(f"🔍 FAISS search '{query}' found {len(results)} results: {scores_str}")
            else:
                logger.warning(f"⚠️ FAISS search '{query}' found 0 results (min_score={min_score*100:.0f}%)")

            return results

        except Exception as e:
            logger.error(f"❌ Error searching: {e}")
            return []

    def _boost_exact_matches(
        self,
        query: str,
        results: List[Tuple[dict, float]]
    ) -> List[Tuple[dict, float]]:
        """
        تعزيز التطابق التام في النتائج
        يعطي نقاط إضافية للأصناف التي تطابق الاستعلام تماماً

        Args:
            query: استعلام البحث
            results: قائمة النتائج (item, score)

        Returns:
            قائمة النتائج مع تعزيز التطابق التام
        """
        query_lower = query.lower().strip()
        boosted_results = []

        for item, score in results:
            name_ar = (item.get('name_ar') or '').lower()
            name_en = (item.get('name_en') or '').lower()

            # تحقق من التطابق التام
            is_exact_match = (
                query_lower == name_ar or
                query_lower == name_en or
                query_lower in name_ar.split() or
                query_lower in name_en.split()
            )

            # تعزيز النقاط للتطابق التام
            if is_exact_match:
                boosted_score = min(score + 20.0, 100.0)  # إضافة 20 نقطة
                logger.debug(f"🎯 Exact match boost: '{item.get('name_ar')}' {score:.1f}% → {boosted_score:.1f}%")
                boosted_results.append((item, boosted_score))
            else:
                boosted_results.append((item, score))

        # إعادة الترتيب بعد التعزيز
        boosted_results.sort(key=lambda x: x[1], reverse=True)

        return boosted_results

    def _filter_by_keyword(
        self,
        query: str,
        results: List[Tuple[dict, float]]
    ) -> List[Tuple[dict, float]]:
        """
        Filter FAISS results to only include items containing the query keyword

        This reduces semantic hallucinations by ensuring results actually match
        the user's search term, not just semantically similar items.

        ENHANCED: Now checks aliases and search_keywords for better matching

        Args:
            query: Original search query
            results: List of (item, score) tuples from FAISS

        Returns:
            Filtered list of (item, score) tuples
        """
        query_lower = query.lower().strip()
        filtered = []

        for item, score in results:
            name_ar = (item.get('name_ar') or '').lower()
            name_en = (item.get('name_en') or '').lower()
            tags = [tag.lower() for tag in (item.get('tags') or [])]

            # === NEW: Get aliases and keywords ===
            aliases = [alias.lower() for alias in (item.get('aliases') or [])]
            keywords = [kw.lower() for kw in (item.get('search_keywords') or [])]

            # Check if keyword appears in name, aliases, keywords, or tags
            # Priority 1: Exact match in name
            if query_lower == name_ar or query_lower == name_en:
                filtered.append((item, score))
                continue

            # === NEW: Priority 2: Exact match in aliases ===
            if query_lower in aliases:
                filtered.append((item, score))
                continue

            # Priority 3: Word match in name
            name_ar_words = name_ar.split()
            name_en_words = name_en.split()
            if query_lower in name_ar_words or query_lower in name_en_words:
                filtered.append((item, score))
                continue

            # === NEW: Priority 4: Match in aliases (partial) ===
            if any(query_lower in alias or alias in query_lower for alias in aliases):
                filtered.append((item, score))
                continue

            # Priority 5: Partial match in name
            if query_lower in name_ar or query_lower in name_en:
                filtered.append((item, score))
                continue

            # === NEW: Priority 6: Match in search_keywords ===
            if any(query_lower in kw or kw in query_lower for kw in keywords):
                filtered.append((item, score))
                continue

            # Priority 7: Match in tags
            if query_lower in tags:
                filtered.append((item, score))
                continue

        return filtered
    
    def is_loaded(self) -> bool:
        """Check if index is loaded"""
        return self._loaded
    
    async def reload_menu(self, items: List[dict]):
        """Reload menu items (useful when menu is updated)"""
        logger.info("🔄 Reloading menu...")
        await self.load_menu(items)
    
    def get_stats(self) -> dict:
        """Get index statistics"""
        if not self._loaded:
            return {"loaded": False}

        return {
            "loaded": True,
            "items_count": len(self.menu_items),
            "dimension": self.embeddings.shape[1] if self.embeddings is not None else 0,
            "memory_mb": (self.embeddings.nbytes / 1024 / 1024) if self.embeddings is not None else 0,
            "model": "paraphrase-multilingual-MiniLM-L12-v2"
        }

    async def search_by_category(self, category: str) -> List[dict]:
        """Search items by category"""
        if not self._loaded:
            return []

        return [
            item for item in self.menu_items
            if item.get('category', '').lower() == category.lower()
        ]

    async def get_all_categories(self) -> List[str]:
        """Get all unique categories"""
        if not self._loaded:
            return []

        categories = set()
        for item in self.menu_items:
            category = item.get('category')
            if category:
                categories.add(category)
        return sorted(list(categories))


# Global instance
faiss_menu_search = FAISSMenuSearch()

