Source code for arviz_stats.loo.loo_moment_match

"""Compute moment matching for problematic observations in PSIS-LOO-CV."""

import warnings
from collections import namedtuple
from copy import deepcopy

import arviz_base as azb
import numpy as np
import xarray as xr
from arviz_base import dataset_to_dataarray, rcParams
from xarray_einstats.stats import logsumexp

from arviz_stats.loo.helper_loo import (
    _get_log_likelihood_i,
    _get_r_eff,
    _get_r_eff_i,
    _get_weights_and_k_i,
    _prepare_loo_inputs,
    _shift,
    _shift_and_cov,
    _shift_and_scale,
    _warn_pareto_k,
)
from arviz_stats.sampling_diagnostics import ess
from arviz_stats.utils import ELPDData

SplitMomentMatch = namedtuple("SplitMomentMatch", ["lwi", "lwfi", "log_liki", "reff"])
UpdateQuantities = namedtuple("UpdateQuantities", ["lwi", "lwfi", "ki", "kfi", "log_liki"])
LooMomentMatchResult = namedtuple(
    "LooMomentMatchResult",
    ["final_log_liki", "final_lwi", "final_ki", "kfs_i", "reff_i", "original_ki", "i"],
)


[docs] def loo_moment_match( data, loo_orig, log_prob_upars_fn, log_lik_i_upars_fn, upars=None, var_name=None, reff=None, max_iters=30, k_threshold=None, split=True, cov=False, pointwise=None, ): r"""Compute moment matching for problematic observations in PSIS-LOO-CV. Adjusts the results of a previously computed Pareto smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV) object by applying a moment matching algorithm to observations with high Pareto k diagnostic values. The moment matching algorithm iteratively adjusts the posterior draws in the unconstrained parameter space to better approximate the leave-one-out posterior. The moment matching algorithm is described in [1]_ and the PSIS-LOO-CV method is described in [2]_ and [3]_. Parameters ---------- data : DataTree or InferenceData Input data. It should contain the posterior and the log_likelihood groups. loo_orig : ELPDData An existing ELPDData object from a previous `loo` result. Must contain pointwise Pareto k values (`pointwise=True` must have been used). log_prob_upars_fn : callable Function that computes the log probability density of the full posterior distribution evaluated at unconstrained parameter draws. The function signature is ``log_prob_upars_fn(upars)`` where ``upars`` is a :class:`~xarray.DataArray` of unconstrained parameter draws with dimensions ``chain``, ``draw``, and a parameter dimension. It should return a :class:`~xarray.DataArray` with dimensions ``chain``, ``draw``. log_lik_i_upars_fn : callable Function that computes the log-likelihood of a single left-out observation evaluated at unconstrained parameter draws. The function signature is ``log_lik_i_upars_fn(upars, i)`` where ``upars`` is a :class:`~xarray.DataArray` of unconstrained parameter draws and ``i`` is the integer index of the left-out observation. It should return a :class:`~xarray.DataArray` with dimensions ``chain``, ``draw``. upars : DataArray, optional Posterior draws transformed to the unconstrained parameter space. Must have ``chain`` and ``draw`` dimensions, plus one additional dimension containing all parameters. Parameter names can be provided as coordinate values on this dimension. If not provided, will attempt to use the ``unconstrained_posterior`` group from the input data if available. var_name : str, optional The name of the variable in log_likelihood group storing the pointwise log likelihood data to use for loo computation. reff: float, optional Relative MCMC efficiency, ``ess / n`` i.e. number of effective samples divided by the number of actual samples. Computed from trace by default. max_iters : int, default 30 Maximum number of moment matching iterations for each problematic observation. k_threshold : float, optional Threshold value for Pareto k values above which moment matching is applied. Defaults to :math:`\min(1 - 1/\log_{10}(S), 0.7)`, where S is the number of samples. split : bool, default True If True, only transform half of the draws and use multiple importance sampling to combine them with untransformed draws. cov : bool, default False If True, match the covariance structure during the transformation, in addition to the mean and marginal variances. Ignored if ``split=False``. pointwise: bool, optional If True, the pointwise predictive accuracy will be returned. Defaults to ``rcParams["stats.ic_pointwise"]``. Moment matching always requires pointwise data from ``loo_orig``. This argument controls whether the returned object includes pointwise data. Returns ------- ELPDData Object with the following attributes: - **kind**: "loo" - **elpd**: expected log pointwise predictive density - **se**: standard error of the elpd - **p**: effective number of parameters - **n_samples**: number of samples - **n_data_points**: number of data points - **scale**: "log" - **warning**: True if the estimated shape parameter of Pareto distribution is greater than ``good_k``. - **good_k**: For a sample size S, the threshold is computed as ``min(1 - 1/log10(S), 0.7)`` - **elpd_i**: :class:`~xarray.DataArray` with the pointwise predictive accuracy, only if ``pointwise=True``. - **pareto_k**: :class:`~xarray.DataArray` with Pareto shape values, only if ``pointwise=True``. - **approx_posterior**: False (not used for standard LOO) - **log_weights**: Smoothed log weights. Examples -------- Moment matching can improve PSIS-LOO-CV estimates for observations with high Pareto k values without having to refit the model for each problematic observation. We will use the non-centered eight schools data which has 1 problematic observation. In practice, moment matching is useful when you have a potentially large number of problematic observations: .. ipython:: :okwarning: In [1]: import arviz_base as azb ...: import numpy as np ...: import xarray as xr ...: from scipy import stats ...: from arviz_stats import loo ...: ...: idata = azb.load_arviz_data("non_centered_eight") ...: posterior = idata.posterior ...: schools = posterior.theta_t.coords["school"].values ...: y_obs = idata.observed_data.obs ...: obs_dim = y_obs.dims[0] ...: ...: loo_orig = loo(idata, pointwise=True, var_name="obs") ...: loo_orig The moment matching algorithm applies affine transformations to posterior draws in unconstrained parameter space. To enable this, we need to collect the posterior parameters from their original space, transform them to unconstrained space if needed, and stack them into a single :class:`xarray.DataArray` that matches the expected ``(chain, draw, param)`` structure. Some parameters may already be in unconstrained space, so we don't need to transform them. This will depend on the model and the choice of parameterization: .. ipython:: :okwarning: In [2]: upars_ds = xr.Dataset( ...: { ...: **{ ...: f"theta_t_{school}": posterior.theta_t.sel(school=school, drop=True) ...: for school in schools ...: }, ...: "mu": posterior.mu, ...: "log_tau": xr.apply_ufunc(np.log, posterior.tau), ...: } ...: ) ...: upars = azb.dataset_to_dataarray( ...: upars_ds, sample_dims=["chain", "draw"], new_dim="upars_dim" ...: ) Moment matching requires two functions: one for the joint log probability (likelihood + priors) and another for the pointwise log-likelihood of a single observation. We first define functions that accept the data they need as keyword-only arguments: .. ipython:: :okwarning: In [3]: sigmas = xr.DataArray( ...: [15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0], ...: dims=[obs_dim], ...: ) ...: ...: def log_prob_upars(upars, *, sigmas, y, schools, obs_dim): ...: theta_t = xr.concat( ...: [upars.sel(upars_dim=f"theta_t_{school}") for school in schools], ...: dim=obs_dim, ...: ) ...: mu = upars.sel(upars_dim="mu") ...: log_tau = upars.sel(upars_dim="log_tau") ...: tau = xr.apply_ufunc(np.exp, log_tau) ...: theta = mu + tau * theta_t ...: ...: log_prior = xr.apply_ufunc(stats.norm(0, 5).logpdf, mu) ...: log_prior = log_prior + xr.apply_ufunc( ...: stats.halfcauchy(0, 5).logpdf, ...: tau, ...: ) ...: log_prior = log_prior + log_tau ...: log_prior = log_prior + xr.apply_ufunc( ...: stats.norm(0, 1).logpdf, ...: theta_t, ...: ).sum(obs_dim) ...: ...: const = -0.5 * np.log(2 * np.pi) ...: log_like = const - np.log(sigmas) - 0.5 * ((y - theta) / sigmas) ** 2 ...: log_like = log_like.sum(obs_dim) ...: return log_prior + log_like ...: ...: def log_lik_i_upars(upars, i, *, sigmas, y, schools, obs_dim): ...: mu = upars.sel(upars_dim="mu") ...: log_tau = upars.sel(upars_dim="log_tau") ...: tau = xr.apply_ufunc(np.exp, log_tau) ...: ...: theta_t_i = upars.sel(upars_dim=f"theta_t_{schools[i]}") ...: theta_i = mu + tau * theta_t_i ...: ...: sigma_i = sigmas.isel({obs_dim: i}) ...: y_i = y.isel({obs_dim: i}) ...: const = -0.5 * np.log(2 * np.pi) ...: return const - np.log(sigma_i) - 0.5 * ((y_i - theta_i) / sigma_i) ** 2 Now, we can specialise these functions with :func:`functools.partial` so the resulting functions match the signature expected by :func:`loo_moment_match()`: .. ipython:: :okwarning: In [4]: from functools import partial ...: log_prob_fn = partial( ...: log_prob_upars, ...: sigmas=sigmas, ...: y=y_obs, ...: schools=schools, ...: obs_dim=obs_dim, ...: ) ...: log_lik_i_fn = partial( ...: log_lik_i_upars, ...: sigmas=sigmas, ...: y=y_obs, ...: schools=schools, ...: obs_dim=obs_dim, ...: ) Finally, we can run moment matching using the prepared inputs. Now, we have no problematic observations anymore: .. ipython:: :okwarning: In [5]: from arviz_stats import loo_moment_match ...: loo_mm = loo_moment_match( ...: idata, ...: loo_orig, ...: upars=upars, ...: log_prob_upars_fn=log_prob_fn, ...: log_lik_i_upars_fn=log_lik_i_fn, ...: var_name="obs", ...: split=True, ...: ) ...: loo_mm Notes ----- The moment matching algorithm considers three affine transformations of the posterior draws: For a specific draw :math:`\theta^{(s)}`, a generic affine transformation includes a square matrix :math:`\mathbf{A}` representing a linear map and a vector :math:`\mathbf{b}` representing a translation such that .. math:: T : \theta^{(s)} \mapsto \mathbf{A}\theta^{(s)} + \mathbf{b} =: \theta^{*{(s)}}. The first transformation, :math:`T_1`, is a translation that matches the mean of the sample to its importance weighted mean given by .. math:: \mathbf{\theta^{*{(s)}}} = T_1(\mathbf{\theta^{(s)}}) = \mathbf{\theta^{(s)}} - \bar{\theta} + \bar{\theta}_w, where :math:`\bar{\theta}` is the mean of the sample and :math:`\bar{\theta}_w` is the importance weighted mean of the sample. The second transformation, :math:`T_2`, is a scaling that matches the marginal variances in addition to the means given by .. math:: \mathbf{\theta^{*{(s)}}} = T_2(\mathbf{\theta^{(s)}}) = \mathbf{v}^{1/2}_w \circ \mathbf{v}^{-1/2} \circ (\mathbf{\theta^{(s)}} - \bar{\theta}) + \bar{\theta}_w, where :math:`\mathbf{v}` and :math:`\mathbf{v}_w` are the sample and weighted variances, and :math:`\circ` denotes the pointwise product of the elements of two vectors. The third transformation, :math:`T_3`, is a covariance transformation that matches the covariance matrix of the sample to its importance weighted covariance matrix given by .. math:: \mathbf{\theta^{*{(s)}}} = T_3(\mathbf{\theta^{(s)}}) = \mathbf{L}_w \mathbf{L}^{-1} (\mathbf{\theta^{(s)}} - \bar{\theta}) + \bar{\theta}_w, where :math:`\mathbf{L}` and :math:`\mathbf{L}_w` are the Cholesky decompositions of the covariance matrix and the weighted covariance matrix, respectively, e.g., .. math:: \mathbf{LL}^T = \mathbf{\Sigma} = \frac{1}{S} \sum_{s=1}^S (\mathbf{\theta^{(s)}} - \bar{\theta}) (\mathbf{\theta^{(s)}} - \bar{\theta})^T and .. math:: \mathbf{L}_w \mathbf{L}_w^T = \mathbf{\Sigma}_w = \frac{\frac{1}{S} \sum_{s=1}^S w^{(s)} (\mathbf{\theta^{(s)}} - \bar{\theta}_w) (\mathbf{\theta^{(s)}} - \bar{\theta}_w)^T}{\sum_{s=1}^S w^{(s)}}. We iterate on :math:`T_1` repeatedly and move onto :math:`T_2` and :math:`T_3` only if :math:`T_1` fails to yield a Pareto-k statistic below the threshold. See Also -------- loo : Standard PSIS-LOO-CV. reloo : Exact re-fitting for problematic observations. References ---------- .. [1] Paananen, T., Piironen, J., Buerkner, P.-C., Vehtari, A. (2021). Implicitly Adaptive Importance Sampling. Statistics and Computing. 31(2) (2021) https://doi.org/10.1007/s11222-020-09982-2 arXiv preprint https://arxiv.org/abs/1906.08850. .. [2] Vehtari et al. *Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC*. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4 arXiv preprint https://arxiv.org/abs/1507.04544. .. [3] Vehtari et al. *Pareto Smoothed Importance Sampling*. Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html arXiv preprint https://arxiv.org/abs/1507.02646 """ if not isinstance(loo_orig, ELPDData): raise TypeError("loo_orig must be an ELPDData object.") if loo_orig.pareto_k is None or loo_orig.elpd_i is None: raise ValueError( "Moment matching requires pointwise LOO results with Pareto k values. " "Please compute the initial LOO with pointwise=True." ) sample_dims = ["chain", "draw"] if upars is None: if hasattr(data, "unconstrained_posterior"): upars_ds = azb.get_unconstrained_samples(data, return_dataset=True) upars = dataset_to_dataarray( upars_ds, sample_dims=sample_dims, new_dim="unconstrained_parameter" ) else: raise ValueError( "upars must be provided or data must contain an 'unconstrained_posterior' group." ) if not isinstance(upars, xr.DataArray): raise TypeError("upars must be a DataArray.") if not all(dim_name in upars.dims for dim_name in sample_dims): raise ValueError(f"upars must have dimensions {sample_dims}.") param_dim_list = [dim for dim in upars.dims if dim not in sample_dims] if len(param_dim_list) == 0: param_dim_name = "upars_dim" upars = upars.expand_dims(dim={param_dim_name: 1}) elif len(param_dim_list) == 1: param_dim_name = param_dim_list[0] else: raise ValueError("upars must have at most one dimension besides 'chain' and 'draw'.") loo_data = deepcopy(loo_orig) loo_data.method = "loo_moment_match" pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise loo_inputs = _prepare_loo_inputs(data, var_name) log_likelihood = loo_inputs.log_likelihood obs_dims = loo_inputs.obs_dims n_samples = loo_inputs.n_samples var_name = loo_inputs.var_name n_params = upars.sizes[param_dim_name] n_data_points = loo_orig.n_data_points if reff is None: reff = _get_r_eff(data, n_samples) try: orig_log_prob = log_prob_upars_fn(upars) if not isinstance(orig_log_prob, xr.DataArray): raise TypeError("log_prob_upars_fn must return a DataArray.") if not all(dim in orig_log_prob.dims for dim in sample_dims): raise ValueError(f"Original log probability must have dimensions {sample_dims}.") if len(orig_log_prob.dims) != len(sample_dims): raise ValueError( f"Original log probability should only have dimensions {sample_dims}, " f"found {orig_log_prob.dims}" ) except Exception as e: raise ValueError(f"Error executing log_prob_upars_fn: {e}") from e if k_threshold is None: k_threshold = min(1 - 1 / np.log10(n_samples), 0.7) if n_samples > 1 else 0.7 ks = loo_data.pareto_k.stack(__obs__=obs_dims).transpose("__obs__").values bad_obs_indices = np.where(ks > k_threshold)[0] if len(bad_obs_indices) == 0: warnings.warn("No Pareto k values exceed the threshold. Returning original LOO data.") if not pointwise: loo_data.elpd_i = None loo_data.pareto_k = None if hasattr(loo_data, "p_loo_i"): loo_data.p_loo_i = None return loo_data lpd = logsumexp(log_likelihood, dims=sample_dims, b=1 / n_samples) loo_data.p_loo_i = lpd - loo_data.elpd_i kfs = np.zeros(n_data_points) log_weights = getattr(loo_data, "log_weights", None) r_eff_data = getattr(loo_data, "r_eff", reff) # Moment matching algorithm for i in bad_obs_indices: mm_result = _loo_moment_match_i( i=i, upars=upars, log_likelihood=log_likelihood, log_prob_upars_fn=log_prob_upars_fn, log_lik_i_upars_fn=log_lik_i_upars_fn, max_iters=max_iters, k_threshold=k_threshold, split=split, cov=cov, orig_log_prob=orig_log_prob, ks=ks, log_weights=log_weights, pareto_k=loo_data.pareto_k, r_eff=r_eff_data, sample_dims=sample_dims, obs_dims=obs_dims, n_samples=n_samples, n_params=n_params, param_dim_name=param_dim_name, var_name=var_name, ) kfs[i] = mm_result.kfs_i if mm_result.final_ki < mm_result.original_ki: new_elpd_i = logsumexp( mm_result.final_log_liki + mm_result.final_lwi, dims=sample_dims ).item() original_log_liki = _get_log_likelihood_i(log_likelihood, i, obs_dims) _update_loo_data_i( loo_data, i, new_elpd_i, mm_result.final_ki, mm_result.final_log_liki, sample_dims, obs_dims, n_samples, original_log_liki, suppress_warnings=True, ) else: warnings.warn( f"Observation {i}: Moment matching did not improve k " f"({mm_result.original_ki:.2f} -> {mm_result.final_ki:.2f}). Reverting.", UserWarning, stacklevel=2, ) if hasattr(loo_orig, "p_loo_i") and loo_orig.p_loo_i is not None: if len(obs_dims) == 1: idx_dict = {obs_dims[0]: i} else: coords = np.unravel_index(i, tuple(loo_data.elpd_i.sizes[d] for d in obs_dims)) idx_dict = dict(zip(obs_dims, coords)) loo_data.p_loo_i[idx_dict] = loo_orig.p_loo_i[idx_dict] final_ks = loo_data.pareto_k.stack(__obs__=obs_dims).transpose("__obs__").values if np.any(final_ks[bad_obs_indices] > k_threshold): warnings.warn( f"After Moment Matching, {np.sum(final_ks > k_threshold)} observations still have " f"Pareto k > {k_threshold:.2f}.", UserWarning, stacklevel=2, ) if not split and np.any(kfs > k_threshold): warnings.warn( "The accuracy of self-normalized importance sampling may be bad. " "Setting the argument 'split' to 'True' will likely improve accuracy.", UserWarning, stacklevel=2, ) elpd_raw = logsumexp(log_likelihood, dims=sample_dims, b=1 / n_samples).sum().values loo_data.p = elpd_raw - loo_data.elpd if not pointwise: loo_data.elpd_i = None loo_data.pareto_k = None if hasattr(loo_data, "p_loo_i"): loo_data.p_loo_i = None return loo_data
def _split_moment_match( upars, cov, total_shift, total_scaling, total_mapping, i, reff, log_prob_upars_fn, log_lik_i_upars_fn, ): r"""Split moment matching importance sampling for PSIS-LOO-CV. Applies affine transformations based on the total moment matching transformation to half of the posterior draws, leaving the other half unchanged. These approximations to the leave-one-out posterior are then combined using multiple importance sampling. Based on the implicit adaptive importance sampling algorithm of [1]_ and the PSIS-LOO-CV method of [2]_ and [3]_. Parameters ---------- upars : DataArray A DataArray representing the posterior draws of the model parameters in the unconstrained space. Must contain the dimensions `chain` and `draw` and a final dimension representing the different unconstrained parameters. cov : bool Whether to match the full covariance matrix of the samples (True) or just the marginal variances (False). Using the full covariance is more computationally expensive. total_shift : ndarray Vector containing the total shift (translation) applied to the parameters. Shape should match the parameter dimension of ``upars``. total_scaling : ndarray Vector containing the total scaling factors for the marginal variances. Shape should match the parameter dimension of ``upars``. total_mapping : ndarray Square matrix representing the linear transformation applied to the covariance matrix. Shape should be (d, d) where d is the parameter dimension. i : int Index of the specific observation to be left out for computing leave-one-out likelihood. reff : float Relative MCMC efficiency, ``ess / n`` i.e. number of effective samples divided by the number of actual samples. log_prob_upars_fn : callable Function that computes the log probability density of the *full posterior* distribution evaluated at unconstrained parameter draws. The function signature is ``log_prob_upars_fn(upars)`` where ``upars`` is a :class:`~xarray.DataArray` of unconstrained parameter draws. It should return a :class:`~xarray.DataArray` with dimensions ``chain``, ``draw``. log_lik_i_upars_fn : callable Function that computes the log-likelihood of the *left-out observation* ``i`` evaluated at unconstrained parameter draws. The function signature is ``log_lik_i_upars_fn(upars, i)`` where ``upars`` is a :class:`~xarray.DataArray` of unconstrained parameter draws and ``i`` is the integer index of the observation. It should return a :class:`~xarray.DataArray` with dimensions ``chain``, ``draw``. Returns ------- SplitMomentMatch A namedtuple containing: - lwi: Updated log importance weights for each sample - lwfi: Updated log importance weights for full distribution - log_liki: Updated log likelihood values for the specific observation - reff: Relative MCMC efficiency (updated based on the split samples) References ---------- .. [1] Paananen, T., Piironen, J., Buerkner, P.-C., Vehtari, A. (2021). *Implicitly Adaptive Importance Sampling*. Statistics and Computing. 31(2) (2021) https://doi.org/10.1007/s11222-020-09982-2 arXiv preprint https://arxiv.org/abs/1906.08850. .. [2] Vehtari et al. *Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC*. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4 arXiv preprint https://arxiv.org/abs/1507.04544. .. [3] Vehtari et al. *Pareto Smoothed Importance Sampling*. Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html arXiv preprint https://arxiv.org/abs/1507.02646 """ if not isinstance(upars, xr.DataArray): raise TypeError("upars must be a DataArray.") sample_dims = ["chain", "draw"] param_dim_list = [dim for dim in upars.dims if dim not in sample_dims] if len(param_dim_list) != 1: raise ValueError("upars must have exactly one dimension besides chain and draw.") param_dim = param_dim_list[0] if not all(dim in upars.dims for dim in sample_dims): raise ValueError( f"Required sample dimensions {sample_dims} not found in upars dimensions {upars.dims}" ) dim = upars.sizes[param_dim] n_chains = upars.sizes["chain"] n_draws = upars.sizes["draw"] n_samples = n_chains * n_draws n_samples_half = n_samples // 2 stack_dims = ["draw", "chain"] upars_stacked = upars.stack(__sample__=stack_dims).transpose("__sample__", param_dim) mean_original = upars_stacked.mean(dim="__sample__") if total_shift is None or total_shift.size == 0: total_shift = np.zeros(dim) if total_scaling is None or total_scaling.size == 0: total_scaling = np.ones(dim) if total_mapping is None or total_mapping.size == 0: total_mapping = np.eye(dim) # Forward transformation upars_trans = upars_stacked - mean_original upars_trans = upars_trans * xr.DataArray(total_scaling, dims=param_dim) if cov and dim > 0: upars_trans = xr.DataArray( upars_trans.data @ total_mapping.T, coords=upars_trans.coords, dims=upars_trans.dims, ) # Inverse Transformation upars_trans = upars_trans + (xr.DataArray(total_shift, dims=param_dim) + mean_original) upars_trans_inv = upars_stacked - (xr.DataArray(total_shift, dims=param_dim) + mean_original) if cov and dim > 0: try: inv_mapping_t = np.linalg.inv(total_mapping.T) upars_trans_inv = xr.DataArray( upars_trans_inv.data @ inv_mapping_t, coords=upars_trans_inv.coords, dims=upars_trans_inv.dims, ) except np.linalg.LinAlgError: warnings.warn("Could not invert mapping matrix. Using identity.", UserWarning) upars_trans_inv = upars_trans_inv / xr.DataArray(total_scaling, dims=param_dim) upars_trans_inv = upars_trans_inv + (mean_original - xr.DataArray(total_shift, dims=param_dim)) upars_trans_half_stacked = upars_stacked.copy(deep=True) upars_trans_half_stacked.data[:n_samples_half, :] = upars_trans.data[:n_samples_half, :] upars_trans_half = upars_trans_half_stacked.unstack("__sample__").transpose( *reversed(stack_dims), param_dim ) upars_trans_half_inv_stacked = upars_stacked.copy(deep=True) upars_trans_half_inv_stacked.data[n_samples_half:, :] = upars_trans_inv.data[n_samples_half:, :] upars_trans_half_inv = upars_trans_half_inv_stacked.unstack("__sample__").transpose( *reversed(stack_dims), param_dim ) try: log_prob_half_trans = log_prob_upars_fn(upars_trans_half) log_prob_half_trans_inv = log_prob_upars_fn(upars_trans_half_inv) except Exception as e: raise ValueError( f"Could not compute log probabilities for transformed parameters: {e}" ) from e try: log_liki_half = log_lik_i_upars_fn(upars_trans_half, i) if not all(dim in log_liki_half.dims for dim in sample_dims) or len( log_liki_half.dims ) != len(sample_dims): raise ValueError( f"log_lik_i_upars_fn must return a DataArray with dimensions {sample_dims}" ) if ( log_liki_half.sizes["chain"] != upars.sizes["chain"] or log_liki_half.sizes["draw"] != upars.sizes["draw"] ): raise ValueError( "log_lik_i_upars_fn output shape does not match input sample dimensions" ) except Exception as e: raise ValueError(f"Could not compute log likelihood for observation {i}: {e}") from e log_jacobian_det = 0.0 if dim > 0: log_jacobian_det = -np.sum(np.log(total_scaling)) try: log_jacobian_det -= np.log(np.linalg.det(total_mapping)) except np.linalg.LinAlgError: log_jacobian_det -= np.inf log_prob_half_trans_inv_adj = log_prob_half_trans_inv + log_jacobian_det # Multiple importance sampling use_forward_log_prob = log_prob_half_trans > log_prob_half_trans_inv_adj raw_log_weights_half = -log_liki_half + log_prob_half_trans log_sum_terms = xr.where( use_forward_log_prob, log_prob_half_trans + xr.ufuncs.log1p(np.exp(log_prob_half_trans_inv_adj - log_prob_half_trans)), log_prob_half_trans_inv_adj + xr.ufuncs.log1p(np.exp(log_prob_half_trans - log_prob_half_trans_inv_adj)), ) raw_log_weights_half -= log_sum_terms raw_log_weights_half = xr.where(np.isnan(raw_log_weights_half), -np.inf, raw_log_weights_half) raw_log_weights_half = xr.where( np.isposinf(raw_log_weights_half), -np.inf, raw_log_weights_half ) # PSIS smoothing for half posterior lwi_psis_da, _ = _wrap__psislw(raw_log_weights_half, sample_dims, reff) lr_full = lwi_psis_da + log_liki_half lr_full = xr.where(np.isnan(lr_full) | (np.isinf(lr_full) & (lr_full > 0)), -np.inf, lr_full) # PSIS smoothing for full posterior lwfi_psis_da, _ = _wrap__psislw(lr_full, sample_dims, reff) n_chains = upars.sizes["chain"] if n_chains == 1: reff_updated = reff else: log_liki_half_1 = log_liki_half.isel( chain=slice(None), draw=slice(0, n_samples_half // n_chains) ) log_liki_half_2 = log_liki_half.isel( chain=slice(None), draw=slice(n_samples_half // n_chains, None) ) liki_half_1 = np.exp(log_liki_half_1) liki_half_2 = np.exp(log_liki_half_2) ess_1 = liki_half_1.azstats.ess(method="mean") ess_2 = liki_half_2.azstats.ess(method="mean") ess_1_value = ess_1.values if hasattr(ess_1, "values") else ess_1 ess_2_value = ess_2.values if hasattr(ess_2, "values") else ess_2 n_samples_1 = log_liki_half_1.size n_samples_2 = log_liki_half_2.size r_eff_1 = ess_1_value / n_samples_1 r_eff_2 = ess_2_value / n_samples_2 reff_updated = min(r_eff_1, r_eff_2) return SplitMomentMatch( lwi=lwi_psis_da, lwfi=lwfi_psis_da, log_liki=log_liki_half, reff=reff_updated, ) def _loo_moment_match_i( i, upars, log_likelihood, log_prob_upars_fn, log_lik_i_upars_fn, max_iters, k_threshold, split, cov, orig_log_prob, ks, log_weights, pareto_k, r_eff, sample_dims, obs_dims, n_samples, n_params, param_dim_name, var_name, ): """Compute moment matching for a single observation.""" n_chains = upars.sizes["chain"] n_draws = upars.sizes["draw"] log_liki = _get_log_likelihood_i(log_likelihood, i, obs_dims).squeeze(drop=True) if isinstance(r_eff, xr.DataArray): reff_i = _get_r_eff_i(r_eff, i, obs_dims) elif r_eff is not None: reff_i = r_eff else: liki = np.exp(log_liki) liki_reshaped = liki.values.reshape(n_chains, n_draws).T ess_val = ess(liki_reshaped, method="mean").item() reff_i = ess_val / n_samples if n_samples > 0 else 1.0 original_ki = ks[i] if log_weights is not None: log_weights_i, ki = _get_weights_and_k_i( log_weights=log_weights, pareto_k=pareto_k, i=i, obs_dims=obs_dims, sample_dims=sample_dims, data=log_likelihood, n_samples=n_samples, reff=reff_i, log_lik_i=log_liki, var_name=var_name, ) lwi = log_weights_i.squeeze(drop=True).transpose(*sample_dims).astype(np.float64) else: log_ratio_i_init = -log_liki lwi, ki = _wrap__psislw(log_ratio_i_init, sample_dims, reff_i) upars_i = upars.copy(deep=True) total_shift = np.zeros(upars_i.sizes[param_dim_name]) total_scaling = np.ones(upars_i.sizes[param_dim_name]) total_mapping = np.eye(upars_i.sizes[param_dim_name]) iterind = 1 transformations_applied = False kfs_i = 0 while iterind <= max_iters and ki > k_threshold: if iterind == max_iters: warnings.warn( f"Maximum number of moment matching iterations ({max_iters}) reached " f"for observation {i}. Final Pareto k is {ki:.2f}.", UserWarning, stacklevel=2, ) break # Try Mean Shift try: shift_res = _shift(upars_i, lwi) quantities_i = _update_quantities_i( shift_res.upars, i, orig_log_prob, log_prob_upars_fn, log_lik_i_upars_fn, reff_i, sample_dims, ) if quantities_i.ki < ki: ki = quantities_i.ki lwi = quantities_i.lwi log_liki = quantities_i.log_liki kfs_i = quantities_i.kfi upars_i = shift_res.upars total_shift = total_shift + shift_res.shift transformations_applied = True iterind += 1 continue # Restart, try mean shift again except RuntimeError as e: warnings.warn( f"Error during mean shift calculation for observation {i}: {e}. " "Stopping moment matching for this observation.", UserWarning, stacklevel=2, ) break # Try Scale Shift try: scale_res = _shift_and_scale(upars_i, lwi) quantities_i = _update_quantities_i( scale_res.upars, i, orig_log_prob, log_prob_upars_fn, log_lik_i_upars_fn, reff_i, sample_dims, ) if quantities_i.ki < ki: ki = quantities_i.ki lwi = quantities_i.lwi log_liki = quantities_i.log_liki kfs_i = quantities_i.kfi upars_i = scale_res.upars total_shift = total_shift + scale_res.shift total_scaling = total_scaling * scale_res.scaling transformations_applied = True iterind += 1 continue # Restart, try mean shift again except RuntimeError as e: warnings.warn( f"Error during scale shift calculation for observation {i}: {e}. " "Stopping moment matching for this observation.", UserWarning, stacklevel=2, ) break # Try Covariance Shift if cov and n_samples >= 10 * n_params: try: cov_res = _shift_and_cov(upars_i, lwi) quantities_i = _update_quantities_i( cov_res.upars, i, orig_log_prob, log_prob_upars_fn, log_lik_i_upars_fn, reff_i, sample_dims, ) if quantities_i.ki < ki: ki = quantities_i.ki lwi = quantities_i.lwi log_liki = quantities_i.log_liki kfs_i = quantities_i.kfi upars_i = cov_res.upars total_shift = total_shift + cov_res.shift total_mapping = cov_res.mapping @ total_mapping transformations_applied = True iterind += 1 continue # Restart, try mean shift again except RuntimeError as e: warnings.warn( f"Error during covariance shift calculation for observation {i}: {e}. " "Stopping moment matching for this observation.", UserWarning, stacklevel=2, ) break break if split and transformations_applied: try: split_res = _split_moment_match( upars=upars, cov=cov, total_shift=total_shift, total_scaling=total_scaling, total_mapping=total_mapping, i=i, reff=reff_i, log_prob_upars_fn=log_prob_upars_fn, log_lik_i_upars_fn=log_lik_i_upars_fn, ) final_log_liki = split_res.log_liki final_lwi = split_res.lwi final_ki = ki reff_i = split_res.reff except RuntimeError as e: warnings.warn( f"Error during split moment matching for observation {i}: {e}. " "Using non-split transformation result.", UserWarning, stacklevel=2, ) final_log_liki = log_liki final_lwi = lwi final_ki = ki else: final_log_liki = log_liki final_lwi = lwi final_ki = ki if transformations_applied: liki_final = np.exp(final_log_liki) liki_final_reshaped = liki_final.values.reshape(n_chains, n_draws).T ess_val_final = ess(liki_final_reshaped, method="mean").item() reff_i = ess_val_final / n_samples if n_samples > 0 else 1.0 return LooMomentMatchResult( final_log_liki=final_log_liki, final_lwi=final_lwi, final_ki=final_ki, kfs_i=kfs_i, reff_i=reff_i, original_ki=original_ki, i=i, ) def _update_loo_data_i( loo_data, i, new_elpd_i, new_pareto_k, log_liki, sample_dims, obs_dims, n_samples, original_log_liki=None, suppress_warnings=False, ): """Update the ELPDData object for a single observation.""" if loo_data.elpd_i is None or loo_data.pareto_k is None: raise ValueError("loo_data must contain pointwise elpd_i and pareto_k values.") lpd_i_log_lik = original_log_liki if original_log_liki is not None else log_liki lpd_i = logsumexp(lpd_i_log_lik, dims=sample_dims, b=1 / n_samples).item() p_loo_i = lpd_i - new_elpd_i if len(obs_dims) == 1: idx_dict = {obs_dims[0]: i} else: coords = np.unravel_index(i, tuple(loo_data.elpd_i.sizes[d] for d in obs_dims)) idx_dict = dict(zip(obs_dims, coords)) loo_data.elpd_i[idx_dict] = new_elpd_i loo_data.pareto_k[idx_dict] = new_pareto_k if not hasattr(loo_data, "p_loo_i") or loo_data.p_loo_i is None: loo_data.p_loo_i = xr.full_like(loo_data.elpd_i, np.nan) loo_data.p_loo_i[idx_dict] = p_loo_i loo_data.elpd = np.nansum(loo_data.elpd_i.values) loo_data.se = np.sqrt(loo_data.n_data_points * np.nanvar(loo_data.elpd_i.values, ddof=1)) loo_data.warning, loo_data.good_k = _warn_pareto_k( loo_data.pareto_k.values[~np.isnan(loo_data.pareto_k.values)], loo_data.n_samples, suppress=suppress_warnings, ) def _update_quantities_i( upars, i, orig_log_prob, log_prob_upars_fn, log_lik_i_upars_fn, reff_i, sample_dims, ): """Update the moment matching quantities for a single observation.""" log_prob_new = log_prob_upars_fn(upars) log_liki_new = log_lik_i_upars_fn(upars, i) log_ratio_i = -log_liki_new + log_prob_new - orig_log_prob lwi_new, ki_new = _wrap__psislw(log_ratio_i, sample_dims, reff_i) log_ratio_full = log_prob_new - orig_log_prob lwfi_new, kfi_new = _wrap__psislw(log_ratio_full, sample_dims, reff_i) return UpdateQuantities( lwi=lwi_new, lwfi=lwfi_new, ki=ki_new, kfi=kfi_new, log_liki=log_liki_new, ) def _wrap__psislw(log_weights, sample_dims, r_eff): """Apply PSIS smoothing over sample dimensions.""" if not isinstance(log_weights, xr.DataArray): raise TypeError("log_weights must be an xarray.DataArray") missing_dims = [dim for dim in sample_dims if dim not in log_weights.dims] if missing_dims: raise ValueError( f"All sample dimensions must be present in the input; missing {missing_dims}." ) other_dims = [dim for dim in log_weights.dims if dim not in sample_dims] if other_dims: raise ValueError( "_wrap__psislw expects `log_weights` to include only sample dimensions; " f"found extra dims {other_dims}." ) stacked = log_weights.stack(__sample__=sample_dims) stacked_for_psis = -stacked try: lw_stacked, k = stacked_for_psis.azstats.psislw(dim="__sample__", r_eff=r_eff) except ValueError as err: err_message = str(err) fallback_errors = ("All tail values are the same", "n_draws_tail must be at least 5") if not any(msg in err_message for msg in fallback_errors): raise log_norm = logsumexp(stacked, dims="__sample__") lw_stacked = stacked - log_norm k = np.inf lw = lw_stacked.unstack("__sample__").transpose(*log_weights.dims) if isinstance(k, xr.DataArray): if k.dims: raise ValueError("Unexpected dimensions on Pareto k output; expected scalar result.") k_val = k.item() elif isinstance(k, np.ndarray): if k.ndim != 0: raise ValueError("Unexpected array shape for Pareto k; expected scalar result.") k_val = k.item() else: try: k_val = k except (TypeError, ValueError) as exc: raise TypeError("Unable to convert PSIS tail index to float") from exc return lw, k_val