"""Collection of metrics for evaluating the performance of probabilistic models."""
from collections import namedtuple
import numpy as np
from arviz_base import convert_to_datatree, dataset_to_dataarray, extract, rcParams
from scipy.spatial import cKDTree
from scipy.stats import wasserstein_distance, wasserstein_distance_nd
from arviz_stats.base import array_stats
from arviz_stats.utils import round_num
[docs]
def r2_score(
    data,
    var_name=None,
    data_pair=None,
    summary=True,
    point_estimate=None,
    ci_kind=None,
    ci_prob=None,
    round_to="2g",
):
    """R² for Bayesian regression models.
    The R², or coefficient of determination, is defined as the proportion of variance
    in the data that is explained by the model. It is computed as the variance of the
    predicted values divided by the variance of the predicted values plus the variance
    of the residuals. For details of the Bayesian R² see [1]_.
    Parameters
    ----------
    data : DataTree or InferenceData
        Input data. It should contain the posterior and the log_likelihood groups.
    var_name : str
        Name of the variable to compute the R² for.
    data_pair : dict
        Dictionary with the first element containing the posterior predictive name
        and the second element containing the observed data variable name.
    summary: bool
        Whether to return a summary (default) or an array of R² samples.
        The summary is a Pandas' series with a point estimate and a credible interval
    point_estimate: str
        The point estimate to compute. If None, the default value is used.
        Defaults values are defined in rcParams["stats.point_estimate"]. Ignored if
        summary is False.
    ci_kind: str
        The kind of credible interval to compute. If None, the default value is used.
        Defaults values are defined in rcParams["stats.ci_kind"]. Ignored if
        summary is False.
    ci_prob: float
        The probability for the credible interval. If None, the default value is used.
        Defaults values are defined in rcParams["stats.ci_prob"]. Ignored if
        summary is False.
    round_to: int or str, optional
        If integer, number of decimal places to round the result. If string of the
        form '2g' number of significant digits to round the result. Defaults to '2g'.
        Use None to return raw numbers.
    Returns
    -------
    Namedtuple or array
    Examples
    --------
    Calculate R² samples for Bayesian regression models :
    .. ipython::
        In [1]: from arviz_stats import r2_score
           ...: from arviz_base import load_arviz_data
           ...: data = load_arviz_data('regression1d')
           ...: r2_score(data)
    References
    ----------
    .. [1] Gelman et al. *R-squared for Bayesian regression models*.
        The American Statistician. 73(3) (2019). https://doi.org/10.1080/00031305.2018.1549100
        preprint http://www.stat.columbia.edu/~gelman/research/published/bayes_R2_v3.pdf.
    """
    if point_estimate is None:
        point_estimate = rcParams["stats.point_estimate"]
    if ci_kind is None:
        ci_kind = rcParams["stats.ci_kind"]
    if ci_prob is None:
        ci_prob = rcParams["stats.ci_prob"]
    if data_pair is None:
        obs_var_name = var_name
        pred_var_name = var_name
    else:
        obs_var_name = list(data_pair.keys())[0]
        pred_var_name = list(data_pair.values())[0]
    y_true = extract(data, group="observed_data", var_names=obs_var_name, combined=False).values
    y_pred = extract(data, group="posterior_predictive", var_names=pred_var_name).values.T
    r_squared = array_stats.r2_score(y_true, y_pred)
    if summary:
        estimate = getattr(np, point_estimate)(r_squared).item()
        c_i = getattr(array_stats, ci_kind)(r_squared, ci_prob)
        r2_summary = namedtuple("R2", [point_estimate, f"{ci_kind}_lb", f"{ci_kind}_ub"])
        if (round_to is not None) and (round_to not in ("None", "none")):
            estimate = round_num(estimate, round_to)
            c_i = (round_num(c_i[0].item(), round_to), round_num(c_i[1].item(), round_to))
        return r2_summary(estimate, c_i[0], c_i[1])
    return r_squared 
[docs]
def metrics(data, kind="rmse", var_name=None, sample_dims=None, round_to="2g"):
    """
    Compute performace metrics.
    Currently supported metrics are mean absolute error, mean squared error and
    root mean squared error.
    For classification problems, accuracy and balanced accuracy are also supported.
    Parameters
    ----------
    data: DataTree or InferenceData
        It should contain groups `observed_data` and `posterior_predictive`.
    kind: str
        The kind of metric to compute. Available options are:
        - 'mae': mean absolute error.
        - 'mse': mean squared error.
        - 'rmse': root mean squared error. Default.
        - 'acc': classification accuracy.
        - 'acc_balanced': balanced classification accuracy.
    var_name: str, optional
        The name of the observed and predicted variable.
    sample_dims: iterable of hashable, optional
        Dimensions to be considered sample dimensions and are to be reduced.
        Default ``rcParams["data.sample_dims"]``.
    round_to: int or str, optional
        If integer, number of decimal places to round the result. If string of the
        form '2g' number of significant digits to round the result. Defaults to '2g'.
        Use None to return raw numbers.
    Returns
    -------
    estimate: namedtuple
        A namedtuple with the mean of the computed metric and its standard error.
    Examples
    --------
    Calculate root mean squared error
    .. ipython::
        In [1]: from arviz_stats import metrics
           ...: from arviz_base import load_arviz_data
           ...: dt = load_arviz_data("radon")
           ...: metrics(dt, kind="rmse")
    Calculate accuracy of a logistic regression model
    .. ipython::
        In [1]: dt = load_arviz_data("anes")
           ...: metrics(dt, kind="acc")
    Notes
    -----
    The computation of the metrics is done by first reducing the posterior predictive
    samples, this is done to mirror the computation of the metrics by the
    :func:`arviz_stats.loo_metrics` function, and hence make comparison easier to perform.
    """
    if sample_dims is None:
        sample_dims = rcParams["data.sample_dims"]
    if var_name is None:
        var_name = list(data.observed_data.data_vars.keys())[0]
    observed = data.observed_data[var_name]
    predicted = data.posterior_predictive[var_name].mean(dim=sample_dims)
    return _metrics(observed, predicted, kind, round_to) 
[docs]
def kl_divergence(
    data1,
    data2,
    group="posterior",
    var_names=None,
    sample_dims=None,
    num_samples=500,
    round_to="2g",
    random_seed=212480,
):
    """Compute the Kullback-Leibler (KL) divergence.
    The KL-divergence is a measure of how different two probability distributions are.
    It represents how much extra uncertainty are we introducing when we use one
    distribution to approximate another. The KL-divergence is not symmetric, thus
    changing the order of the `data1` and `data2` arguments will change the result.
    For details of the approximation used to the compute the KL-divergence see [1]_.
    Parameters
    ----------
    data1, data2 : DataArray, Dataset, DataTree, or InferenceData
    group : hashable, default "posterior"
        Group on which to compute the kl-divergence.
    var_names : str or list of str, optional
        Names of the variables for which the KL-divergence should be computed.
    sample_dims : iterable of hashable, optional
        Dimensions to be considered sample dimensions and are to be reduced.
        Default ``rcParams["data.sample_dims"]``.
    num_samples : int
        Number of samples to use for the distance calculation. Default is 500.
    round_to: int or str, optional
        If integer, number of decimal places to round the result. If string of the
        form '2g' number of significant digits to round the result. Defaults to '2g'.
        Use None to return raw numbers.
    random_seed : int
        Random seed for reproducibility. Use None for no seed.
    Returns
    -------
    KL-divergence : float
    Examples
    --------
    Calculate the KL-divergence between the posterior distributions
    for the variable mu in the centered and non-centered eight schools models
    .. ipython::
        In [1]: from arviz_stats import kl_divergence
           ...: from arviz_base import load_arviz_data
           ...: data1 = load_arviz_data('centered_eight')
           ...: data2 = load_arviz_data('non_centered_eight')
           ...: kl_divergence(data1, data2, var_names="mu")
    References
    ----------
    .. [1] F. Perez-Cruz, *Kullback-Leibler divergence estimation of continuous distributions*
        IEEE International Symposium on Information Theory. (2008)
        https://doi.org/10.1109/ISIT.2008.4595271.
        preprint https://www.tsc.uc3m.es/~fernando/bare_conf3.pdf
    """
    dist1, dist2 = _prepare_distribution_pair(
        data1,
        data2,
        group=group,
        var_names=var_names,
        sample_dims=sample_dims,
        num_samples=num_samples,
        random_seed=random_seed,
    )
    kl_d = _kld(dist1, dist2)
    if round_to is not None and round_to not in ("None", "none"):
        kl_d = round_num(kl_d, round_to)
    return kl_d 
[docs]
def wasserstein(
    data1,
    data2,
    group="posterior",
    var_names=None,
    sample_dims=None,
    joint=True,
    num_samples=500,
    round_to="2g",
    random_seed=212480,
):
    """Compute the Wasserstein-1 distance.
    The Wasserstein distance, also called the Earth mover’s distance or the optimal transport
    distance, is a similarity metric between two probability distributions [1]_.
    Parameters
    ----------
    data1, data2 : DataArray, Dataset, DataTree, or InferenceData
    group : hashable, default "posterior"
        Group on which to compute the Wasserstein distance.
    var_names : str or list of str, optional
        Names of the variables for which the Wasserstein distance should be computed.
    sample_dims : iterable of hashable, optional
        Dimensions to be considered sample dimensions and are to be reduced.
        Default ``rcParams["data.sample_dims"]``.
    joint : bool, default True
        Whether to compute Wasserstein distance for the joint distribution (True)
        or over the marginals (False)
    num_samples : int
        Number of samples to use for the distance calculation. Default is 500.
    round_to: int or str, optional
        If integer, number of decimal places to round the result. If string of the
        form '2g' number of significant digits to round the result. Defaults to '2g'.
        Use None to return raw numbers.
    random_seed : int
        Random seed for reproducibility. Use None for no seed.
    Returns
    -------
    wasserstein_distance : float
    Notes
    -----
    The computation is faster for the marginals (`joint=False`). This is equivalent to
    assume the marginals are independent, which usually is not the case.
    This function uses the :func:`scipy.stats.wasserstein_distance` for the computation of the
    marginals and :func:`scipy.stats.wasserstein_distance_nd` for the joint distribution.
    Examples
    --------
    Calculate the Wasserstein distance between the posterior distributions
    for the variable mu in the centered and non-centered eight schools models
    .. ipython::
        In [1]: from arviz_stats import wasserstein
           ...: from arviz_base import load_arviz_data
           ...: data1 = load_arviz_data('centered_eight')
           ...: data2 = load_arviz_data('non_centered_eight')
           ...: wasserstein(data1, data2, var_names="mu")
    References
    ----------
    .. [1] "Wasserstein metric",
           https://en.wikipedia.org/wiki/Wasserstein_metric
    """
    dist1, dist2 = _prepare_distribution_pair(
        data1,
        data2,
        group=group,
        var_names=var_names,
        sample_dims=sample_dims,
        num_samples=num_samples,
        random_seed=random_seed,
    )
    if joint:
        distance = wasserstein_distance_nd(dist1, dist2)
    else:
        distance = 0
        for var1, var2 in zip(dist1.T, dist2.T):
            distance += wasserstein_distance(var1, var2)
        distance = distance.item()
    if round_to is not None and round_to not in ("None", "none"):
        distance = round_num(distance, round_to)
    return distance 
def _prepare_distribution_pair(
    data1, data2, group, var_names, sample_dims, num_samples, random_seed
):
    """Prepare the distribution pair for metric calculations."""
    data1 = convert_to_datatree(data1)
    data2 = convert_to_datatree(data2)
    if sample_dims is None:
        sample_dims = rcParams["data.sample_dims"]
    dist1 = _extract_and_reindex(
        data1,
        group=group,
        var_names=var_names,
        sample_dims=sample_dims,
        num_samples=num_samples,
        random_seed=random_seed,
    )
    dist2 = _extract_and_reindex(
        data2,
        group=group,
        var_names=var_names,
        sample_dims=sample_dims,
        num_samples=num_samples,
        random_seed=random_seed,
    )
    shared_var_names = set(dist1.data_vars) & set(dist2.data_vars)
    if not shared_var_names:
        raise ValueError(
            "No shared variable names found between the two datasets. "
            "Ensure that both datasets contain variables with matching names."
        )
    if var_names is None:
        var_names = list(shared_var_names)
        dist1, dist2 = dist1[var_names], dist2[var_names]
    dist1 = dataset_to_dataarray(dist1, sample_dims=["sample"])
    dist2 = dataset_to_dataarray(dist2, sample_dims=["sample"])
    return dist1, dist2
def _extract_and_reindex(data, group, var_names, sample_dims, num_samples, random_seed):
    return (
        extract(
            data,
            group=group,
            sample_dims=sample_dims,
            var_names=var_names,
            num_samples=num_samples,
            random_seed=random_seed,
            keep_dataset=True,
        )
        .reset_index("sample")
        .drop_vars(sample_dims)
        .assign_coords(sample=range(num_samples))
    )
def _kld(ary0, ary1):
    """Kullback-Leibler divergence approximation.
    Compute KL-divergence using equation 14 from [1]_. Assumes both arrays
    are of the same shape.
    Parameters
    ----------
    ary0, ary1 : (N, M) array-like
        Samples of the input distributions. ``N`` represents the number of samples (e.g. posterior
        samples) and ``M`` the number of outputs (e.g. number of variables in the posterior)
    Returns
    -------
    float
        The Kullback-Leibler divergence between the two
        distributions.
    References
    ----------
    .. [1] F. Perez-Cruz, *Kullback-Leibler divergence estimation of continuous distributions*
        IEEE International Symposium on Information Theory. (2008)
        https://doi.org/10.1109/ISIT.2008.4595271.
        preprint https://www.tsc.uc3m.es/~fernando/bare_conf3.pdf
    """
    # for discrete data we need to smooth the samples to avoid numerical errors
    # here we are adding a small noise to all samples, differences should be negligible
    # but we may want to do something more sophisticated in the future
    rng = np.random.default_rng(0)
    ary0 = ary0 + rng.normal(0, ary0.std(axis=0) / 1e6, size=ary0.shape)
    ary1 = ary1 + rng.normal(0, ary1.std(axis=0) / 1e6, size=ary1.shape)
    samples, dim = ary0.shape
    # Build KD-trees for X and Y
    kd_tree_ary0 = cKDTree(ary0)
    kd_tree_ary1 = cKDTree(ary1)
    # first nearest neighbour distances of X to Y
    r_k, _ = kd_tree_ary1.query(ary0)
    # second nearest neighbour distances of X to X
    # we skip the trivial first nearest neighbour distance
    s_k = kd_tree_ary0.query(ary0, k=2)[0][:, 1]
    kl_div = (dim / samples) * np.sum(np.log(r_k / s_k)) + np.log(samples / (samples - 1))
    # Due to numerical errors and for very similar samples we can get negative values
    kl_div = max(0.0, kl_div.item())
    return kl_div
def _metrics(observed, predicted, kind, round_to):
    """Compute performance metrics.
    Parameters
    ----------
    observed: DataArray
        Observed data.
    predicted: DataArray
        Predicted data.
    kind: str
        The kind of metric to compute. Available options are:
        - 'mae': mean absolute error.
        - 'mse': mean squared error.
        - 'rmse': root mean squared error. Default.
        - 'acc': classification accuracy.
        - 'acc_balanced': balanced classification accuracy.
    round_to: int or str, optional
        If integer, number of decimal places to round the result. If string of the
        form '2g' number of significant digits to round the result. Defaults to '2g'.
        Use None to return raw numbers.
    Returns
    -------
    estimate: namedtuple
        A namedtuple with the mean of the computed metric and its standard error.
    """
    valid_kind = ["mae", "rmse", "mse", "acc", "acc_balanced"]
    if kind not in valid_kind:
        raise ValueError(f"kind must be one of {valid_kind}")
    estimate = namedtuple(kind, ["mean", "se"])
    mean, std_error = array_stats.metrics(observed, predicted, kind=kind)
    return estimate(round_num(mean, round_to), round_num(std_error, round_to))