"""Power-scaling sensitivity diagnostics."""
import logging
import numpy as np
import xarray as xr
from arviz_base import convert_to_datatree, dataset_to_dataframe, extract
from arviz_base.labels import BaseLabeller
from arviz_stats.utils import get_log_likelihood_dataset, get_log_prior
from arviz_stats.validate import validate_dims
_log = logging.getLogger(__name__)
labeller = BaseLabeller()
__all__ = ["psense", "psense_summary"]
[docs]
def psense(
    data,
    var_names=None,
    filter_vars=None,
    group="prior",
    coords=None,
    sample_dims=None,
    alphas=(0.99, 1.01),
    group_var_names=None,
    group_coords=None,
):
    """
    Compute power-scaling sensitivity values.
    Parameters
    ----------
    data : DataTree or InferenceData
        Input data. It should contain the posterior and the log_likelihood and/or log_prior groups.
    var_names : list of str, optional
        Names of posterior variables to include in the power scaling sensitivity diagnostic
    filter_vars: {None, "like", "regex"}, default None
        Used for `var_names` only.
        If ``None`` (default), interpret var_names as the real variables names.
        If "like", interpret var_names as substrings of the real variables names.
        If "regex", interpret var_names as regular expressions on the real variables names.
    group : {"prior", "likelihood"}, default "prior"
        If "likelihood", the pointsize log likelihood values are retrieved
        from the ``log_likelihood`` group and added together.
        If "prior", the log prior values are retrieved from the ``log_prior`` group.
    coords : dict, optional
        Coordinates defining a subset over the posterior. Only these variables will
        be used when computing the prior sensitivity.
    sample_dims : str or sequence of hashable, optional
        Dimensions to reduce. Defaults to ``rcParams["data.sample_dims"]``
    alphas : tuple
        Lower and upper alpha values for gradient calculation. Defaults to (0.99, 1.01).
    group_var_names : str, optional
        Name of the prior or log likelihood variables to use
    group_coords : dict, optional
        Coordinates defining a subset over the group element for which to
        compute the prior sensitivity diagnostic.
    Returns
    -------
    xarray.DataTree
        Returns dataTree of power-scaling sensitivity diagnostic values.
        Higher sensitivity values indicate greater sensitivity.
        Prior sensitivity above 0.05 indicates informative prior.
        Likelihood sensitivity below 0.05 indicates weak or non-informative likelihood.
    Notes
    -----
    The diagnostic is computed by power-scaling either the prior or likelihood
    and determining the degree to which the posterior changes as described in [1]_.
    It uses Pareto-smoothed importance sampling to avoid refitting the model.
    References
    ----------
    .. [1] Kallioinen et al, *Detecting and diagnosing prior and likelihood sensitivity with
       power-scaling*, Stat Comput 34, 57 (2024), https://doi.org/10.1007/s11222-023-10366-5
    """
    data = convert_to_datatree(data)
    dataset = extract(
        data,
        var_names=var_names,
        filter_vars=filter_vars,
        group="posterior",
        combined=False,
        keep_dataset=True,
    )
    if coords is not None:
        dataset = dataset.sel(coords)
    lower_w, upper_w = _get_power_scale_weights(
        data,
        alphas=alphas,
        group=group,
        sample_dims=sample_dims,
        group_var_names=group_var_names,
        group_coords=group_coords,
    )
    return dataset.azstats.power_scale_sense(
        lower_w=lower_w,
        upper_w=upper_w,
        lower_alpha=alphas[0],
        upper_alpha=alphas[1],
        sample_dims=sample_dims,
    ) 
[docs]
def psense_summary(
    data,
    var_names=None,
    filter_vars=None,
    coords=None,
    sample_dims=None,
    threshold=0.05,
    alphas=(0.99, 1.01),
    prior_var_names=None,
    likelihood_var_names=None,
    prior_coords=None,
    likelihood_coords=None,
    round_to=3,
):
    """
    Compute the prior/likelihood sensitivity based on power-scaling perturbations.
    Parameters
    ----------
    data : DataTree or InferenceData
        Input data. It should contain the posterior and the log_likelihood and/or log_prior groups.
    var_names : list of str, optional
        Names of posterior variables to include in the power scaling sensitivity diagnostic
    filter_vars: {None, "like", "regex"}, default None
        Used for `var_names` only.
        If ``None`` (default), interpret var_names as the real variables names.
        If "like", interpret var_names as substrings of the real variables names.
        If "regex", interpret var_names as regular expressions on the real variables names.
    coords : dict, optional
        Coordinates defining a subset over the posterior. Only these variables will
        be used when computing the prior sensitivity.
    sample_dims : str or sequence of hashable, optional
        Dimensions to reduce unless mapped to an aesthetic.
        Defaults to ``rcParams["data.sample_dims"]``
    threshold : float, optional
        Threshold value to determine the sensitivity diagnosis. Default is 0.05.
    alphas : tuple
        Lower and upper alpha values for gradient calculation. Defaults to (0.99, 1.01).
    prior_var_names : str, optional
        Name of the log-prior variables to include in the power scaling sensitivity diagnostic
    likelihood_var_names : str, optional
        Name of the log-likelihood variables to include in the power scaling sensitivity diagnostic
    prior_coords : dict, optional
        Coordinates defining a subset over the group element for which to
        compute the log-prior sensitivity diagnostic
    likelihood_coords : dict, optional
        Coordinates defining a subset over the group element for which to
        compute the log-likelihood sensitivity diagnostic
    round_to : int, optional
        Number of decimal places to round the sensitivity values. Default is 3.
    Returns
    -------
    psense_df : DataFrame
        DataFrame containing the prior and likelihood sensitivity values for each variable
        in the data. And a diagnosis column with the following values:
        - "prior-data conflict" if both prior and likelihood sensitivity are above threshold
        - "strong prior / weak likelihood" if the prior sensitivity is above threshold
        and the likelihood sensitivity is below the threshold
        - "-" otherwise
    Examples
    --------
    .. ipython::
        In [1]: from arviz_base import load_arviz_data
           ...: from arviz_stats import psense_summary
           ...: rugby = load_arviz_data("rugby")
           ...: psense_summary(rugby, var_names="atts")
    Notes
    -----
    The diagnostic is computed by power-scaling either the prior or likelihood
    and determining the degree to which the posterior changes as described in [1]_.
    It uses Pareto-smoothed importance sampling to avoid refitting the model.
    References
    ----------
    .. [1] Kallioinen et al, *Detecting and diagnosing prior and likelihood sensitivity with
       power-scaling*, Stat Comput 34, 57 (2024), https://doi.org/10.1007/s11222-023-10366-5
    """
    pssdp = psense(
        data,
        var_names=var_names,
        filter_vars=filter_vars,
        group="prior",
        sample_dims=sample_dims,
        coords=coords,
        alphas=alphas,
        group_var_names=prior_var_names,
        group_coords=prior_coords,
    )
    pssdl = psense(
        data,
        var_names=var_names,
        filter_vars=filter_vars,
        group="likelihood",
        coords=coords,
        sample_dims=sample_dims,
        alphas=alphas,
        group_var_names=likelihood_var_names,
        group_coords=likelihood_coords,
    )
    joined = xr.concat([pssdp, pssdl], dim="component").assign_coords(
        component=["prior", "likelihood"]
    )
    psense_df = dataset_to_dataframe(joined, sample_dims=["component"]).T
    def _diagnose(row):
        if row["prior"] >= threshold and row["likelihood"] >= threshold:
            return "potential prior-data conflict"
        if row["prior"] > threshold > row["likelihood"]:
            return "potential strong prior / weak likelihood"
        return "✓"
    psense_df["diagnosis"] = psense_df.apply(_diagnose, axis=1)
    if "potential" in "".join(psense_df["diagnosis"]):
        _log.warning(
            "We detected potential issues. For more information on how to interpret the results, "
            "please check\n"
            "https://arviz-devs.github.io/EABM/Chapters/"
            "Sensitivity_checks.html#interpreting-sensitivity-diagnostics-summary\n"
            "or read original paper https://doi.org/10.1007/s11222-023-10366-5"
        )
    return psense_df.round(round_to) 
def power_scale_dataset(data, group, alphas, sample_dims, group_var_names, group_coords):
    """Resample posterior based on power-scaled weights.
    Parameters
    ----------
    data : DataTree or InferenceData
        Input data. It should contain the posterior and the log_likelihood and/or log_prior groups.
    group : str
        Group to resample. Either "prior" or "likelihood"
    alphas : tuple of float
        Lower and upper alpha values for power scaling.
    sample_dims : str or sequence of hashable
        Dimensions to reduce unless mapped to an aesthetic.
    group_var_names : str
        Name of the log-prior or log-likelihood variables to use.
    group_coords : dict
        Coordinates defining a subset over the group element for which to
        compute the sensitivity diagnostic.
    Returns
    -------
    DataSet with resampled data.
    """
    dt = convert_to_datatree(data)
    lower_w, upper_w = _get_power_scale_weights(
        dt,
        alphas,
        group=group,
        sample_dims=sample_dims,
        group_var_names=group_var_names,
        group_coords=group_coords,
    )
    lower_w = lower_w.values.flatten()
    upper_w = upper_w.values.flatten()
    s_size = len(lower_w)
    idxs_to_drop = sample_dims if len(sample_dims) == 1 else ["sample"] + sample_dims
    idxs_to_drop = set(idxs_to_drop).union(
        [
            idx
            for idx in dt["posterior"].xindexes
            if any(dim in dt["posterior"][idx].dims for dim in sample_dims)
        ]
    )
    resampled = [
        extract(
            dt,
            group="posterior",
            sample_dims=sample_dims,
            num_samples=s_size,
            weights=weights,
            random_seed=42,
            resampling_method="stratified",
        ).drop_indexes(idxs_to_drop)
        for weights in (lower_w, upper_w)
    ]
    resampled.insert(
        1, extract(dt, group="posterior", sample_dims=sample_dims).drop_indexes(idxs_to_drop)
    )
    return xr.concat(resampled, dim="alpha").assign_coords(alpha=[alphas[0], 1, alphas[1]])
def _get_power_scale_weights(
    dt, alphas=None, group=None, sample_dims=None, group_var_names=None, group_coords=None
):
    """Compute power scale weights."""
    sample_dims = validate_dims(sample_dims)
    if group == "likelihood":
        group_draws = get_log_likelihood_dataset(dt, var_names=group_var_names)
    elif group == "prior":
        group_draws = get_log_prior(dt, var_names=group_var_names)
    else:
        raise ValueError("Value for `group` argument not recognized")
    if group_coords is not None:
        group_draws = group_draws.sel(group_coords)
    # we stack the different variables (if any) and dimensions in each variable (if any)
    # into a flat dimension "latent-obs_var", over which we sum afterwards.
    # Consequently, after this group_draws draws is a dataarray with only sample_dims as dims
    group_draws = group_draws.to_stacked_array("latent-obs_var", sample_dims=sample_dims).sum(
        "latent-obs_var"
    )
    # calculate importance sampling weights for lower and upper alpha power-scaling
    lower_w = np.exp(group_draws.azstats.power_scale_lw(alpha=alphas[0], dim=sample_dims))
    lower_w = lower_w / lower_w.sum(sample_dims)
    upper_w = np.exp(group_draws.azstats.power_scale_lw(alpha=alphas[1], dim=sample_dims))
    upper_w = upper_w / upper_w.sum(sample_dims)
    return lower_w, upper_w