aif360.sklearn.metrics.one_vs_rest

aif360.sklearn.metrics.one_vs_rest(func, y_true, y_pred=None, prot_attr=None, return_groups=False, **kwargs)[source]

Compute an arbitrary difference/ratio metric on all intersectional groups of the protected attributes provided in a one-vs-rest manner.

Parameters:
  • func (function) – A difference or ratio metric function from aif360.sklearn.metrics.

  • y_true (pandas.Series) – Outcome vector with protected attributes as index.

  • y_pred (array-like, optional) – Estimated outcomes.

  • prot_attr (array-like, keyword-only) – Protected attribute(s). If None, all protected attributes in y are used.

  • sample_weight (array-like, optional) – Sample weights passed through to func.

  • return_groups (bool) – Return group names in addition to metric values. Names are tuples of protected attribute values.

  • **kwargs – Additional keyword args to be passed through to func.

Returns:

list – List of metric values considering each intersectional group in turn as privileged and the rest as unprivileged.

tuple:

Metric values and their corresponding group names.

  • vals (list) – List of metric values considering each group in turn as privileged and the rest as unprivileged.

  • groups (numpy.ndarray) – Array of tuples containing unique intersectional groups derived from the provided protected attributes.

Examples

>>> X, y = fetch_german()
>>> v, k = one_vs_rest(statistical_parity_difference, y,
...                    prot_attr=['sex', 'age'], return_groups=True,
...                    pos_label='good')
>>> dict(zip(k, v))
{(0, 0): 0.16493748337323755,
 (0, 1): 0.0030679552078539674,
 (1, 0): 0.09643201542912239,
 (1, 1): -0.09833664609268755}
>>> from functools import partial
>>> from sklearn.metrics import accuracy_score
>>> from sklearn.linear_model import LogisticRegression
>>> y_pred = LogisticRegression(solver='liblinear').fit(X, y).predict(X)
>>> acc_diff = partial(difference, accuracy_score)
>>> one_vs_rest(acc_diff, y, y_pred, prot_attr=['sex', 'age'])
[0.11338121840915127,
 -0.013775118883264326,
 0.018450658952105403,
 -0.04119677790563869]