Source code for arviz_stats.loo.loo_pit
"""Compute leave one out (PSIS-LOO) probability integral transform (PIT) values."""
import numpy as np
import xarray as xr
from arviz_base import convert_to_datatree, extract
from xarray_einstats.stats import logsumexp
from arviz_stats.loo.helper_loo import _get_r_eff
from arviz_stats.utils import ELPDData, get_log_likelihood_dataset
[docs]
def loo_pit(
    data,
    var_names=None,
    log_weights=None,
):
    r"""Compute leave one out (PSIS-LOO) probability integral transform (PIT) values.
    The LOO-PIT values are :math:`p(\tilde{y}_i \le y_i \mid y_{-i})`, where :math:`y_i`
    represents the observed data for index :math:`i` and :math:`\tilde y_i` represents the
    posterior predictive sample at index :math:`i`. Note that :math:`y_{-i}` indicates we have
    left out the :math:`i`-th observation. LOO-PIT values are computed using the PSIS-LOO-CV
    method described in [1]_ and [2]_.
    Parameters
    ----------
    data : DataTree or InferenceData
        It should contain posterior, posterior_predictive and log_likelihood groups.
    var_names : str or list of str, optional
        Names of the variables to be used to compute the LOO-PIT values. If None, all
        variables are used. The function assumes that the observed and log_likelihood
        variables share the same names.
    log_weights: DataArray or ELPDData, optional
        Smoothed log weights. Can be either:
        - A DataArray with the same shape as ``y_pred``
        - An ELPDData object from a previous :func:`arviz_stats.loo` call.
        Defaults to None. If not provided, it will be computed using the PSIS-LOO method.
    Returns
    -------
    loo_pit: array or DataArray
        Value of the LOO-PIT at each observed data point.
    Examples
    --------
    Calculate LOO-PIT values using as test quantity the observed values themselves.
    .. ipython::
        In [1]: from arviz_stats import loo_pit
           ...: from arviz_base import load_arviz_data
           ...: dt = load_arviz_data("centered_eight")
           ...: loo_pit(dt)
    Calculate LOO-PIT values using as test quantity the square of the difference between
    each observation and `mu`. For this we create a new DataTree, copying the posterior and
    log_likelihood groups and creating new observed and posterior_predictive groups.
    .. ipython::
        In [1]: from arviz_base import from_dict
           ...: new_dt = from_dict({"posterior": dt.posterior,
           ...:                 "log_likelihood": dt.log_likelihood,
           ...:                 "observed_data": {
           ...:                     "obs": (dt.observed_data.obs
           ...:                            - dt.posterior.mu.median(dim=("chain", "draw")))**2},
           ...:                 "posterior_predictive": {
           ...:                     "obs": (dt.posterior_predictive.obs - dt.posterior.mu)**2}})
           ...: loo_pit(new_dt)
    References
    ----------
    .. [1] Vehtari et al. *Practical Bayesian model evaluation using leave-one-out cross-validation
        and WAIC*. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4
        arXiv preprint https://arxiv.org/abs/1507.04544.
    .. [2] Vehtari et al. *Pareto Smoothed Importance Sampling*.
        Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html
        arXiv preprint https://arxiv.org/abs/1507.02646
    """
    data = convert_to_datatree(data)
    rng = np.random.default_rng(214)
    if var_names is None:
        var_names = list(data.observed_data.data_vars.keys())
    elif isinstance(var_names, str):
        var_names = [var_names]
    log_likelihood = get_log_likelihood_dataset(data, var_names=var_names)
    if log_weights is None:
        n_samples = log_likelihood.chain.size * log_likelihood.draw.size
        reff = _get_r_eff(data, n_samples)
        log_weights, _ = log_likelihood.azstats.psislw(r_eff=reff)
    if isinstance(log_weights, ELPDData):
        if log_weights.log_weights is None:
            raise ValueError("ELPDData object does not contain log_weights")
        log_weights = log_weights.log_weights
    posterior_predictive = extract(
        data,
        group="posterior_predictive",
        combined=False,
        var_names=var_names,
        keep_dataset=True,
    )
    observed_data = extract(
        data,
        group="observed_data",
        combined=False,
        var_names=var_names,
        keep_dataset=True,
    )
    sel_min = {}
    sel_sup = {}
    for var in var_names:
        pred = posterior_predictive[var]
        obs = observed_data[var]
        sel_min[var] = pred < obs
        sel_sup[var] = pred == obs
    sel_min = xr.Dataset(sel_min)
    sel_sup = xr.Dataset(sel_sup)
    pit = np.exp(logsumexp(log_weights.where(sel_min, -np.inf), dims=["chain", "draw"]))
    loo_pit_values = xr.Dataset(coords=observed_data.coords)
    for var in var_names:
        pit_lower = pit[var].values
        if sel_sup[var].any():
            pit_sup_addition = np.exp(
                logsumexp(log_weights.where(sel_sup[var], -np.inf), dims=["chain", "draw"])
            )
            pit_upper = pit_lower + pit_sup_addition[var].values
            random_value = rng.uniform(pit_lower, pit_upper)
            loo_pit_values[var] = observed_data[var].copy(data=random_value)
        else:
            loo_pit_values[var] = observed_data[var].copy(data=pit_lower)
    return loo_pit_values