#!/usr/bin/env python3
"""
Rebuild FAISS Index
===================
Rebuilds the FAISS index with fresh data from Supabase.
Run this script weekly or when menu items are updated.

Usage:
    python3 scripts/rebuild_faiss_index.py
"""

import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import asyncio
from supabase import create_client
from app.config import settings
from app.services.faiss_menu_search import FAISSMenuSearch
from loguru import logger


async def rebuild_index():
    """Rebuild FAISS index from Supabase"""
    
    print("\n" + "="*60)
    print("🔄 Rebuilding FAISS Index")
    print("="*60)
    
    # Step 1: Connect to Supabase
    print("\n📡 Step 1: Connecting to Supabase...")
    try:
        supabase = create_client(settings.SUPABASE_URL, settings.SUPABASE_KEY)
        print("✅ Connected to Supabase")
    except Exception as e:
        print(f"❌ Failed to connect: {e}")
        return False
    
    # Step 2: Fetch menu items
    print("\n📥 Step 2: Fetching menu items...")
    try:
        result = supabase.table('menu_items')\
            .select('*')\
            .eq('active', True)\
            .execute()
        
        if not result.data:
            print("❌ No menu items found")
            return False
        
        menu_items = result.data
        print(f"✅ Fetched {len(menu_items)} menu items")
        
        # Show categories
        categories = {}
        for item in menu_items:
            cat = item.get('category') or 'أخرى'
            categories[cat] = categories.get(cat, 0) + 1

        print("\n📊 Items by category:")
        for cat, count in categories.items():
            print(f"   • {cat}: {count} items")
    
    except Exception as e:
        print(f"❌ Failed to fetch menu items: {e}")
        return False
    
    # Step 3: Enhance menu items with category context
    print("\n🔨 Step 3: Enhancing menu items with category context...")
    try:
        enhanced_items = []
        for item in menu_items:
            enhanced_item = item.copy()

            # Add category context to description for better embeddings
            category = item.get('category', 'أخرى')
            description_ar = item.get('description_ar', '')
            description_en = item.get('description_en', '')

            # Enhance Arabic description
            if description_ar:
                enhanced_item['description_ar'] = f"الفئة: {category}. {description_ar}"
            else:
                enhanced_item['description_ar'] = f"الفئة: {category}"

            # Enhance English description
            if description_en:
                enhanced_item['description_en'] = f"Category: {category}. {description_en}"
            else:
                enhanced_item['description_en'] = f"Category: {category}"

            enhanced_items.append(enhanced_item)

        print(f"✅ Enhanced {len(enhanced_items)} items with category context")

    except Exception as e:
        print(f"❌ Failed to enhance items: {e}")
        return False

    # Step 4: Build FAISS index
    print("\n🔨 Step 4: Building FAISS index...")
    try:
        faiss_search = FAISSMenuSearch()
        await faiss_search.load_menu(enhanced_items)

        stats = faiss_search.get_stats()
        print(f"✅ FAISS index built successfully")
        print(f"\n📊 Index statistics:")
        print(f"   • Items: {stats['items_count']}")
        print(f"   • Dimension: {stats['dimension']}")
        print(f"   • Memory: {stats['memory_mb']:.2f} MB")
        print(f"   • Model: {stats['model']}")

    except Exception as e:
        print(f"❌ Failed to build index: {e}")
        return False
    
    # Step 5: Test search
    print("\n🧪 Step 5: Testing search...")
    test_queries = ['هريس', 'مصلونة', 'شوربة', 'سلطة', 'حمص']
    
    all_passed = True
    for query in test_queries:
        try:
            results = await faiss_search.search(query, top_k=3, keyword_filter=True)
            if results:
                print(f"   ✅ '{query}': {len(results)} results")
                for item, score in results[:2]:  # Show top 2
                    print(f"      • {item['name_ar']} ({score:.1f}%)")
            else:
                print(f"   ⚠️  '{query}': No results")
                all_passed = False
        except Exception as e:
            print(f"   ❌ '{query}': Error - {e}")
            all_passed = False
    
    # Step 6: Performance test
    print("\n⚡ Step 6: Performance test...")
    import time
    
    total_time = 0
    iterations = 10
    
    for i in range(iterations):
        start = time.time()
        await faiss_search.search('هريس', top_k=3)
        elapsed = (time.time() - start) * 1000
        total_time += elapsed
    
    avg_time = total_time / iterations
    print(f"   Average search time: {avg_time:.2f} ms ({iterations} iterations)")
    
    if avg_time < 200:
        print(f"   ✅ Performance is good")
    else:
        print(f"   ⚠️  Performance is slower than expected")
    
    # Summary
    print("\n" + "="*60)
    if all_passed:
        print("✅ FAISS index rebuilt successfully!")
    else:
        print("⚠️  FAISS index rebuilt with warnings")
    print("="*60)
    
    return all_passed


async def main():
    """Main function"""
    success = await rebuild_index()
    return 0 if success else 1


if __name__ == "__main__":
    exit_code = asyncio.run(main())
    sys.exit(exit_code)

