web-dev-qa-db-ja.com

PyTorch LSTMの「非表示」と「出力」の違いは何ですか?

PyTorchのLSTMモジュール(および同様のRNNとGRU)のドキュメントを理解できません。出力に関して、それは言います:

出力:出力、(h_n、c_n)

  • 出力(seq_len、batch、hidden_​​size * num_directions):各tについて、RNNの最後のレイヤーからの出力フィーチャ(h_t)を含むテンソル。 torch.nn.utils.rnn.PackedSequenceが入力として指定されている場合、出力もパックシーケンスになります。
  • h_n(num_layers * num_directions、batch、hidden_​​size):t = seq_lenの非表示状態を含むテンソル
  • c_n(num_layers * num_directions、batch、hidden_​​size):t = seq_lenのセル状態を含むテンソル

変数outputh_nは両方とも非表示状態の値を与えるようです。 h_nは、outputに既に含まれている最後のタイムステップを冗長的に提供するだけですか、それ以上のものがありますか?

40
Nathaniel

図を作成しました。名前は PyTorch docs に従いますが、num_layersの名前をwに変更しました。

outputは、最後のレイヤーのすべての隠された状態を含みます(時間的ではなく、深さ方向の「最後」)。 (h_n, c_n)は、最後のタイムステップの後の非表示状態で構成されますt=n。したがって、潜在的にそれらを別のLSTMにフィードできます。

LSTM diagram

バッチディメンションは含まれません。

79
nnnmmm

実際に使用するモデルと、モデルの解釈方法によって異なります。出力は次のとおりです。

  • 単一のLSTMセルの非表示状態
  • いくつかのLSTMセルの非表示状態
  • すべての非表示状態の出力

出力は、直接解釈されることはほとんどありません。入力がエンコードされている場合、結果をデコードするためのsoftmaxレイヤーが必要です。

注:言語モデリングでは、次の単語の確率p(wt + 1| w1、...、wt)= softmax(Wht+ b)。

1
prosti

出力状態は、RNN(LSTM)の各タイムステップからのすべての非表示状態のテンソルであり、RNN(LSTM)によって返される非表示状態は、入力シーケンスの最後のタイムステップからの最後の非表示状態です。これを確認するには、各ステップから非表示の状態をすべて収集し、それを出力状態と比較します(pack_padded_sequenceを使用していない場合)。

0
Jibin Mathew