"""Continuously ranked probability scores with PSIS-LOO-CV weights."""
from collections import namedtuple
import numpy as np
import xarray as xr
from arviz_base import convert_to_datatree, extract
from xarray_einstats.stats import logsumexp
from arviz_stats.loo.helper_loo import (
    _get_r_eff,
    _prepare_loo_inputs,
    _validate_crps_input,
    _warn_pareto_k,
)
from arviz_stats.utils import round_num
[docs]
def loo_score(
    data,
    var_name=None,
    log_weights=None,
    pareto_k=None,
    kind="crps",
    pointwise=False,
    round_to="2g",
):
    r"""Compute PWM-based CRPS/SCRPS with PSIS-LOO-CV weights.
    Implements the probability-weighted-moment (PWM) identity for the continuous ranked
    probability score (CRPS) with Pareto-smoothed importance sampling leave-one-out (PSIS-LOO-CV)
    weights, but returns its negative as a maximization score (larger is better). This assumes
    that the PSIS-LOO-CV approximation is working well.
    Specifically, the PWM identity used here is
    .. math::
        \operatorname{CRPS}_{\text{loo}}(F, y)
        = E_{\text{loo}}\left[|X - y|\right]
        + E_{\text{loo}}[X]
        - 2\cdot E_{\text{loo}} \left[X\,F_{\text{loo}}(X') \right].
    The PWM identity is described in [3]_, traditional CRPS and SCRPS are described in
    [1]_ and [2]_, and the PSIS-LOO-CV method is described in [4]_ and [5]_.
    Parameters
    ----------
    data : DataTree or InferenceData
        Input data. It should contain the ``posterior_predictive``, ``observed_data`` and
        ``log_likelihood`` groups.
    var_name : str, optional
        The name of the variable in the log_likelihood group to use. If None, the first
        variable in ``observed_data`` is used and assumed to match ``log_likelihood`` and
        ``posterior_predictive`` names.
    log_weights : DataArray, optional
        Smoothed log weights for PSIS-LOO-CV. Must have the same shape as the log-likelihood data.
        Defaults to None. If not provided, they will be computed via PSIS-LOO-CV. Must be provided
        together with ``pareto_k`` or both must be None.
    pareto_k : DataArray, optional
        Pareto tail indices corresponding to the PSIS smoothing. Same shape as the log-likelihood
        data. If not provided, they will be computed via PSIS-LOO-CV. Must be provided together with
        ``log_weights`` or both must be None.
    kind : str, default "crps"
        The kind of score to compute. Available options are:
        - 'crps': continuous ranked probability score. Default.
        - 'scrps': scale-invariant continuous ranked probability score.
    pointwise : bool, default False
        If True, include per-observation score values in the return object.
    round_to : int or str, default "2g"
        If integer, number of decimal places to round the result. If string of the form ``"2g"``,
        number of significant digits to round the result. Use None to return raw numbers.
    Returns
    -------
    namedtuple
        If ``pointwise`` is False (default), a namedtuple named ``CRPS`` or ``SCRPS`` with fields
        ``mean`` and ``se``. If ``pointwise`` is True, the namedtuple also includes a ``pointwise``
        field with per-observation values.
    Examples
    --------
    Compute scores and return the mean and standard error:
    .. ipython::
        :okwarning:
        In [1]: from arviz_stats import loo_score
           ...: from arviz_base import load_arviz_data
           ...: dt = load_arviz_data("centered_eight")
           ...: loo_score(dt, kind="crps")
    .. ipython::
        :okwarning:
        In [2]: loo_score(dt, kind="scrps")
    We can also pass previously computed PSIS-LOO weights and return the pointwise values:
    .. ipython::
        :okwarning:
        In [3]: from arviz_stats import loo
           ...: loo_data = loo(dt, pointwise=True)
           ...: loo_score(dt, kind="crps",
           ...:           log_weights=loo_data.log_weights,
           ...:           pareto_k=loo_data.pareto_k,
           ...:           pointwise=True)
    Notes
    -----
    For a single observation with posterior-predictive draws :math:`x_1, \ldots, x_S`
    and PSIS-LOO-CV weights :math:`w_i \propto \exp(\ell_i)` normalized so that
    :math:`\sum_{i=1}^S w_i = 1`, define the PSIS-LOO-CV expectation and the left-continuous
    weighted CDF as
    .. math::
        E_{\text{loo}}[g(X)] := \sum_{i=1}^S w_i\, g(x_i), \quad
        F_{\text{loo}}(x') := \sum_{i: x_i < x} w_i.
    The first probability-weighted moment is
    :math:`b_1 := E_{\text{loo}}\left[X\,F_{\text{loo}}(X')\right]`.
    With this, the nonnegative CRPS under PSIS-LOO-CV is
    .. math::
        \operatorname{CRPS}_{\text{loo}}(F, y)
        = E_{\text{loo}}\left[\,|X-y|\,\right]
        + E_{\text{loo}}[X] - 2\,b_1.
    For the scale term for the SCRPS, we use the PSIS-LOO-CV weighted Gini mean difference given by
    :math:`\Delta_{\text{loo}} := E_{\text{loo}}\left[\,|X - X'|\,\right]`.
    This admits the PWM representation given by
    .. math::
        \Delta_{\text{loo}} =
        2\,E_{\text{loo}}\left[\,X\,\left(2F_{\text{loo}}(X') - 1\right)\,\right].
    A finite-sample weighted order-statistic version of this is used in the function and is given by
    .. math::
        \Delta_{\text{loo}} =
        2 \sum_{i=1}^S w_{(i)}\, x_{(i)} \left\{\,2 F^-_{(i)} + w_{(i)} - 1\,\right\},
    where :math:`x_{(i)}` are the values sorted increasingly, :math:`w_{(i)}` are the
    corresponding normalized weights, and :math:`F^-_{(i)} = \sum_{j<i} w_{(j)}`.
    The locally scale-invariant score returned for ``kind="scrps"`` is
    .. math::
        S_{\text{SCRPS}}(F, y)
        = -\frac{E_{\text{loo}}\left[\,|X-y|\,\right]}{\Delta_{\text{loo}}}
        - \frac{1}{2}\log \Delta_{\text{loo}}.
    When PSIS weights are highly variable (large Pareto :math:`k`), Monte-Carlo noise can
    increase. This function surfaces PSIS-LOO-CV diagnostics via ``pareto_k`` and warns when
    tail behavior suggests unreliability.
    References
    ----------
    .. [1] Bolin, D., & Wallin, J. (2023). *Local scale invariance and robustness of
       proper scoring rules*. Statistical Science, 38(1), 140–159. https://doi.org/10.1214/22-STS864
       arXiv preprint https://arxiv.org/abs/1912.05642
    .. [2] Gneiting, T., & Raftery, A. E. (2007). *Strictly Proper Scoring Rules,
       Prediction, and Estimation*. Journal of the American Statistical Association,
       102(477), 359–378. https://doi.org/10.1198/016214506000001437
    .. [3] Taillardat M, Mestre O, Zamo M, Naveau P (2016). *Calibrated ensemble forecasts using
       quantile regression forests and ensemble model output statistics*. Mon Weather Rev
       144(6):2375–2393. https://doi.org/10.1175/MWR-D-15-0260.1
    .. [4] Vehtari, A., Gelman, A., & Gabry, J. (2017). *Practical Bayesian model
       evaluation using leave-one-out cross-validation and WAIC*. Statistics and Computing,
       27(5), 1413–1432. https://doi.org/10.1007/s11222-016-9696-4
       arXiv preprint https://arxiv.org/abs/1507.04544
    .. [5] Vehtari, A., et al. (2024). *Pareto Smoothed Importance Sampling*. Journal of
       Machine Learning Research, 25(72). https://jmlr.org/papers/v25/19-556.html
       arXiv preprint https://arxiv.org/abs/1507.02646
    """
    if kind not in {"crps", "scrps"}:
        raise ValueError(f"kind must be either 'crps' or 'scrps'. Got {kind}")
    data = convert_to_datatree(data)
    loo_inputs = _prepare_loo_inputs(data, var_name)
    var_name = loo_inputs.var_name
    log_likelihood = loo_inputs.log_likelihood
    y_pred = extract(data, group="posterior_predictive", var_names=var_name, combined=False)
    y_obs = extract(data, group="observed_data", var_names=var_name, combined=False)
    n_samples = loo_inputs.n_samples
    sample_dims = loo_inputs.sample_dims
    obs_dims = loo_inputs.obs_dims
    r_eff = _get_r_eff(data, n_samples)
    _validate_crps_input(y_pred, y_obs, log_likelihood, sample_dims=sample_dims, obs_dims=obs_dims)
    if (log_weights is None) != (pareto_k is None):
        raise ValueError(
            "Both log_weights and pareto_k must be provided together or both must be None. "
            "Only one was provided."
        )
    if log_weights is None and pareto_k is None:
        log_weights_da, pareto_k = log_likelihood.azstats.psislw(r_eff=r_eff, dim=sample_dims)
    else:
        log_weights_da = log_weights
    abs_error = np.abs(y_pred - y_obs)
    loo_weighted_abs_error = _loo_weighted_mean(abs_error, log_weights_da, sample_dims)
    loo_weighted_mean_prediction = _loo_weighted_mean(y_pred, log_weights_da, sample_dims)
    pwm_first_moment_b1 = _apply_pointwise_weighted_statistic(
        y_pred, log_weights_da, sample_dims, _compute_pwm_first_moment_b1
    )
    crps_pointwise = (
        loo_weighted_abs_error + loo_weighted_mean_prediction - 2.0 * pwm_first_moment_b1
    )
    if kind == "crps":
        pointwise_scores = -crps_pointwise
        khat_da = pareto_k
    else:
        gini_mean_difference = _apply_pointwise_weighted_statistic(
            y_pred, log_weights_da, sample_dims, _compute_weighted_gini_mean_difference
        )
        pointwise_scores = -(loo_weighted_abs_error / gini_mean_difference) - 0.5 * np.log(
            gini_mean_difference
        )
        khat_da = pareto_k
    _warn_pareto_k(khat_da, n_samples)
    n_pts = int(np.prod([pointwise_scores.sizes[d] for d in pointwise_scores.dims]))
    mean = pointwise_scores.mean().values.item()
    se = (pointwise_scores.std(ddof=0).values / (n_pts**0.5)).item()
    name = "SCRPS" if kind == "scrps" else "CRPS"
    if pointwise:
        return namedtuple(name, ["mean", "se", "pointwise"])(
            round_num(mean, round_to),
            round_num(se, round_to),
            pointwise_scores,
        )
    return namedtuple(name, ["mean", "se"])(
        round_num(mean, round_to),
        round_num(se, round_to),
    ) 
def _compute_pwm_first_moment_b1(values_sorted, weights):
    """Compute first PWM using a left-continuous weighted CDF."""
    values_sorted, weights_sorted = _sort_values_and_normalize_weights(values_sorted, weights)
    cumulative_weights = np.cumsum(weights_sorted)
    f_minus = cumulative_weights - weights_sorted
    return np.sum(weights_sorted * values_sorted * f_minus).item()
def _compute_weighted_gini_mean_difference(values, weights):
    """Compute PSIS-LOO-CV weighted Gini mean difference."""
    values_sorted, weights_sorted = _sort_values_and_normalize_weights(values, weights)
    cumulative_weights = np.cumsum(weights_sorted)
    cumulative_before = cumulative_weights - weights_sorted
    bracket = 2.0 * cumulative_before + weights_sorted - 1.0
    return (2.0 * np.sum(weights_sorted * values_sorted * bracket)).item()
def _loo_weighted_mean(values, log_weights, dim):
    """Compute PSIS-LOO-CV weighted mean."""
    log_num = logsumexp(log_weights, dims=dim, b=values)
    log_den = logsumexp(log_weights, dims=dim)
    return np.exp(log_num - log_den)
def _apply_pointwise_weighted_statistic(x, log_weights, sample_dims, stat_func):
    """Apply a weighted statistic over sample dims."""
    max_logw = log_weights.max(dim=sample_dims)
    weights = np.exp(log_weights - max_logw)
    stacked = "__sample__"
    xs = x.stack({stacked: sample_dims})
    ws = weights.stack({stacked: sample_dims})
    return xr.apply_ufunc(
        stat_func,
        xs,
        ws,
        input_core_dims=[[stacked], [stacked]],
        output_core_dims=[[]],
        vectorize=True,
        output_dtypes=[float],
    )
def _sort_values_and_normalize_weights(values, weights):
    """Sort values by ascending order and normalize weights."""
    idx = np.argsort(values, kind="mergesort")
    values_sorted = values[idx]
    weights_sorted = weights[idx]
    weights_sorted = weights_sorted / np.sum(weights_sorted)
    return values_sorted, weights_sorted