web-dev-qa-db-ja.com

Kerasコールバックはチェックポイントの保存をスキップし続け、val_accがないと主張します

より大きなモデルをいくつか実行して、中間結果を試したいと思います。

したがって、私はチェックポイントを使用して、各エポックの後で最良のモデルを保存しようとします。

これは私のコードです:

model = Sequential()
model.add(LSTM(700, input_shape=(X_modified.shape[1], X_modified.shape[2]), return_sequences=True))
model.add(Dropout(0.2))
model.add(LSTM(700, return_sequences=True))
model.add(Dropout(0.2))
model.add(LSTM(700))
model.add(Dropout(0.2))
model.add(Dense(Y_modified.shape[1], activation='softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

# Save the checkpoint in the /output folder
filepath = "output/text-gen-best.hdf5"

# Keep only a single checkpoint, the best over test accuracy.
checkpoint = ModelCheckpoint(filepath,
                            monitor='val_acc',
                            verbose=1,
                            save_best_only=True,
                            mode='max')
model.fit(X_modified, Y_modified, epochs=100, batch_size=50, callbacks=[checkpoint])

しかし、私は最初のエポックの後でまだ警告を受けています:

/usr/local/lib/python3.6/site-packages/keras/callbacks.py:432: RuntimeWarning: Can save best model only with val_acc available, skipping.
  'skipping.' % (self.monitor), RuntimeWarning)

たす metrics=['accuracy']モデルへの他のSO質問(例 事前トレーニングされたVGG16モデルを使用しているときに重みを保存できません )解決策でしたが、ここでエラーはまだ残っています。

7
xentity

次のコードを使用してモデルをチェックポイントしようとしています

# Save the checkpoint in the /output folder
filepath = "output/text-gen-best.hdf5"

# Keep only a single checkpoint, the best over test accuracy.
checkpoint = ModelCheckpoint(filepath,
                            monitor='val_acc',
                            verbose=1,
                            save_best_only=True,
                            mode='max')

ModelCheckpointは引数monitorを考慮して、モデルを保存するかどうかを決定します。コードではval_accです。したがって、val_accに増加がある場合は、重みを保存します。

今あなたのフィットコードで、

model.fit(X_modified, Y_modified, epochs=100, batch_size=50, callbacks=[checkpoint])

検証データが提供されていません。 ModelCheckpointには、チェックするmonitor引数がないため、重みを保存できません。

val_accに基づいてチェックポイントを実行するには、このような検証データを提供する必要があります。

model.fit(X_modified, Y_modified, validation_data=(X_valid, y_valid), epochs=100, batch_size=50, callbacks=[checkpoint])

何らかの理由で検証データを使用したくない場合、チェックポイントを実装するには、ModelCheckpointを変更して、このようにaccまたはlossに基づいて機能するようにする必要があります。

# Save the checkpoint in the /output folder
filepath = "output/text-gen-best.hdf5"

# Keep only a single checkpoint, the best over test accuracy.
checkpoint = ModelCheckpoint(filepath,
                            monitor='acc',
                            verbose=1,
                            save_best_only=True,
                            mode='max')

modeminを使用する場合は、monitorlossに変更する必要があることに注意してください。

13
Sreeram TP

メトリックが欠落しているためではなく、検証データがないために欠落しています。 validation_dataパラメータを介してfitに追加するか、validation_splitを使用します。

1