Source code for aif360.algorithms.inprocessing.art_classifier

import numpy as np

from aif360.datasets import BinaryLabelDataset
from aif360.algorithms import Transformer


[docs]class ARTClassifier(Transformer): """Wraps an instance of an :obj:`art.classifiers.Classifier` to extend :obj:`~aif360.algorithms.Transformer`. """ def __init__(self, art_classifier): """Initialize ARTClassifier. Args: art_classifier (art.classifier.Classifier): A Classifier object from the `adversarial-robustness-toolbox`_. .. _adversarial-robustness-toolbox: https://github.com/IBM/adversarial-robustness-toolbox """ super(ARTClassifier, self).__init__(art_classifier=art_classifier) self._art_classifier = art_classifier
[docs] def fit(self, dataset, batch_size=128, nb_epochs=20): """Train a classifer on the input. Args: dataset (Dataset): Training dataset. batch_size (int): Size of batches (passed through to ART). nb_epochs (int): Number of epochs to use for training (passed through to ART). Returns: ARTClassifier: Returns self. """ self._art_classifier.fit(dataset.features, dataset.labels, batch_size=batch_size, nb_epochs=nb_epochs) return self
[docs] def predict(self, dataset, logits=False): """Perform prediction for the input. Args: dataset (Dataset): Test dataset. logits (bool, optional): True is prediction should be done at the logits layer (passed through to ART). Returns: Dataset: Dataset with predicted labels in the `labels` field. """ pred_labels = self._art_classifier.predict(dataset.features, dataset.labels, logits=logits) if isinstance(dataset, BinaryLabelDataset): pred_labels = np.argmax(pred_labels, axis=1).reshape((-1, 1)) pred_dataset = dataset.copy() pred_dataset.labels = pred_labels return pred_dataset