web-dev-qa-db-ja.com

適合率と再現率のKerasカスタム決定しきい値

KerasTensorflowバックエンドを使用)を使用してバイナリ分類を行っており、約76%の精度と70%のリコールがあります。今、私は決定のしきい値で遊んでみたいと思います。私の知る限り、Kerasは決定しきい値0.5を使用します。 Kerasに、決定の精度と再現率にカスタムしきい値を使用する方法はありますか?

お時間をいただきありがとうございます!

11
nabroyan

次のようなカスタムメトリックを作成します。

@ Marcinのおかげで編集:引数としてthreshold_valueを使用して目的のメトリックを返す関数を作成します

def precision_threshold(threshold=0.5):
    def precision(y_true, y_pred):
        """Precision metric.
        Computes the precision over the whole batch using threshold_value.
        """
        threshold_value = threshold
        # Adaptation of the "round()" used before to get the predictions. Clipping to make sure that the predicted raw values are between 0 and 1.
        y_pred = K.cast(K.greater(K.clip(y_pred, 0, 1), threshold_value), K.floatx())
        # Compute the number of true positives. Rounding in prevention to make sure we have an integer.
        true_positives = K.round(K.sum(K.clip(y_true * y_pred, 0, 1)))
        # count the predicted positives
        predicted_positives = K.sum(y_pred)
        # Get the precision ratio
        precision_ratio = true_positives / (predicted_positives + K.epsilon())
        return precision_ratio
    return precision

def recall_threshold(threshold = 0.5):
    def recall(y_true, y_pred):
        """Recall metric.
        Computes the recall over the whole batch using threshold_value.
        """
        threshold_value = threshold
        # Adaptation of the "round()" used before to get the predictions. Clipping to make sure that the predicted raw values are between 0 and 1.
        y_pred = K.cast(K.greater(K.clip(y_pred, 0, 1), threshold_value), K.floatx())
        # Compute the number of true positives. Rounding in prevention to make sure we have an integer.
        true_positives = K.round(K.sum(K.clip(y_true * y_pred, 0, 1)))
        # Compute the number of positive targets.
        possible_positives = K.sum(K.clip(y_true, 0, 1))
        recall_ratio = true_positives / (possible_positives + K.epsilon())
        return recall_ratio
    return recall

今、あなたはそれらをで使うことができます

model.compile(..., metrics = [precision_threshold(0.1), precision_threshold(0.2),precision_threshold(0.8), recall_threshold(0.2,...)])

これがお役に立てば幸いです:)

20
Nassim Ben