from typing import List, Optional, Tuple
from uuid import UUID
from datetime import datetime
from sqlalchemy import select, func, desc, asc, and_, or_, nulls_last, nulls_first
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from fastapi import HTTPException, status

from app.models.violation import Violation
from app.models.vendor import Vendor
from app.models.product import Product
from app.schemas.violation import ViolationCreate, ViolationUpdate


class ViolationService:
    @staticmethod
    def _calculate_and_set_price_differences(violation: Violation) -> Violation:
        """Calculate and set price differences for a violation if they're null."""
        if violation.price_difference is None or violation.percentage_difference is None:
            price_diff = float(violation.msp) - float(violation.scraped_price)
            percentage_diff = (price_diff / float(violation.msp)) * 100 if violation.msp > 0 else 0
            violation.price_difference = price_diff
            violation.percentage_difference = percentage_diff
        return violation

    @staticmethod
    async def _enrich_reference_id(db: AsyncSession, violation: Violation) -> Violation:
        """
        Populate reference_id from Product table if it's NULL but barcode_number exists.
        Uses barcode to lookup the product and get its reference_id.
        """
        if violation.reference_id is None and violation.barcode_number:
            try:
                product_result = await db.execute(
                    select(Product).where(Product.barcode == violation.barcode_number)
                )
                product = product_result.scalars().first()
                if product and product.reference_id:
                    violation.reference_id = str(product.reference_id)
            except Exception as e:
                # If lookup fails, just continue with NULL reference_id
                pass
        return violation
    @staticmethod
    async def get_violations(
        db: AsyncSession,
        page: int = 1,
        limit: int = 10,
        sort_by: str = "violation_date",
        sort_order: str = "desc",
        product_name: Optional[str] = None,
        search: Optional[str] = None,
        vendor_name: Optional[str] = None,
        marketplace: Optional[str] = None,
        url: Optional[str] = None,
        vendor_id: Optional[UUID] = None,
        source_type: Optional[str] = None,
        violation_status: Optional[str] = None,
        date_from: Optional[datetime] = None,
        date_to: Optional[datetime] = None,
    ) -> Tuple[List[Violation], int]:
        """
        Retrieve violations with pagination, sorting, filtering and searching.
        
        Args:
            db: Database session
            page: Page number (1-indexed)
            limit: Items per page
            sort_by: Field to sort by (violation_date, created_at, product_name, price_difference, percentage_difference, vendor_name)
            product_name: Search term for product_name
            search: Search term for product_name (alternative parameter)
            vendor_name: Search term for vendor_name
            marketplace: Search term for marketplace
            url: Search term for URL
            vendor_id: Filter by vendor ID
            source_type: Filter by source type (registered, discovered)
            violation_status: Filter by status (open, notified)
            date_from: Filter violations from this date
            date_to: Filter violations until this date
        
        Returns:
            Tuple of (violations list, total count)
        """
        offset = (page - 1) * limit
        query = select(Violation).options(joinedload(Violation.vendor))

        # Build filters
        filters = []

        # Search filters (partial match)
        # If `search` is provided, treat it as a global search across key fields.
        if search:
            like_term = f"%{search}%"
            filters.append(
                or_(
                    Violation.product_name.ilike(like_term),
                    Violation.vendor_name.ilike(like_term),
                    Violation.marketplace.ilike(like_term),
                    Violation.url.ilike(like_term),
                )
            )
        elif product_name:
            filters.append(Violation.product_name.ilike(f"%{product_name}%"))

        if vendor_name:
            filters.append(Violation.vendor_name.ilike(f"%{vendor_name}%"))

        if marketplace:
            filters.append(Violation.marketplace.ilike(f"%{marketplace}%"))

        if url:
            filters.append(Violation.url.ilike(f"%{url}%"))

        # Status filter (based on notification_sent_at field)
        if violation_status:
            if violation_status == "notified":
                filters.append(Violation.notification_sent_at.isnot(None))
            elif violation_status == "open":
                filters.append(Violation.notification_sent_at.is_(None))

        if vendor_id:
            filters.append(Violation.vendor_id == vendor_id)

        if source_type:
            filters.append(Violation.source_type == source_type)

        if date_from:
            filters.append(Violation.violation_date >= date_from)

        if date_to:
            filters.append(Violation.violation_date <= date_to)

        if filters:
            query = query.where(and_(*filters))

        # Sorting logic
        # sort_order: asc / desc
        order_fn = asc if sort_order == "asc" else desc
        if sort_by == "created_at":
            query = query.order_by(nulls_last(order_fn(Violation.created_at)))
        elif sort_by == "product_name":
            query = query.order_by(nulls_last(order_fn(Violation.product_name)))
        elif sort_by == "price_difference":
            query = query.order_by(nulls_last(order_fn(Violation.price_difference)))
        elif sort_by == "percentage_difference":
            query = query.order_by(nulls_last(order_fn(Violation.percentage_difference)))
        elif sort_by == "vendor_name":
            query = query.order_by(nulls_last(order_fn(Violation.vendor_name)))
        elif sort_by == "violation_date":
            query = query.order_by(nulls_last(order_fn(Violation.violation_date)))
        else:
            query = query.order_by(nulls_last(desc(Violation.violation_date)))

        # Get total count for pagination
        count_query = select(func.count()).select_from(Violation)
        if filters:
            count_query = count_query.where(and_(*filters))
        total = await db.scalar(count_query)

        # Apply pagination
        query = query.offset(offset).limit(limit)
        result = await db.execute(query)
        violations = result.scalars().unique().all()
        
        # Enrich violations with calculated price differences and reference_id
        enriched_violations = []
        for violation in violations:
            ViolationService._calculate_and_set_price_differences(violation)
            await ViolationService._enrich_reference_id(db, violation)
            enriched_violations.append(violation)

        return enriched_violations, total or 0

    @staticmethod
    async def get_violation_by_id(db: AsyncSession, violation_id: UUID) -> Violation:
        """Get a single violation by ID."""
        result = await db.execute(
            select(Violation)
            .options(joinedload(Violation.vendor))
            .where(Violation.id == violation_id)
        )
        violation = result.scalars().unique().first()
        
        if not violation:
            raise HTTPException(
                status_code=status.HTTP_404_NOT_FOUND,
                detail="Violation not found",
            )
        
        # Enrich with price differences and reference_id
        ViolationService._calculate_and_set_price_differences(violation)
        await ViolationService._enrich_reference_id(db, violation)
        return violation

    @staticmethod
    async def create_violation(db: AsyncSession, violation_in: ViolationCreate) -> Violation:
        """Create a new violation record."""
        # Calculate price difference if not provided
        price_diff = float(violation_in.msp) - float(violation_in.scraped_price)
        percentage_diff = (price_diff / float(violation_in.msp)) * 100 if violation_in.msp > 0 else 0

        db_violation = Violation(
            **violation_in.model_dump(),
            price_difference=price_diff,
            percentage_difference=percentage_diff,
        )
        db.add(db_violation)
        await db.commit()
        await db.refresh(db_violation)
        
        # Enrich with reference_id from Product if available
        await ViolationService._enrich_reference_id(db, db_violation)
        
        return db_violation

    @staticmethod
    async def update_violation(
        db: AsyncSession, violation_id: UUID, violation_in: ViolationUpdate
    ) -> Violation:
        """Update an existing violation."""
        violation = await ViolationService.get_violation_by_id(db, violation_id)

        update_data = violation_in.model_dump(exclude_unset=True)
        
        # Recalculate price differences if price fields are updated
        if "msp" in update_data or "scraped_price" in update_data:
            msp = float(update_data.get("msp", violation.msp))
            scraped_price = float(update_data.get("scraped_price", violation.scraped_price))
            price_diff = msp - scraped_price
            percentage_diff = (price_diff / msp) * 100 if msp > 0 else 0
            
            update_data["price_difference"] = price_diff
            update_data["percentage_difference"] = percentage_diff

        for field, value in update_data.items():
            setattr(violation, field, value)

        db.add(violation)
        await db.commit()
        await db.refresh(violation)
        return violation

    @staticmethod
    async def delete_violation(db: AsyncSession, violation_id: UUID) -> None:
        """Delete a violation record."""
        violation = await ViolationService.get_violation_by_id(db, violation_id)
        await db.delete(violation)
        await db.commit()

    @staticmethod
    async def get_violations_by_vendor(
        db: AsyncSession, vendor_id: UUID, page: int = 1, limit: int = 10
    ) -> Tuple[List[Violation], int]:
        """Get all violations for a specific vendor."""
        return await ViolationService.get_violations(
            db, page=page, limit=limit, vendor_id=vendor_id
        )

    @staticmethod
    async def get_statistics(db: AsyncSession) -> dict:
        """Get violation statistics."""
        total = await db.scalar(select(func.count(Violation.id)))
        
        # Count by source type
        registered = await db.scalar(
            select(func.count(Violation.id)).where(Violation.source_type == "registered")
        )
        discovered = await db.scalar(
            select(func.count(Violation.id)).where(Violation.source_type == "discovered")
        )

        # Average price difference
        avg_price_diff = await db.scalar(
            select(func.avg(Violation.price_difference))
        )

        # Average percentage difference
        avg_percentage_diff = await db.scalar(
            select(func.avg(Violation.percentage_difference))
        )

        return {
            "total_violations": total or 0,
            "registered_violations": registered or 0,
            "discovered_violations": discovered or 0,
            "average_price_difference": float(avg_price_diff) if avg_price_diff else 0.0,
            "average_percentage_difference": float(avg_percentage_diff) if avg_percentage_diff else 0.0,
        }


