web-dev-qa-db-ja.com

Kerasでクラスごとの精度、再現率、F1スコアを取得する

Keras(2.1.5)のTensorFlowバックエンドを使用してニューラルネットワークをトレーニングし、ネットワークの出力としてCRFレイヤーを追加するためにkeras-contrib(2.0.8)ライブラリも使用しました。

NNを使​​用してテストセットで予測を行った後、各クラスの精度、再現率、およびf1スコアを取得する方法を知りたいです。

6
Haritz

トレーニングしたまったく同じモデルを構築する関数get_model()と、モデルの重みを含むHDF5ファイルを指すパスweights_pathがあるとします。

model = get_model()
model.load_weights(weights_path)

これにより、モデルが適切に読み込まれます。次に、テストデータのImageDataGeneratorを定義し、モデルを適合させて予測を取得するだけです。

# Path to your folder testing data
testing_folder = ""
# Image size (set up the image size used for training)
img_size = 256
# Batch size (you should tune it based on your memory)
batch_size = 16

val_datagen = ImageDataGenerator(
    rescale=1. / 255)
validation_generator = val_datagen.flow_from_directory(
    testing_folder,
    target_size=(img_size, img_size),
    batch_size=batch_size,
    shuffle=False,
    class_mode='categorical')

次に、model.predict_generator()メソッドを使用して、モデルにデータセット全体のすべての予測を生成させることができます。

# Number of steps corresponding to an Epoch
steps = 100
predictions = model.predict_generator(validation_generator, steps=steps)

最後に、sklearnパッケージのmetrics.confusion_matrix()メソッドを使用して、混乱行列を作成します。

val_preds = np.argmax(predictions, axis=-1)
val_trues = validation_generator.classes
cm = metrics.confusion_matrix(val_trues, val_preds)

または、sklearnからmetrics.precision_recall_fscore_support()メソッドを使用して、すべてのクラスのすべての精度、再現率、およびf1スコアを取得します(引数average=Noneは、すべてのクラスのメトリックを出力します)。

# label names
labels = validation_generator.class_indices.keys()
precisions, recall, f1_score, _ = metrics.precision_recall_fscore_support(val_trues, val_preds, labels=labels)

私はそれをテストしていませんが、これはあなたを助けると思います。

4
Ferran Parés

見て - sklearn.metrics.classification_report

from sklearn.metrics import classification_report

y_pred = model.predict(x_test)
print(classification_report(y_true, y_pred))

のようなものを与えます

             precision    recall  f1-score   support

    class 0       0.50      1.00      0.67         1
    class 1       0.00      0.00      0.00         1
    class 2       1.00      0.67      0.80         3

avg / total       0.70      0.60      0.61         5
1
Martin Thoma

私の 質問 を参照してください。はいの場合、それはあなたの質問への答えかもしれません。

0
yannis