web-dev-qa-db-ja.com

Tensorflow 2の各エポックの後の各クラスの再現率を計算します

Tensorflow 2のKeras APIを使用するモデルの各エポックの後に、各クラスのバイナリクラスとマルチクラス(1つのホットエンコード)の両方の分類シナリオで再現率を計算しようとしています。例えばバイナリ分類では、次のようなことができるようになりたいです

_import tensorflow as tf
model = tf.keras.Sequential()
model.add(...)
model.add(tf.keras.layers.Dense(1))

model.compile(metrics=[binary_recall(label=0), binary_recall(label=1)], ...)
history = model.fit(...)

plt.plot(history.history['binary_recall_0'])
plt.plot(history.history['binary_recall_1'])
plt.show()
_

またはマルチクラスのシナリオで私は何かをしたいのですが

_model = tf.keras.Sequential()
model.add(...)
model.add(tf.keras.layers.Dense(3))

model.compile(metrics=[recall(label=0), recall(label=1), recall(label=2)], ...)
history = model.fit(...)

plt.plot(history.history['recall_0'])
plt.plot(history.history['recall_1'])
plt.plot(history.history['recall_2'])
plt.show()
_

私は不均衡なデータセットの分類子に取り組んでおり、マイノリティクラスのリコールがどの時点で低下し始めたかを確認できるようにしたいと考えています。

私はここでマルチクラス分類子の特定のクラスの精度の実装を見つけました https://stackoverflow.com/a/41717938/373655 。私はこれを私が必要とするものに適応させようとしていますが、_keras.backend_はまだ私にはかなり異質なので、どんな助けも大いに感謝されます。

また、Keras metricsを使用できるかどうか(各バッチの最後に計算され、平均化されるため)またはKeras callbacksを使用する必要があるかどうかもわかりません(これは各エポックの終わり)。リコールには違いがないように思えます(例:8/10 == (3/5 + 5/5) / 2)が、Keras 2でリコールが削除されたので、何かが欠けている可能性があります( https:// github.com/keras-team/keras/issues/5794

編集-部分解(マルチクラス分類)@mujjigaのソリューションは、バイナリ分類とマルチクラス分類の両方で機能しますが、@ P-Gnが指摘したように、 tensorflow 2の Recall metric は、マルチクラス分類のためにこれをそのままサポートします。例えば.

_from tensorflow.keras.metrics import Recall

model = ...

model.compile(loss='categorical_crossentropy', metrics=[
    Recall(class_id=0, name='recall_0')
    Recall(class_id=1, name='recall_1')
    Recall(class_id=2, name='recall_2')
])

history = model.fit(...)

plt.plot(history.history['recall_2'])
plt.plot(history.history['val_recall_2'])
plt.show()
_
6
rob

TF2では、tf.keras.metrics.Recallが獲得しましたclass_idまさにそれを可能にするメンバー。 FashionMNISTの使用例:

import tensorflow as tf

(x_train, y_train), _ = tf.keras.datasets.fashion_mnist.load_data()
x_train = x_train[..., None].astype('float32') / 255
y_train = tf.keras.utils.to_categorical(y_train)

input_shape = x_train.shape[1:]
model = tf.keras.Sequential([
  tf.keras.layers.Conv2D(filters=64, kernel_size=2, padding='same', activation='relu', input_shape=input_shape),
  tf.keras.layers.MaxPool2D(pool_size=2),
  tf.keras.layers.Dropout(0.3),

  tf.keras.layers.Conv2D(filters=32, kernel_size=2, padding='same', activation='relu'),
  tf.keras.layers.MaxPool2D(pool_size=2),
  tf.keras.layers.Dropout(0.3),

  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(units=256, activation='relu'),
  tf.keras.layers.Dropout(0.5),
  tf.keras.layers.Dense(units=10, activation='softmax')])

model.compile(loss='categorical_crossentropy', optimizer='Adam',
  metrics=[tf.keras.metrics.Recall(class_id=i) for i in range(10)])
model.fit(x_train, y_train, batch_size=128, epochs=50)

TF 1.13では、tf.keras.metric.Recallにはこれがありませんclass_id引数ですが、サブクラス化によって追加できます(TF2のアルファリリースでは、いくぶん驚いたことに、不可能と思われるもの)。

class Recall(tf.keras.metrics.Recall):

  def __init__(self, *, class_id, **kwargs):
    super().__init__(**kwargs)
    self.class_id= class_id

  def update_state(self, y_true, y_pred, sample_weight=None):
    y_true = y_true[:, self.class_id]
    y_pred = tf.cast(tf.equal(
      tf.math.argmax(y_pred, axis=-1), self.class_id), dtype=tf.float32)
    return super().update_state(y_true, y_pred, sample_weight)
2
P-Gn