"""Custom renderer to wrap all API responses in a standard format."""

from rest_framework.renderers import JSONRenderer
from rest_framework_simplejwt.tokens import AccessToken, RefreshToken, UntypedToken

from apps.core.messages import get_message
from apps.core.middleware import get_current_language
from apps.core.tokens import TOKEN_TYPE, format_token_expiry


def _normalize_jwt_payload(data):
    """Promote SimpleJWT's raw `access` / `refresh` keys to the richer shape.

    Any response data that contains an `access` and/or `refresh` token from
    SimpleJWT is rewritten in place so the public API always returns
    `access_token`, `refresh_token`, `token_type`, `expires_at`, and
    `refresh_expires_at`.

    This makes /jwt/create/, /jwt/refresh/, and any custom auth endpoint
    return a consistent shape without per-endpoint code.
    """
    if not isinstance(data, dict):
        return data

    raw_access = data.pop("access", None)
    raw_refresh = data.pop("refresh", None)

    if raw_access is not None:
        try:
            access = AccessToken(raw_access)
        except Exception:
            access = None
        data["access_token"] = str(raw_access)
        data["token_type"] = TOKEN_TYPE
        data["expires_at"] = format_token_expiry(access) if access else ""

    if raw_refresh is not None:
        try:
            refresh = RefreshToken(raw_refresh)
        except Exception:
            try:
                refresh = UntypedToken(raw_refresh)
            except Exception:
                refresh = None
        data["refresh_token"] = str(raw_refresh)
        data["refresh_expires_at"] = format_token_expiry(refresh) if refresh else ""

    return data


class StandardResponseRenderer(JSONRenderer):
    """Wrap all responses in a standard envelope.

    Success:
        {"status": 1, "code": 200, "message": "...", "data": {...}, "errors": []}

    Error:
        {"status": 0, "code": 4xx, "message": "...", "data": null,
         "errors": [{"field": "...", "message": "..."}]}

    Views can set the message key by including `_message` in their response
    payload (it will be popped before wrapping). Otherwise "success" / "error"
    is used.

    SimpleJWT `access` / `refresh` keys are automatically normalized to
    `access_token` / `refresh_token` with `token_type`, `expires_at`, and
    `refresh_expires_at` added.
    """

    def render(self, data, accepted_media_type=None, renderer_context=None):
        response = renderer_context.get("response") if renderer_context else None
        request = renderer_context.get("request") if renderer_context else None

        if response is None:
            return super().render(data, accepted_media_type, renderer_context)

        status_code = response.status_code
        is_success = 200 <= status_code < 400

        # Resolve language: request attr (set by middleware) > thread-local default
        lang = getattr(request, "language", None) if request else None
        if not lang:
            lang = get_current_language()

        if is_success:
            message_key = (
                data.pop("_message", "success") if isinstance(data, dict) else "success"
            )
            data = _normalize_jwt_payload(data)
            wrapped = {
                "status": 1,
                "code": status_code,
                "message": get_message(message_key, lang),
                "data": data,
                "errors": [],
            }
        else:
            errors = []
            if isinstance(data, dict):
                message_key = data.pop("_message", None)
                if not message_key:
                    message_key = data.get("detail", "error")
                for key, value in data.items():
                    if key in ("detail", "_message"):
                        continue
                    if isinstance(value, list):
                        for err in value:
                            errors.append({"field": key, "message": str(err)})
                    else:
                        errors.append({"field": key, "message": str(value)})
            elif isinstance(data, list):
                message_key = "error"
                errors = [{"message": str(e)} for e in data]
            else:
                message_key = str(data)

            wrapped = {
                "status": 0,
                "code": status_code,
                "message": get_message(str(message_key), lang),
                "data": None,
                "errors": errors,
            }

        return super().render(wrapped, accepted_media_type, renderer_context)
