from fastapi import APIRouter, Depends, HTTPException, Path, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from typing import List
from bson import ObjectId
from datetime import datetime

from app.utility.security import verify_jwt
from app.db.db import get_col
from app.models.productVarient import ProductVariant

router = APIRouter(prefix="/product-variants", tags=["Product Variants"])
security = HTTPBearer()


# ---------------------------
# ⛔ Role-based Access Control
# ---------------------------
def admin_required(credentials: HTTPAuthorizationCredentials = Depends(security)):
    token = credentials.credentials
    ok, payload = verify_jwt(token)
    
    if not ok or not payload:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid token"
        )
    
    if payload.get("role") != "admin":
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail="Admin access required"
        )
    return payload


# ---------------------------
# 1️⃣ Create Variant (Admin Only)
# ---------------------------
@router.post("/create-variant/{product_id}", status_code=201)
async def create_variant(
    product_id: str,
    variant: ProductVariant,
    user=Depends(admin_required)
):
    # Verify product exists
    products_col = get_col("products")
    product_oid = ObjectId(product_id)
    product = products_col.find_one({"_id": product_oid})
    if not product:
        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Product not found")
    
    col = get_col("product_variants")
    
    now = datetime.utcnow()
    new_variant = variant.dict()
    new_variant["product"] = product_oid  # Use "product" field name to match productRouter.py
    new_variant["createdAt"] = now
    new_variant["updatedAt"] = now
    
    # Set default values if not provided
    if "total_sold" not in new_variant:
        new_variant["total_sold"] = 0
    if "total_stock" not in new_variant:
        new_variant["total_stock"] = new_variant.get("stock", 0)

    result = col.insert_one(new_variant)

    return {
        "success": True,
        "message": "Variant created",
        "variant_id": str(result.inserted_id)
    }


# ---------------------------
# 2️⃣ Update Variant (Admin Only)
# ---------------------------
@router.put("/update-variant/{variant_id}", status_code=200)
async def update_variant(
    variant_id: str,
    variant: ProductVariant,
    user=Depends(admin_required)
):
    col = get_col("product_variants")

    update_data = {k: v for k, v in variant.dict().items() if v is not None}
    update_data["updatedAt"] = datetime.utcnow()

    result = col.update_one(
        {"_id": ObjectId(variant_id)},
        {"$set": update_data}
    )

    if result.matched_count == 0:
        raise HTTPException(404, "Variant not found")

    return {"success": True, "message": "Variant updated"}


# ---------------------------
# 3️⃣ Delete Variant (Admin Only)
# ---------------------------
@router.delete("/delete-variant/{variant_id}", status_code=200)
async def delete_variant(
    variant_id: str,
    user=Depends(admin_required)
):
    col = get_col("product_variants")

    result = col.delete_one({"_id": ObjectId(variant_id)})

    if result.deleted_count == 0:
        raise HTTPException(404, "Variant not found")

    return {"success": True, "message": "Variant deleted"}


# ---------------------------
# 4️⃣ Get All Variants by Product ID
# ---------------------------
@router.get("/product/{product_id}")
async def get_variants_by_product(product_id: str):
    col = get_col("product_variants")

    variants_cursor = col.find({"product": ObjectId(product_id)})  # Use "product" field name
    
    # Serialize variants to match the format used in productRouter.py
    def serialize_variant(v):
        def serialize_datetime(value):
            if isinstance(value, datetime):
                return value.isoformat()
            return value
        
        return {
            "id": str(v["_id"]),
            "productId": str(v["product"]),
            "variant_Type": v.get("variant_Type"),
            "variant_Values": v.get("variant_Values"),
            "price": v.get("price"),
            "offer_price": v.get("offer_price"),
            "offer_percentage": v.get("offer_percentage"),
            "sku": v.get("sku"),
            "stock": v.get("stock"),
            "total_stock": v.get("total_stock"),
            "total_sold": v.get("total_sold", 0),
            "variant_url": v.get("variant_url"),
            "schema_markup": v.get("schema_markup"),
            "createdAt": serialize_datetime(v.get("createdAt")),
            "updatedAt": serialize_datetime(v.get("updatedAt")),
        }
    
    variants = [serialize_variant(v) for v in variants_cursor]
    return variants


# ---------------------------
# 5️⃣ Get Single Variant by ID
# ---------------------------
@router.get("/single/{variant_id}")
async def get_one_variant(variant_id: str):
    col = get_col("product_variants")

    variant = col.find_one({"_id": ObjectId(variant_id)})

    if not variant:
        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Variant not found")
    
    # Serialize variant to match the format used in productRouter.py
    def serialize_datetime(value):
        if isinstance(value, datetime):
            return value.isoformat()
        return value
    
    return {
        "id": str(variant["_id"]),
        "productId": str(variant["product"]),
        "variant_Type": variant.get("variant_Type"),
        "variant_Values": variant.get("variant_Values"),
        "price": variant.get("price"),
        "offer_price": variant.get("offer_price"),
        "offer_percentage": variant.get("offer_percentage"),
        "sku": variant.get("sku"),
        "stock": variant.get("stock"),
        "total_stock": variant.get("total_stock"),
        "total_sold": variant.get("total_sold", 0),
        "variant_url": variant.get("variant_url"),
        "schema_markup": variant.get("schema_markup"),
        "createdAt": serialize_datetime(variant.get("createdAt")),
        "updatedAt": serialize_datetime(variant.get("updatedAt")),
    }
