Source code for arviz_stats.bayes_factor
"""Bayes Factor using Savage-Dickey density ratio."""
import warnings
from arviz_base import convert_to_datatree
from numpy import finfo
from scipy.interpolate import interp1d
[docs]
def bayes_factor(data, var_names, ref_vals=0, return_ref_vals=False, prior=None):
    """
    Compute Bayes factor using Savage–Dickey ratio.
    Parameters
    ----------
    data : DataTree, or InferenceData
        The data object containing the posterior and optionally the prior distributions.
    var_names : str or list of str
        Names of the variables for which the bayes factor should be computed.
    ref_vals : float or list of float, default 0
        Reference value for each variable. Must match var_names in length if list.
    return_ref_vals : bool, default False
        If True, return the reference density values for the posterior and prior.
    prior : dict, optional
        Dictionary with prior distributions for each variable. If not provided,
        the prior will be taken from the `prior` group in the data object.
    Returns
    -------
    dict
        Dictionary with Bayes Factor values: BF10 and BF01 per variable.
    References
    ----------
    .. [1] Heck DW. *A caveat on the Savage-Dickey density ratio:
       The case of computing Bayes factors for regression parameters.*
       Br J Math Stat Psychol, 72. (2019) https://doi.org/10.1111/bmsp.12150
    Examples
    --------
    Compute Bayes factor for a home and intercept variable in a rugby dataset
    using a reference value of 0.15 for home and 3 for intercept.
    .. ipython::
        In [1]: from arviz_base import load_arviz_data
           ...: from arviz_stats import bayes_factor
           ...: dt = load_arviz_data("rugby")
           ...: bayes_factor(dt, var_names=["home", "intercept"], ref_vals=[0.15, 3])
    """
    data = convert_to_datatree(data)
    if isinstance(var_names, str):
        var_names = [var_names]
    if isinstance(ref_vals, int | float):
        ref_vals = [ref_vals] * len(var_names)
    if len(var_names) != len(ref_vals):
        raise ValueError("Length of var_names and ref_vals must match.")
    results = {}
    ref_density_vals = {}
    for var, ref_val in zip(var_names, ref_vals):
        posterior = data.posterior[var]
        prior = data.prior[var]
        if not isinstance(ref_val, int | float):
            raise ValueError(f"Reference value for variable '{var}' must be numerical")
        if ref_val > posterior.max() or ref_val < posterior.min():
            warnings.warn(
                f"Reference value {ref_val} for '{var}' is outside the posterior range. "
                "This may overstate evidence in favor of H1."
            )
        if ref_val > prior.max() or ref_val < prior.min():
            warnings.warn(
                f"Reference value {ref_val} for '{var}' is outside the prior range. "
                "Bayes factor computation is not reliable."
            )
        posterior_kde = posterior.azstats.kde(grid_len=512, circular=False)
        prior_kde = prior.azstats.kde(grid_len=512, circular=False)
        posterior_val = interp1d(
            posterior_kde.values[0],
            posterior_kde.values[1],
            bounds_error=False,
            fill_value=finfo("float").eps,
        )(ref_val).item()
        prior_val = interp1d(
            prior_kde.values[0],
            prior_kde.values[1],
            bounds_error=False,
            fill_value=finfo("float").eps,
        )(ref_val).item()
        if prior_val <= 0 or posterior_val <= 0:
            raise ValueError(
                f"Invalid KDE values at ref_val={ref_val}: "
                f"prior={prior_val}, posterior={posterior_val}"
            )
        bf_10 = prior_val / posterior_val
        bf_01 = 1 / bf_10
        results[var] = {"BF10": bf_10, "BF01": bf_01}
        if return_ref_vals:
            ref_density_vals[var] = {"prior": prior_val, "posterior": posterior_val}
    if return_ref_vals:
        return results, ref_density_vals
    return results