import base64
import hashlib
import hmac
import json
import os
import time
from typing import Any, Dict, Optional, Tuple
from dotenv import load_dotenv

load_dotenv()

DEFAULT_JWT_SECRET = os.getenv("DEFAULT_JWT_SECRET", "dev-secret-change-me")
DEFAULT_JWT_EXP_SECONDS = int(os.getenv("JWT_EXP_SECONDS", 86400))
PASSWORD_HASH_ITERATIONS = int(os.getenv("PASSWORD_HASH_ITERATIONS", 310000))


def _b64url_encode(raw: bytes) -> str:
    return base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")


def _b64url_decode(b64: str) -> bytes:
    padding = "=" * (-len(b64) % 4)
    return base64.urlsafe_b64decode(b64 + padding)


def hash_password(plain_password: str) -> str:
    salt = os.urandom(16)
    dk = hashlib.pbkdf2_hmac("sha256", plain_password.encode("utf-8"), salt, PASSWORD_HASH_ITERATIONS, dklen=32)
    return "pbkdf2_sha256${}${}${}".format(
        PASSWORD_HASH_ITERATIONS,
        _b64url_encode(salt),
        _b64url_encode(dk),
    )


def verify_password(plain_password: str, stored_hash: str) -> bool:
    try:
        alg, iterations_s, salt_b64, hash_b64 = stored_hash.split("$", 3)
        if alg != "pbkdf2_sha256":
            return False
        iterations = int(iterations_s)
        salt = _b64url_decode(salt_b64)
        expected = _b64url_decode(hash_b64)
        dk = hashlib.pbkdf2_hmac("sha256", plain_password.encode("utf-8"), salt, iterations, dklen=len(expected))
        return hmac.compare_digest(dk, expected)
    except Exception:
        return False


def create_jwt(payload: Dict[str, Any], exp_seconds: Optional[int] = None, secret: Optional[str] = None) -> str:
    header = {"alg": "HS256", "typ": "JWT"}
    now = int(time.time())
    exp = now + (exp_seconds if exp_seconds is not None else DEFAULT_JWT_EXP_SECONDS)
    claims = dict(payload)
    claims["iat"] = now
    claims["exp"] = exp
    secret_bytes = (secret or DEFAULT_JWT_SECRET).encode("utf-8")

    header_b64 = _b64url_encode(json.dumps(header, separators=(",", ":")).encode("utf-8"))
    payload_b64 = _b64url_encode(json.dumps(claims, separators=(",", ":")).encode("utf-8"))
    signing_input = f"{header_b64}.{payload_b64}".encode("ascii")
    signature = hmac.new(secret_bytes, signing_input, hashlib.sha256).digest()
    sig_b64 = _b64url_encode(signature)
    return f"{header_b64}.{payload_b64}.{sig_b64}"


def verify_jwt(token: str, secret: Optional[str] = None) -> Tuple[bool, Optional[Dict[str, Any]]]:
    try:
        header_b64, payload_b64, sig_b64 = token.split(".", 2)
        signing_input = f"{header_b64}.{payload_b64}".encode("ascii")
        secret_bytes = (secret or DEFAULT_JWT_SECRET).encode("utf-8")
        expected_sig = hmac.new(secret_bytes, signing_input, hashlib.sha256).digest()
        if not hmac.compare_digest(expected_sig, _b64url_decode(sig_b64)):
            return False, None
        payload = json.loads(_b64url_decode(payload_b64).decode("utf-8"))
        now = int(time.time())
        if "exp" in payload and now > int(payload["exp"]):
            return False, None
        return True, payload
    except Exception:
        return False, None

def get_jwt_expiration(token: str) -> Optional[int]:
    try:
        _, payload_b64, _ = token.split(".", 2)
        payload = json.loads(_b64url_decode(payload_b64).decode("utf-8"))
        return int(payload.get("exp")) if "exp" in payload else None
    except Exception:
        return None

