web-dev-qa-db-ja.com

テンソルフローにおけるLSTMの正則化

TensorflowはNiceLSTMラッパーを提供します。

rnn_cell.BasicLSTM(num_units, forget_bias=1.0, input_size=None,
           state_is_Tuple=False, activation=tanh)

正則化、たとえばL2正則化を使用したいと思います。ただし、LSTMセルで使用されているさまざまな重み行列に直接アクセスできないため、次のようなことを明示的に行うことはできません。

loss = something + beta * tf.reduce_sum(tf.nn.l2_loss(weights))

行列にアクセスしたり、LSTMで正則化を使用したりする方法はありますか?

12
BiBi

tf.trainable_variables は、L2正則化用語を追加するために使用できるVariableオブジェクトのリストを提供します。これにより、モデル内のすべての変数に正則化が追加されることに注意してください。 L2項を重みのサブセットのみに制限する場合は、 name_scope を使用して変数に特定のプレフィックスを付け、後でそれを使用して変数をtf.trainable_variablesによって返されるリスト。

10
keveman

私は次のことをするのが好きですが、私が知っている唯一のことは、バッチノルムパラメーターやバイアスなど、一部のパラメーターはL2で正則化されたくないということです。 LSTMには1つのバイアステンソルが含まれています(概念的には多くのバイアスがありますが、パフォーマンスのために連結されているようです)。バッチの正規化では、変数名に「noreg」を追加して無視します。

_loss = your regular output loss
l2 = lambda_l2_reg * sum(
    tf.nn.l2_loss(tf_var)
        for tf_var in tf.trainable_variables()
        if not ("noreg" in tf_var.name or "Bias" in tf_var.name)
)
loss += l2
_

ここで、_lambda_l2_reg_は小さな乗数です。例:float(0.005)

この選択(正則化でいくつかの変数を破棄するループ内の完全なif)を実行すると、一度に0.879 F1スコアから0.890にジャンプしました構成のlambdaの値を再調整せずにコードをテストした場合、これにはバッチ正則化とバイアスの両方の変更が含まれ、ニューラルネットワークに他のバイアスがありました。

この論文 によると、反復重みを正則化すると、勾配の爆発に役立つ可能性があります。

また、 この他の論文 によると、ドロップアウトは、セル内ではなく、スタックされたセル間で使用する方が適切です。

勾配消失問題について、L2正則化がすでに追加されている損失で勾配クリッピングを使用する場合、その正則化はクリッピングプロセス中にも考慮されます。


P.S.これが私が取り組んでいたニューラルネットワークです: https://github.com/guillaume-chevalier/HAR-stacked-residual-bidir-LSTMs

11

Tensorflowには、L2ノルムをモデルに適用できるようにする組み込み関数とヘルパー関数がいくつかあります tf.clip_by_global_norm

    # ^^^ define your LSTM above here ^^^

    params = tf.trainable_variables()

    gradients = tf.gradients(self.losses, params)

    clipped_gradients, norm = tf.clip_by_global_norm(gradients,max_gradient_norm)
    self.gradient_norms = norm

    opt = tf.train.GradientDescentOptimizer(self.learning_rate)
    self.updates = opt.apply_gradients(
                    Zip(clipped_gradients, params), global_step=self.global_step)

トレーニングステップの実行:

    outputs = session.run([self.updates, self.gradient_norms, self.losses], input_feed)
0
j314erre