

.. _sphx_glr_auto_examples_ensemble_plot_comparison_bagging_classifier.py:


=========================================================
Comparison of balanced and imbalanced bagging classifiers
=========================================================

This example shows the benefit of balancing the training set when using a
bagging classifier. ``BalancedBaggingClassifier`` chains a
``RandomUnderSampler`` and a given classifier while ``BaggingClassifier`` is
using directly the imbalanced data.

Balancing the data set before training the classifier improve the
classification performance. In addition, it avoids the ensemble to focus on the
majority class which would be a known drawback of the decision tree
classifiers.





.. rst-class:: sphx-glr-horizontal


    *

      .. image:: /auto_examples/ensemble/images/sphx_glr_plot_comparison_bagging_classifier_001.png
            :scale: 47

    *

      .. image:: /auto_examples/ensemble/images/sphx_glr_plot_comparison_bagging_classifier_002.png
            :scale: 47


.. rst-class:: sphx-glr-script-out

 Out::

    Class distribution of the training set: Counter({2: 36, 1: 33, 0: 17})
    Class distribution of the test set: Counter({2: 14, 0: 8, 1: 7})
    Classification results using a bagging classifier on imbalanced data
                       pre       rec       spe        f1       geo       iba       sup

              0       0.00      0.00      1.00      0.00      0.00      0.00         8
              1       0.28      1.00      0.18      0.44      0.53      0.26         7
              2       1.00      0.29      1.00      0.44      0.77      0.62        14

    avg / total       0.55      0.38      0.80      0.32      0.50      0.36        29

    Confusion matrix, without normalization
    [[ 0  8  0]
     [ 0  7  0]
     [ 0 10  4]]
    Classification results using a bagging classifier on balanced data
                       pre       rec       spe        f1       geo       iba       sup

              0       1.00      1.00      1.00      1.00      1.00      1.00         8
              1       0.86      0.86      0.95      0.86      0.90      0.81         7
              2       0.93      0.93      0.93      0.93      0.93      0.87        14

    avg / total       0.93      0.93      0.96      0.93      0.94      0.89        29

    Confusion matrix, without normalization
    [[ 8  0  0]
     [ 0  6  1]
     [ 0  1 13]]




|


.. code-block:: python


    # Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
    # License: MIT

    from collections import Counter
    import itertools

    import matplotlib.pyplot as plt
    import numpy as np

    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split
    from sklearn.ensemble import BaggingClassifier
    from sklearn.metrics import confusion_matrix

    from imblearn.datasets import make_imbalance
    from imblearn.ensemble import BalancedBaggingClassifier

    from imblearn.metrics import classification_report_imbalanced


    def plot_confusion_matrix(cm, classes,
                              normalize=False,
                              title='Confusion matrix',
                              cmap=plt.cm.Blues):
        """
        This function prints and plots the confusion matrix.
        Normalization can be applied by setting `normalize=True`.
        """
        if normalize:
            cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
            print("Normalized confusion matrix")
        else:
            print('Confusion matrix, without normalization')

        print(cm)

        plt.imshow(cm, interpolation='nearest', cmap=cmap)
        plt.title(title)
        plt.colorbar()
        tick_marks = np.arange(len(classes))
        plt.xticks(tick_marks, classes, rotation=45)
        plt.yticks(tick_marks, classes)

        fmt = '.2f' if normalize else 'd'
        thresh = cm.max() / 2.
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            plt.text(j, i, format(cm[i, j], fmt),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")

        plt.tight_layout()
        plt.ylabel('True label')
        plt.xlabel('Predicted label')


    iris = load_iris()
    X, y = make_imbalance(iris.data, iris.target, ratio={0: 25, 1: 40, 2: 50},
                          random_state=0)
    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

    bagging = BaggingClassifier(random_state=0)
    balanced_bagging = BalancedBaggingClassifier(random_state=0)

    print('Class distribution of the training set: {}'.format(Counter(y_train)))

    bagging.fit(X_train, y_train)
    balanced_bagging.fit(X_train, y_train)

    print('Class distribution of the test set: {}'.format(Counter(y_test)))

    print('Classification results using a bagging classifier on imbalanced data')
    y_pred_bagging = bagging.predict(X_test)
    print(classification_report_imbalanced(y_test, y_pred_bagging))
    cm_bagging = confusion_matrix(y_test, y_pred_bagging)
    plt.figure()
    plot_confusion_matrix(cm_bagging, classes=iris.target_names,
                          title='Confusion matrix using BaggingClassifier')

    print('Classification results using a bagging classifier on balanced data')
    y_pred_balanced_bagging = balanced_bagging.predict(X_test)
    print(classification_report_imbalanced(y_test, y_pred_balanced_bagging))
    cm_balanced_bagging = confusion_matrix(y_test, y_pred_balanced_bagging)
    plt.figure()
    plot_confusion_matrix(cm_balanced_bagging, classes=iris.target_names,
                          title='Confusion matrix using BalancedBaggingClassifier')

    plt.show()

**Total running time of the script:** ( 0 minutes  0.305 seconds)



.. container:: sphx-glr-footer


  .. container:: sphx-glr-download

     :download:`Download Python source code: plot_comparison_bagging_classifier.py <plot_comparison_bagging_classifier.py>`



  .. container:: sphx-glr-download

     :download:`Download Jupyter notebook: plot_comparison_bagging_classifier.ipynb <plot_comparison_bagging_classifier.ipynb>`

.. rst-class:: sphx-glr-signature

    `Generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_
