web-dev-qa-db-ja.com

KerasのLSTMレイヤーの重みを解釈する方法

私は現在、LSTMレイヤーを使用して、天気予報用のリカレントニューラルネットワークをトレーニングしています。ネットワーク自体はかなりシンプルで、おおよそ次のようになります。

model = Sequential()  
model.add(LSTM(hidden_neurons, input_shape=(time_steps, feature_count), return_sequences=False))  
model.add(Dense(feature_count))  
model.add(Activation("linear"))  

LSTMレイヤーの重みは、次の形状を持っています。

for weight in model.get_weights(): # weights from Dense layer omitted
    print(weight.shape)

> (feature_count, hidden_neurons)
> (hidden_neurons, hidden_neurons)
> (hidden_neurons,)
> (feature_count, hidden_neurons)
> (hidden_neurons, hidden_neurons)
> (hidden_neurons,)
> (feature_count, hidden_neurons)
> (hidden_neurons, hidden_neurons)
> (hidden_neurons,)
> (feature_count, hidden_neurons)
> (hidden_neurons, hidden_neurons)
> (hidden_neurons,)

つまり、このLSTM層には4つの「要素」があるように見えます。私はそれらをどのように解釈するのか今疑問に思っています:

  • この表現のtime_stepsパラメータはどこにありますか?それは重みにどのように影響しますか?

  • 私は、LSTMが入力や忘却ゲートのようないくつかのブロックで構成されていることを読みました。これらがこれらの重み行列で表されている場合、どの行列がどのゲートに属していますか?

  • ネットワークが学んだことを確認する方法はありますか?たとえば、前回のタイムステップ(tを予測する場合はt-1)からどれくらいかかりますか、t-2などからどれくらいかかりますか?たとえば、入力t-5が完全に無関係であることを重みから読み取ることができるかどうかを知ることは興味深いでしょう。

明確化とヒントをいただければ幸いです。

12
Isa

Keras 2.2.0を使用している場合

印刷するとき

print(model.layers[0].trainable_weights)

3つのテンソルが表示されます:lstm_1/kernel, lstm_1/recurrent_kernel, lstm_1/bias:0各テンソルの次元の1つは、

4 * number_of_units

ここでnumber_of_unitsはニューロンの数です。試してください:

units = int(int(model.layers[0].trainable_weights[0].shape[1])/4)
print("No units: ", units)

これは、各テンソルに4つのLSTMユニットの重みが(この順序で)含まれているためです。

i(入力)、f(忘れる)、c(セルの状態)、o(出力)

したがって、重みを抽出するには、スライス演算子を使用するだけです。

W = model.layers[0].get_weights()[0]
U = model.layers[0].get_weights()[1]
b = model.layers[0].get_weights()[2]

W_i = W[:, :units]
W_f = W[:, units: units * 2]
W_c = W[:, units * 2: units * 3]
W_o = W[:, units * 3:]

U_i = U[:, :units]
U_f = U[:, units: units * 2]
U_c = U[:, units * 2: units * 3]
U_o = U[:, units * 3:]

b_i = b[:units]
b_f = b[units: units * 2]
b_c = b[units * 2: units * 3]
b_o = b[units * 3:]

出典: keras code

15

私はおそらくあなたの質問のすべてに答えることはできないでしょうが、私ができることは、LSTMセルとそれを構成するさまざまなコンポーネントに関する詳細情報を提供することです。

githubのこの投稿 は、印刷中にパラメーターの名前を確認する方法を提案しています。

model = Sequential()
model.add(LSTM(4,input_dim=5,input_length=N,return_sequences=True))
for e in Zip(model.layers[0].trainable_weights, model.layers[0].get_weights()):
    print('Param %s:\n%s' % (e[0],e[1]))

出力は次のようになります。

Param lstm_3_W_i:
[[ 0.00069305, ...]]
Param lstm_3_U_i:
[[ 1.10000002, ...]]
Param lstm_3_b_i:
[ 0., ...]
Param lstm_3_W_c:
[[-1.38370085, ...]]
...

これで、これらのさまざまな重みに関する詳細情報 here を見つけることができます。それらは、異なるインデックスを持つW、U、Vおよびbの名前を持っています。

  • W行列は、入力をいくつかの他の内部値に変換する行列です。形は[input_dim, output_dim]
  • U行列は、以前の非表示状態を別の内部値に変換する行列です。形は[output_dim, output_dim]
  • bベクトルは、各ブロックのバイアスです。それらはすべて[output_dim]
  • Vは出力ゲートでのみ使用され、新しい内部状態から出力する値を選択します。形があります[output_dim, output_dim]

つまり、実際には4つの異なる「ブロック」(または内部レイヤー)があります。

  • ゲートを忘れる:これは、以前の非表示状態(h_ {t-1})と入力(x)に基づいて、セルの以前の内部状態(C_ {t-1})からどの値を忘れるかを決定します。

    f_t = sigmoid(W_f * x + U_f * h_ {t-1} + b_f)

    f_tは0と1の間の値のベクトルで、前のセルの状態から保持するもの(= 1)と忘れるもの(= 0)をエンコードします。

  • 入力ゲート:前の非表示状態(h_ {t-1})と入力(x)に基づいて、入力(x)から使用する値を決定します。

    i_t = sigmoid(W_i * x + U_i * h_ {t-1} + b_i)

    i_tは、新しいセルの状態を更新するために使用する値をエンコードする0〜1の値のベクトルです。

  • 候補値:入力(x)と以前の非表示状態(h_ {t-1})を使用して、新しいセル値を作成し、内部のセル状態を更新します。

    Ct_t = tanh(W_c * x + U_c * h_ {t-1} + b_c)

    Ct_tは、セルの状態(C_ {t-1})を更新する可能性のある値を含むベクトルです。

これら3つの値を使用して、新しい内部セル状態(C_t)を作成します。

C_t = f_t * C_ {t-1} + i_t * Ct_t

ご覧のとおり、新しい内部セル状態は2つの要素で構成されています。最後の状態から忘れていない部分と、入力から学習したい部分です。

  • 出力ゲート:セルの状態を出力したくないので、出力したいもの(h_t)を抽象化したものと見なされることがあります。それで、持っているすべての情報に基づいて、このステップの出力であるh_tを構築します。

    h_t = W_o * x + U_o * h_ {t-1} + V_o * C_t + b_o

これにより、LSTMセルの動作が明らかになることを願っています。 LSTMのチュートリアルでは、Niceスキーマや段階的な例などを使用しているため、ぜひご覧ください。比較的複雑なレイヤーです。

あなたの質問に関して、私は今、状態を修正するために入力から何が使用されたかを追跡する方法を考えています。入力を処理する行列であるため、最終的にはさまざまなW行列を見ることができます。 W_cは、セルの状態を更新するために潜在的に使用されるものに関する情報を提供します。 W_oは、出力の生成に使用されるものに関する情報を提供する可能性があります...しかし、前の状態も影響を与えるため、これらすべては他の重みに関連しています。

ただし、W_cにいくつかの強い重みがある場合、入力ゲート(i_t)が完全に閉じてセルの状態の更新を消滅させる可能性があるため、何の意味もない可能性があります...トレースする数学のフィールドは複雑ですニューラルネットで何が起こっているかは非常に複雑です。

ニューラルネットは、最も一般的なケースでは本当にブラックボックスです。文献から、出力から入力まで情報をトレースバックするいくつかのケースを見つけることができますが、これは私が読んだものからの非常に特殊なケースです。

これが役に立てば幸いです:-)

6
Nassim Ben