import logging
from datetime import datetime, timedelta
from typing import Optional, Dict, Tuple
from sqlalchemy import select, func, and_, desc, or_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload

from app.models.violation import Violation
from app.models.product import Product
from app.models.vendor import Vendor
from app.schemas.dashboard import (
    DashboardResponse,
    KPISummary,
    ActiveViolationsKPI,
    MonitoredProductsKPI,
    AveragePriceDiscountKPI,
    ViolationTrendChart,
    TrendDataPoint,
    ViolationAnalysis,
    MarketplaceAnalysis,
    VendorAnalysis,
    ProductAnalysis,
    HighestDiscountAnalysis,
    FinancialImpactAnalysis,
    RecentViolationsPanel,
    RecentViolationItem,
)

logger = logging.getLogger(__name__)


class DashboardService:
    """Service for dashboard data aggregation and analytics."""

    @staticmethod
    async def get_active_violations_kpi(db: AsyncSession) -> ActiveViolationsKPI:
        """
        Get KPI data for active violations.
        
        Returns:
            ActiveViolationsKPI with:
            - total_active_violations: Count of violations without notification sent
            - total_detections: Sum of confirmation_count
            - offending_vendors_count: Distinct vendors with violations
            - affected_products_count: Distinct products with violations
        """
        # Get total active violations (not notified)
        active_violations_result = await db.execute(
            select(func.count(Violation.id)).where(Violation.notification_sent_at.is_(None))
        )
        total_active = active_violations_result.scalar() or 0

        # Get total detections (sum of confirmation counts)
        detections_result = await db.execute(
            select(func.sum(Violation.confirmation_count)).where(Violation.notification_sent_at.is_(None))
        )
        total_detections = detections_result.scalar() or 0

        # Get count of offending vendors
        offending_vendors_result = await db.execute(
            select(func.count(func.distinct(Violation.vendor_id)))
            .where(
                and_(
                    Violation.notification_sent_at.is_(None),
                    Violation.vendor_id.isnot(None)
                )
            )
        )
        offending_vendors = offending_vendors_result.scalar() or 0

        # Also count discovered vendors (vendor_name without vendor_id)
        discovered_vendors_result = await db.execute(
            select(func.count(func.distinct(Violation.vendor_name)))
            .where(
                and_(
                    Violation.notification_sent_at.is_(None),
                    Violation.vendor_name.isnot(None),
                    Violation.vendor_id.is_(None),
                    Violation.source_type == "discovered"
                )
            )
        )
        discovered_vendors = discovered_vendors_result.scalar() or 0
        total_offending_vendors = offending_vendors + discovered_vendors

        # Get count of affected products
        affected_products_result = await db.execute(
            select(func.count(func.distinct(Violation.product_name)))
            .where(Violation.notification_sent_at.is_(None))
        )
        affected_products = affected_products_result.scalar() or 0

        # Determine status based on violation count
        if total_active >= 10:
            status = "critical"
        elif total_active >= 5:
            status = "active"
        else:
            status = "low"

        return ActiveViolationsKPI(
            total_active_violations=total_active,
            total_detections=total_detections,
            offending_vendors_count=total_offending_vendors,
            affected_products_count=affected_products,
            status=status,
        )

    @staticmethod
    async def get_monitored_products_kpi(db: AsyncSession) -> MonitoredProductsKPI:
        """
        Get KPI data for monitored products.
        
        Returns:
            MonitoredProductsKPI with total monitored products count
        """
        result = await db.execute(
            select(func.count(Product.id)).where(Product.status.is_(True))
        )
        total_monitored = result.scalar() or 0

        return MonitoredProductsKPI(total_monitored_products=total_monitored)

    @staticmethod
    async def get_average_discount_kpi(db: AsyncSession) -> AveragePriceDiscountKPI:
        """
        Get KPI data for average discount percentage.
        
        Calculation: (Sum of percentage_difference / Total active violations)
        Only includes non-null percentage differences
        
        Returns:
            AveragePriceDiscountKPI with average discount percentage
        """
        # Get sum of percentage differences for active violations
        sum_result = await db.execute(
            select(func.sum(Violation.percentage_difference))
            .where(
                and_(
                    Violation.notification_sent_at.is_(None),
                    Violation.percentage_difference.isnot(None)
                )
            )
        )
        sum_percentage = sum_result.scalar() or 0

        # Get count of active violations with percentage difference
        count_result = await db.execute(
            select(func.count(Violation.id))
            .where(
                and_(
                    Violation.notification_sent_at.is_(None),
                    Violation.percentage_difference.isnot(None)
                )
            )
        )
        count = count_result.scalar() or 0

        # Calculate average
        average_discount = (sum_percentage / count) if count > 0 else 0.0

        return AveragePriceDiscountKPI(average_discount_percentage=round(average_discount, 2))

    @staticmethod
    async def get_kpi_summary(db: AsyncSession) -> KPISummary:
        """
        Get complete KPI summary section.
        
        Returns:
            KPISummary with all KPI cards
        """
        active_violations = await DashboardService.get_active_violations_kpi(db)
        monitored_products = await DashboardService.get_monitored_products_kpi(db)
        average_discount = await DashboardService.get_average_discount_kpi(db)

        return KPISummary(
            active_violations=active_violations,
            monitored_products=monitored_products,
            average_discount=average_discount,
        )

    @staticmethod
    async def get_30day_violation_trend(db: AsyncSession) -> ViolationTrendChart:
        """
        Get 30-day violation trend data.
        
        Returns:
            ViolationTrendChart with daily violation counts for last 30 days
        """
        # Calculate date 30 days ago
        thirty_days_ago = datetime.utcnow() - timedelta(days=30)

        # Get violations grouped by date
        result = await db.execute(
            select(
                func.date(Violation.violation_date).label("date"),
                func.count(Violation.id).label("count"),
            )
            .where(Violation.violation_date >= thirty_days_ago)
            .group_by(func.date(Violation.violation_date))
            .order_by(func.date(Violation.violation_date))
        )

        violations_by_date = result.all()

        # Create a dict for quick lookup
        violation_dict = {str(date): count for date, count in violations_by_date}

        # Generate data points for all 30 days
        trend_data = []
        for i in range(30):
            date = (datetime.utcnow() - timedelta(days=29 - i)).date()
            date_str = str(date)
            count = violation_dict.get(date_str, 0)
            trend_data.append(
                TrendDataPoint(date=date_str, violation_count=count)
            )

        return ViolationTrendChart(trend_data=trend_data)

    @staticmethod
    async def get_marketplace_with_most_violations(
        db: AsyncSession,
    ) -> Optional[MarketplaceAnalysis]:
        """
        Get marketplace with most active violations.
        
        Returns:
            MarketplaceAnalysis with marketplace name and violation count
        """
        result = await db.execute(
            select(
                Violation.marketplace,
                func.count(Violation.id).label("violation_count"),
            )
            .where(
                and_(
                    Violation.notification_sent_at.is_(None),
                    Violation.marketplace.isnot(None),
                )
            )
            .group_by(Violation.marketplace)
            .order_by(desc("violation_count"))
            .limit(1)
        )

        row = result.first()
        if row:
            return MarketplaceAnalysis(
                marketplace_name=row[0] or "Unknown",
                violation_count=row[1],
            )
        return None

    @staticmethod
    async def get_most_offending_vendor(db: AsyncSession) -> Optional[VendorAnalysis]:
        """
        Get vendor with most active violations.
        Prioritizes registered vendors, but includes discovered vendors.
        
        Returns:
            VendorAnalysis with vendor name and violation count
        """
        # Try to get registered vendor first
        result = await db.execute(
            select(
                Vendor.name,
                func.count(Violation.id).label("violation_count"),
            )
            .join(Violation, Violation.vendor_id == Vendor.id)
            .where(
                and_(
                    Violation.notification_sent_at.is_(None),
                    Violation.source_type == "registered",
                )
            )
            .group_by(Vendor.name)
            .order_by(desc("violation_count"))
            .limit(1)
        )

        row = result.first()
        if row:
            return VendorAnalysis(
                vendor_name=row[0],
                violation_count=row[1],
            )

        # If no registered vendors, get discovered vendors
        result = await db.execute(
            select(
                Violation.vendor_name,
                func.count(Violation.id).label("violation_count"),
            )
            .where(
                and_(
                    Violation.notification_sent_at.is_(None),
                    Violation.vendor_name.isnot(None),
                    Violation.source_type == "discovered",
                )
            )
            .group_by(Violation.vendor_name)
            .order_by(desc("violation_count"))
            .limit(1)
        )

        row = result.first()
        if row:
            return VendorAnalysis(
                vendor_name=row[0] or "Unknown",
                violation_count=row[1],
            )
        return None

    @staticmethod
    async def get_most_affected_product(db: AsyncSession) -> Optional[ProductAnalysis]:
        """
        Get product with most active violations.
        
        Returns:
            ProductAnalysis with product name and violation count
        """
        result = await db.execute(
            select(
                Violation.product_name,
                func.count(Violation.id).label("violation_count"),
            )
            .where(Violation.notification_sent_at.is_(None))
            .group_by(Violation.product_name)
            .order_by(desc("violation_count"))
            .limit(1)
        )

        row = result.first()
        if row:
            return ProductAnalysis(
                product_name=row[0],
                violation_count=row[1],
            )
        return None

    @staticmethod
    async def get_highest_discount_detected(
        db: AsyncSession,
    ) -> Optional[HighestDiscountAnalysis]:
        """
        Get violation with highest discount percentage.
        
        Returns:
            HighestDiscountAnalysis with product info and highest discount
        """
        result = await db.execute(
            select(
                Violation.product_name,
                Violation.percentage_difference,
                Violation.msp,
                Violation.scraped_price,
            )
            .where(
                and_(
                    Violation.notification_sent_at.is_(None),
                    Violation.percentage_difference.isnot(None),
                )
            )
            .order_by(desc(Violation.percentage_difference))
            .limit(1)
        )

        row = result.first()
        if row:
            return HighestDiscountAnalysis(
                product_name=row[0],
                highest_discount_percentage=round(float(row[1]), 2),
                msp=float(row[2]),
                scraped_price=float(row[3]),
            )
        return None

    @staticmethod
    async def get_financial_impact(db: AsyncSession) -> FinancialImpactAnalysis:
        """
        Get estimated financial impact.
        
        Calculation: Sum of (MSP - Scraped Price) across all active violations
        
        Returns:
            FinancialImpactAnalysis with total impact and violation count
        """
        # Get sum of price differences (MSP - Scraped Price)
        result = await db.execute(
            select(
                func.sum(Violation.price_difference),
                func.count(Violation.id),
            )
            .where(
                and_(
                    Violation.notification_sent_at.is_(None),
                    Violation.price_difference.isnot(None),
                )
            )
        )

        total_impact, violation_count = result.first()
        total_impact = float(total_impact) if total_impact else 0.0
        violation_count = violation_count or 0

        return FinancialImpactAnalysis(
            total_estimated_impact=round(total_impact, 2),
            currency="USD",
            violation_count=violation_count,
        )

    @staticmethod
    async def get_violation_analysis(db: AsyncSession) -> ViolationAnalysis:
        """
        Get complete violation analysis section.
        
        Returns:
            ViolationAnalysis with all analysis insights
        """
        marketplace = await DashboardService.get_marketplace_with_most_violations(db)
        vendor = await DashboardService.get_most_offending_vendor(db)
        product = await DashboardService.get_most_affected_product(db)
        highest_discount = await DashboardService.get_highest_discount_detected(db)
        financial_impact = await DashboardService.get_financial_impact(db)

        return ViolationAnalysis(
            marketplace_with_most_violations=marketplace,
            most_offending_vendor=vendor,
            most_affected_product=product,
            highest_discount_detected=highest_discount,
            financial_impact=financial_impact,
        )

    @staticmethod
    async def get_recent_violations(
        db: AsyncSession, limit: int = 10
    ) -> RecentViolationsPanel:
        """
        Get recent violations panel.
        
        Args:
            db: Database session
            limit: Number of recent violations to return (default: 10)
        
        Returns:
            RecentViolationsPanel with recent violations and total count
        """
        # Get recent violations
        result = await db.execute(
            select(Violation)
            .options(joinedload(Violation.vendor))
            .where(Violation.notification_sent_at.is_(None))
            .order_by(desc(Violation.violation_date))
            .limit(limit)
        )
        violations = result.scalars().unique().all()

        # Convert to response items
        recent_items = []
        for violation in violations:
            item = RecentViolationItem(
                id=violation.id,
                product_name=violation.product_name,
                marketplace=violation.marketplace,
                vendor_name=violation.vendor_name or (violation.vendor.name if violation.vendor else None),
                current_price=float(violation.scraped_price),
                target_msp=float(violation.msp),
                price_difference=float(violation.price_difference) if violation.price_difference else None,
                percentage_difference=float(violation.percentage_difference)
                if violation.percentage_difference
                else None,
                status="open" if not violation.notification_sent_at else "notified",
                violation_date=violation.violation_date,
                url=violation.url,
            )
            recent_items.append(item)

        # Get total violations count
        total_result = await db.execute(
            select(func.count(Violation.id)).where(Violation.notification_sent_at.is_(None))
        )
        total_count = total_result.scalar() or 0

        return RecentViolationsPanel(
            recent_violations=recent_items,
            total_violations=total_count,
        )

    @staticmethod
    async def get_complete_dashboard(db: AsyncSession) -> DashboardResponse:
        """
        Get complete dashboard data.
        
        Aggregates all dashboard sections into a single response.
        
        Returns:
            DashboardResponse with all dashboard data
        """
        kpi_summary = await DashboardService.get_kpi_summary(db)
        violation_trend = await DashboardService.get_30day_violation_trend(db)
        violation_analysis = await DashboardService.get_violation_analysis(db)
        recent_violations = await DashboardService.get_recent_violations(db, limit=10)

        return DashboardResponse(
            kpi_summary=kpi_summary,
            violation_trend=violation_trend,
            violation_analysis=violation_analysis,
            recent_violations=recent_violations,
            last_refreshed_at=datetime.utcnow(),
        )
