web-dev-qa-db-ja.com

ケラス実験の学習曲線をプロットする方法は?

ケラスを使用してRNNをトレーニングしていますが、データセットのサイズによって検証の精度がどのように変化するかを確認したいと思います。 Kerasの履歴オブジェクトには_val_acc_というリストがあり、エポックごとにそれぞれの検証セットの精度で追加されます( Googleグループの投稿へのリンク )。実行されたエポック数の_val_acc_の平均を取得し、それぞれのデータセットサイズに対してプロットしたいと思います。

質問:_val_acc_リストの要素を取得して、numpy.mean(val_acc)のような操作を実行するにはどうすればよいですか?


EDIT:@ runDOSrunが言ったように、_val_acc_ sの平均を取得することは意味がありません。最終的な_val_acc_の取得に焦点を当てましょう。

@nemoが提案したものを試しましたが、うまくいきませんでした。これが私が印刷したときに得たものです

model.fit(X_train, y_train, batch_size = 512, nb_Epoch = 5, validation_split = 0.05).__dict__

出力:

_{'model': <keras.models.Sequential object at 0x000000001F752A90>, 'params': {'verbose': 1, 'nb_Epoch': 5, 'batch_size': 512, 'metrics': ['loss', 'val_loss'], 'nb_sample': 1710, 'do_validation': True}, 'Epoch': [0, 1, 2, 3, 4], 'history': {'loss': [0.96936064512408959, 0.66933631673890948, 0.63404161288724303, 0.62268789783555867, 0.60833334699708819], 'val_loss': [0.84040999412536621, 0.75676006078720093, 0.73714292049407959, 0.71032363176345825, 0.71341043710708618]}}
_

私の履歴辞書には_val_acc_としてのリストがないことがわかりました。

質問:_val_acc_をhistory辞書に含める方法は?

10
akilat90

精度は目的関数ではなく(一般的な)メトリックであるため、精度値を取得するには、fit中に計算するように要求する必要があります。精度の計算が意味をなさない場合があるため、Kerasではデフォルトで有効になっていません。ただし、これは組み込みのメトリックであり、簡単に追加できます。

メトリックを追加するには、 metrics=['accuracy']パラメーターをmodel.compileに使用します

あなたの例では:

history = model.fit(X_train, y_train, batch_size = 512, 
          nb_Epoch = 5, validation_split = 0.05)

その後、history.history['val_acc']として検証精度にアクセスできます。

4
Neil Slater

averageの精度が最終的な精度よりも重要であると思うのはなぜですか?初期値によっては、平均がかなり誤解を招く可能性があります。平均は同じですが解釈が異なるさまざまな曲線を思いつくのは簡単です。

train_accval_accの完全な履歴をプロットして、RNNが特定のセットアップ内で正常に機能しているかどうかを判断します。また、サンプルサイズをN> 1にすることを忘れないでください。ランダム初期化はRNNに大きな影響を与える可能性があります。セットアップごとに、少なくともN = 10の異なる初期化を行って、異なるパフォーマンスが実際に設定サイズによって引き起こされることを確認してください。より良い/より悪い初期化によるものではありません。

3
runDOSrun

履歴オブジェクトは、モデルのfit() ting中に作成されます。詳細については、keras/engine/training.pyを参照してください。

モデルのhistory属性を使用して履歴にアクセスできます:model.history

モデルをフィッティングした後、属性を平均するだけです。

np.mean([v['val_acc'] for v in model.history])

指定するすべての出力のパターンはval_<your output name here>であることに注意してください。

2
nemo