from typing import List, Optional, Tuple
from sqlalchemy import select, func, or_
from sqlalchemy.ext.asyncio import AsyncSession
from fastapi import HTTPException, status

from app.models.product import Product
from app.models.violation import Violation
from app.models.scraping_result import ScrapingResult
from app.schemas.product import ProductCreate, ProductUpdate


class ProductService:
    @staticmethod
    async def get_violation_count(db: AsyncSession, product_id: int) -> int:
        """Count violations for a specific product."""
        # Get product first
        product_result = await db.execute(select(Product).where(Product.id == product_id))
        product = product_result.scalars().first()
        
        if not product:
            return 0
        
        # Count violations matching this product's name
        count_result = await db.execute(
            select(func.count(Violation.id)).where(
                Violation.product_name == product.product_name
            )
        )
        return count_result.scalar() or 0

    @staticmethod
    async def enrich_product_with_violations(db: AsyncSession, product: Product) -> Product:
        """Add violation_count to a product object by querying violations matching product name."""
        # Count violations matching this product's name
        count_result = await db.execute(
            select(func.count(Violation.id)).where(
                Violation.product_name == product.product_name
            )
        )
        violation_count = count_result.scalar() or 0
        product.violation_count = violation_count  # type: ignore
        return product

    @staticmethod
    async def get_products(
        db: AsyncSession,
        page: int = 1,
        limit: int = 10,
        sort_by: str = "product_name",
        search: Optional[str] = None,
    ) -> Tuple[List[Product], int]:
        offset = (page - 1) * limit
        query = select(Product)

        if search:
            query = query.where(Product.product_name.ilike(f"%{search}%"))

        # Sorting logic
        if sort_by == "msp":
            query = query.order_by(Product.msp)
        elif sort_by == "product_name":
            query = query.order_by(Product.product_name)
        elif sort_by == "last_scraped_date":
            # For now sorting by updated_at as placeholder for last_scraped_date
            query = query.order_by(Product.updated_at.desc())
        else:
            query = query.order_by(Product.product_name)

        # Get total count for pagination
        count_query = select(func.count()).select_from(query.subquery())
        total = await db.scalar(count_query)

        # Apply pagination
        query = query.offset(offset).limit(limit)
        result = await db.execute(query)
        products = result.scalars().all()

        # Enrich products with violation counts
        enriched_products = []
        for product in products:
            product = await ProductService.enrich_product_with_violations(db, product)
            enriched_products.append(product)

        return enriched_products, total

    @staticmethod
    async def create_product(db: AsyncSession, product_in: ProductCreate) -> Product:
        # Check if barcode already exists
        existing_product = await db.execute(
            select(Product).where(Product.barcode == product_in.barcode)
        )
        if existing_product.scalars().first():
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail="Barcode already exists",
            )

        db_product = Product(**product_in.model_dump())
        db.add(db_product)
        await db.commit()
        await db.refresh(db_product)
        return db_product

    @staticmethod
    async def update_product(
        db: AsyncSession, product_id: int, product_in: ProductUpdate
    ) -> Product:
        result = await db.execute(select(Product).where(Product.id == product_id))
        db_product = result.scalars().first()
        if not db_product:
            raise HTTPException(
                status_code=status.HTTP_404_NOT_FOUND,
                detail="Product not found",
            )

        update_data = product_in.model_dump(exclude_unset=True)
        for field, value in update_data.items():
            setattr(db_product, field, value)

        await db.commit()
        await db.refresh(db_product)
        return db_product

    @staticmethod
    async def delete_product(db: AsyncSession, product_id: int) -> None:
        result = await db.execute(select(Product).where(Product.id == product_id))
        db_product = result.scalars().first()
        if not db_product:
            raise HTTPException(
                status_code=status.HTTP_404_NOT_FOUND,
                detail="Product not found",
            )

        await db.delete(db_product)
        await db.commit()

    @staticmethod
    async def get_product_by_id(db: AsyncSession, product_id: int) -> Product:
        result = await db.execute(select(Product).where(Product.id == product_id))
        db_product = result.scalars().first()
        if not db_product:
            raise HTTPException(
                status_code=status.HTTP_404_NOT_FOUND,
                detail="Product not found",
            )
        # Enrich product with violation count
        db_product = await ProductService.enrich_product_with_violations(db, db_product)
        return db_product
