かえるのプログラミングブログ

プログラミングでつまずいたところとその解決策などを書いていきます。

Confusion Matrix の復習をしてみた。 2018/12/18

かえるるる(@kaeru_nantoka)です。

今回は、私が今朝まで参加していた通称 PLAsTiccコンペ
PLAsTiCC Astronomical Classification | Kaggle)で大変お世話になった、
Confusion Matrix (sklearn.metrics.confusion_matrix — scikit-learn 0.20.1 documentation)

について復習したのでその備忘として書いていこうと思っております。

ソースコードは、私の github に残しているのでよろしければご覧になってください。

https://github.com/osuossu8/myEDAs/blob/master/plasticc/confusion_matrix.ipynb

*****

( i ) Confusion Matrix とはなんぞや

このような図を描画してくれる便利メソッドです。

f:id:kaeru_nantoka_py:20181218225533p:plain
PLAsTiCC コンペでの私の best model の Confusion Matrix

簡単に言うと分類問題で、test label と prediction label とで、何を何に分類しているのかが格子状になっていてひと目でわかる描画メソッドです。

PLAsTiCC コンペは たくさんある星のデータを15種類に分類するのが課題でした。

コンペ中は、
New Confusion matrix | Kaggle

というような discussion スレッドが立ち、上位陣のモデルでは何を正しく分類できており、何を正しく分類するのが難しいのかを知ることができ、大変参考になりました。

( ii ) 動機

今回の執筆の動機は、シンプルに言うとまた使えるように整理しておきたかったからです。
実はこの PLAsTiCC コンペで散々お世話になったこのメソッド、 コンペ期間中はコピペで、中身を理解せずに使用しておりました。
コピペのままじゃあ次につながらぬ。と言うことでおさらいがてらまとめてみようと思った次第です。

( iii ) 本題

はい、今回は定番中の定番、アヤメのデータセットを使って復習してみました。(一度やってみたかった)

アヤメのデータセットは、花弁や茎などの長さといった特徴から 3種類に分類する基本的なデータです。
分類に使用したモデルは XGBoost です。

import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

import matplotlib.pyplot as plt
import seaborn as sns 

import xgboost as xgb

from sklearn import datasets
from sklearn.model_selection import train_test_split

使用するライブラリを import して

# Iris データセットを読み込む
iris = datasets.load_iris()
X, y = iris.data, iris.target

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1111)

params = {
        # 多値分類問題
        'objective': 'multiclass',
        # クラス数は 3
        'num_class': 3,
}

# 上記のパラメータでモデルを学習する
clf = xgb.XGBClassifier(**params)
    
clf.fit(X_train, y_train)
    
# テストデータを予測する
y_pred = clf.predict(X_test)

# 精度 (Accuracy) を計算する
from sklearn.metrics import accuracy_score
print(accuracy_score(y_pred,y_test))

accuracy は 0.97 とかでした。

アヤメのデータセットは ndarray の形で与えられているので、DataFrame に変換しておきます。

y_test_df = pd.DataFrame(y_test)
print(y_test_df[0].unique())

y_pred_df = pd.DataFrame(y_pred)
print(y_pred_df[0].unique())


で主役の Confusion Matrix です。

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()


# http://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html

from sklearn.metrics import confusion_matrix

# Compute confusion matrix
cnf_matrix = confusion_matrix(y_test_df, y_pred_df) # 引数に先ほどの df を渡す...(a)
np.set_printoptions(precision=2)

class_names = list(y_pred_df[0].unique()) # 正解ラベルを定義 ...(b)

# Plot non-normalized confusion matrix
plt.figure(figsize=(5,5))
foo = plot_confusion_matrix(cnf_matrix, classes=class_names,normalize=True,
                      title='Confusion matrix') ...(a) (b) を渡す

これで以下のような画像が得られます。

f:id:kaeru_nantoka_py:20181218231632p:plain


1.00 となっている label 0 と label 1 は 100% 正しく分類できており、0.94 となっているところは、label 2
を正しく分類できた確率が 94% , 残りの 0.06 は、正解ラベル 2 のものを間違えて ラベル 1 と予測しちゃったのが 6% あるよ。という見方をします。


( iv ) まとめ

結局ソースコード自体は kernels のものを丸々 パクった 参考にした形になりましたが、
うまく基本データセットで応用できたのできちんと身についたかなと思います。

ありがとうございました。