from datetime import datetime
from typing import Any, Dict, Optional

from bson import ObjectId
from fastapi import APIRouter, HTTPException, Path, status
from pydantic import BaseModel, Field, validator

from app.db.db import get_col


router = APIRouter(prefix="/cart", tags=["cart"])


class CartItemCreate(BaseModel):
    productId: str
    quantity: int = Field(..., gt=0)
    variantId: Optional[str] = None
    variantType: Optional[str] = None
    offer_price: Optional[float] = None
    offer_percentage: Optional[float] = None
    variantValue: Optional[str] = None
    price: float = Field(..., ge=0)

    @validator("productId")
    def _validate_product(cls, value: str) -> str:
        if not ObjectId.is_valid(value):
            raise ValueError("Invalid productId")
        return value

    @validator("variantId")
    def _validate_variant(cls, value: Optional[str]) -> Optional[str]:
        if value and not ObjectId.is_valid(value):
            raise ValueError("Invalid variantId")
        return value


class CartItemUpdate(BaseModel):
    quantity: Optional[int] = Field(default=None, gt=0)
    price: Optional[float] = Field(default=None, ge=0)


def _ensure_object_id(identifier: str) -> ObjectId:
    try:
        return ObjectId(identifier)
    except Exception as exc:  # pragma: no cover - defensive
        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid identifier") from exc


def _recalculate_totals(cart_doc: Dict[str, Any]) -> Dict[str, Any]:
    total_amount = 0.0
    for item in cart_doc.get("items", []):
        total_amount += float(item.get("price", 0)) * int(item.get("quantity", 0))
    cart_doc["totalAmount"] = round(total_amount, 2)
    platform_fee = float(cart_doc.get("platformFee", 8.0))
    shipping = float(cart_doc.get("shippingCharges", 0))
    cart_doc["finalAmount"] = round(total_amount + platform_fee + shipping, 2)
    cart_doc["updatedAt"] = datetime.utcnow()
    return cart_doc


def _serialize_cart(cart_doc: Dict[str, Any]) -> Dict[str, Any]:
    return {
        "id": str(cart_doc["_id"]),
        "user": str(cart_doc["user"]),
        "items": [
            {
                "id": str(item["_id"]),
                "product": str(item["product"]),
                "quantity": item["quantity"],
                "variantId": str(item["variantId"]) if item.get("variantId") else None,
                "variantType": item.get("variantType"),
                "variantValue": item.get("variantValue"),
                "price": item.get("price", 0),
                "offer_price": item.get("offer_price"),
                "offer_percentage": item.get("offer_percentage"),
                "createdAt": item.get("createdAt"),
                "updatedAt": item.get("updatedAt"),
            }
            for item in cart_doc.get("items", [])
        ],
        "totalAmount": cart_doc.get("totalAmount", 0),
        "platformFee": cart_doc.get("platformFee", 8.0),
        "shippingCharges": cart_doc.get("shippingCharges", 0),
        "finalAmount": cart_doc.get("finalAmount", 0),
        "createdAt": cart_doc.get("createdAt"),
        "updatedAt": cart_doc.get("updatedAt"),
    }


def _get_or_create_cart(user_oid: ObjectId) -> Dict[str, Any]:
    carts = get_col("carts")
    cart = carts.find_one({"user": user_oid})
    if cart:
        return cart
    now = datetime.utcnow()
    cart_model = {
        "user": user_oid,
        "items": [],
        "totalAmount": 0.0,
        "platformFee": 8.0,
        "shippingCharges": 0.0,
        "finalAmount": 0.0,
        "createdAt": now,
        "updatedAt": now,
    }
    res = carts.insert_one(cart_model)
    cart_model["_id"] = res.inserted_id
    return cart_model


def _persist_cart(cart_doc: Dict[str, Any]) -> Dict[str, Any]:
    carts = get_col("carts")
    doc_id = cart_doc["_id"]
    payload = dict(cart_doc)
    payload.pop("_id", None)
    carts.update_one({"_id": doc_id}, {"$set": payload})
    return carts.find_one({"_id": doc_id})


@router.get("/{user_id}", status_code=status.HTTP_200_OK)
def get_cart(user_id: str = Path(..., description="User identifier")):
    user_oid = _ensure_object_id(user_id)
    cart = _get_or_create_cart(user_oid)
    return _serialize_cart(cart)


@router.post("/{user_id}/add/items", status_code=status.HTTP_200_OK)
def add_to_cart(user_id: str, body: CartItemCreate):
    user_oid = _ensure_object_id(user_id)
    cart = _get_or_create_cart(user_oid)
    product_oid = _ensure_object_id(body.productId)
    variant_oid = _ensure_object_id(body.variantId) if body.variantId else None
    for item in cart.get("items", []):
        if item["product"] == product_oid and item.get("variantId") == variant_oid:
            item["quantity"] += body.quantity
            item["price"] = body.price
            item["updatedAt"] = datetime.utcnow()
            break
    else:
        new_item = {
            "_id": ObjectId(),
            "product": product_oid,
            "quantity": body.quantity,
            "variantId": variant_oid,
            "variantType": body.variantType,
            "variantValue": body.variantValue,
            "price": body.price,
            "offer_price": body.offer_price,
            "offer_percentage": body.offer_percentage,
            "createdAt": datetime.utcnow(),
            "updatedAt": datetime.utcnow(),
        }
        cart.setdefault("items", []).append(new_item)
    updated = _persist_cart(_recalculate_totals(cart))
    return _serialize_cart(updated)


@router.put("/{user_id}/update/items/{item_id}", status_code=status.HTTP_200_OK)
def update_cart_item(user_id: str, item_id: str, body: CartItemUpdate):
    user_oid = _ensure_object_id(user_id)
    item_oid = _ensure_object_id(item_id)
    carts = get_col("carts")
    cart = carts.find_one({"user": user_oid})
    if not cart:
        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Cart not found")
    updated_item = None
    for item in cart.get("items", []):
        if item["_id"] == item_oid:
            if body.quantity is not None:
                item["quantity"] = body.quantity
            if body.price is not None:
                item["price"] = body.price
            if body.offer_price is not None:
                item["offer_price"] = body.offer_price
            if body.offer_percentage is not None:
                item["offer_percentage"] = body.offer_percentage
            item["updatedAt"] = datetime.utcnow()
            updated_item = item
            break
    if not updated_item:
        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Cart item not found")
    cart = _persist_cart(_recalculate_totals(cart))
    return _serialize_cart(cart)


@router.delete("/{user_id}/delete/items/{item_id}", status_code=status.HTTP_200_OK)
def remove_cart_item(user_id: str, item_id: str):
    user_oid = _ensure_object_id(user_id)
    item_oid = _ensure_object_id(item_id)
    carts = get_col("carts")
    cart = carts.find_one({"user": user_oid})
    if not cart:
        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Cart not found")
    original_len = len(cart.get("items", []))
    cart["items"] = [item for item in cart.get("items", []) if item["_id"] != item_oid]
    if len(cart["items"]) == original_len:
        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Cart item not found")
    cart = _persist_cart(_recalculate_totals(cart))
    return _serialize_cart(cart)


@router.delete("/{user_id}", status_code=status.HTTP_200_OK)
def clear_cart(user_id: str):
    user_oid = _ensure_object_id(user_id)
    carts = get_col("carts")
    cart = carts.find_one({"user": user_oid})
    if not cart:
        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Cart not found")
    cart["items"] = []
    cart["totalAmount"] = 0.0
    cart["finalAmount"] = cart.get("platformFee", 8.0) + cart.get("shippingCharges", 0.0)
    cart["updatedAt"] = datetime.utcnow()
    cart = _persist_cart(cart)
    return _serialize_cart(cart)


