web-dev-qa-db-ja.com

PyTorchの埋め込み、LSTM、および線形レイヤーに正しく入力を与える方法は?

_torch.nn_モジュールのさまざまなコンポーネントを使用して、バッチトレーニング用の入力を正しく準備する方法について明確にする必要があります。具体的には、seq2seqモデルのエンコーダー/デコーダーネットワークの作成を検討しています。

これらの3つの層を持つモジュールが順番にあるとします:

  1. _nn.Embedding_
  2. _nn.LSTM_
  3. _nn.Linear_

_nn.Embedding_

入力: _batch_size * seq_length_
出力: _batch_size * seq_length * embedding_dimension_

ここでは何の問題もありません。入力と出力の予想される形状について明示したいだけです。

_nn.LSTM_

入力: _seq_length * batch_size * input_size_(この場合は_embedding_dimension_)
出力: _seq_length * batch_size * hidden_size_
last_hidden_​​state: _batch_size * hidden_size_
last_cell_state: _batch_size * hidden_size_

Embeddingレイヤーの出力をLSTMレイヤーの入力として使用するには、軸1と2を転置する必要があります。

私がオンラインで見つけた多くの例は、x = embeds.view(len(sentence), self.batch_size , -1)のようなことをしていますが、それは私を混乱させます。このビューは、同じバッチの要素が同じバッチにあることをどのように保証しますか? len(sentence)と_self.batch_のサイズが同じサイズの場合はどうなりますか?

_nn.Linear_

入力: _batch_size_ x _input_size_(この場合のLSTMの非表示サイズまたは??)
出力: _batch_size_ x _output_size_

LSTMの_last_hidden_state_のみが必要な場合は、_nn.Linear_への入力として指定できます。

ただし、(すべての中間の非表示状態も含む)出力を使用する場合は、_nn.Linear_の入力サイズを_seq_length * hidden_size_に変更し、出力をLinearモジュールへの入力として使用する必要があります。出力の軸1と2を転置する必要があり、Output_transposed(batch_size, -1)で表示できます。

ここでの私の理解は正しいですか?これらの転置演算をテンソル_(tensor.transpose(0, 1))_で実行するにはどうすればよいですか?

18
Silpara

ほとんどの概念の理解は正確ですが、あちこちに欠けている点がいくつかあります。

埋め込みをLSTM(またはその他の反復ユニット)に接続する

(batch_size, seq_len, embedding_size)の形で出力を埋め込みます。現在、これをLSTMに渡すことができるさまざまな方法があります。
* LSTMbatch_firstとして入力を受け入れる場合、これをLSTMに直接渡すことができます。したがって、LSTMパス引数batch_first=Trueを作成します。
*または、(seq_len, batch_size, embedding_size)の形で入力を渡すことができます。したがって、埋め込み出力をこの形状に変換するには、前述のようにtorch.transpose(tensor_name, 0, 1)を使用して1番目と2番目の次元を転置する必要があります。

Q. x = embeds.view(len(sentence)、self.batch_size、-1)のような何かをする多くの例をオンラインで見ていますが、これは私を混乱させます。
A。これは間違っています。それはバッチを混同し、あなたは絶望的な学習タスクを学習しようとします。これを見ればどこでも、著者にこのステートメントを変更し、代わりに転置を使用するように伝えることができます。

batch_firstを使用しないことを支持する議論があります。これは、Nvidia CUDAが提供する基礎となるAPIが、バッチをセカンダリとして使用するとかなり高速に実行されることを示しています。

コンテキストサイズの使用

埋め込み出力をLSTMに直接フィードします。これにより、LSTMの入力サイズがコンテキストサイズ1に修正されます。つまり、入力がLSTMへの単語である場合、常に一度に1単語が与えられます。しかし、これは常に望んでいることではありません。そのため、コンテキストサイズを拡張する必要があります。これは次のように行うことができます-

# Assuming that embeds is the embedding output and context_size is a defined variable
embeds = embeds.unfold(1, context_size, 1)  # Keeping the step size to be 1
embeds = embeds.view(embeds.size(0), embeds.size(1), -1)

ドキュメントを展開する
これで、上記のようにLSTMにフィードできます。seq_lenseq_len - context_size + 1に変更され、embedding_size(LSTMの入力サイズ)がcontext_size * embedding_sizeに変更されたことを思い出してください

可変シーケンス長の使用

バッチ内の異なるインスタンスの入力サイズは常に同じではありません。たとえば、文の一部は10語、一部は15語、一部は1000語です。したがって、繰り返し単位への可変長シーケンス入力が必要です。これを行うには、入力をネットワークにフィードする前に実行する必要がある追加の手順がいくつかあります。次の手順に従うことができます-
1。バッチを最大のシーケンスから最小のシーケンスに並べ替えます。
2。バッチ内の各シーケンスの長さを定義するseq_lengths配列を作成します。 (これは単純なpythonリスト)にすることができます)
3。すべてのシーケンスをパディングして、最大のシーケンスと同じ長さにします。
4。このバッチのLongTensor変数を作成します。
5。さて、上記の変数を埋め込みに通して適切なコンテキストサイズの入力を作成した後、次のようにシーケンスをパックする必要があります-

# Assuming embeds to be the proper input to the LSTM
lstm_input = nn.utils.rnn.pack_padded_sequence(embeds, [x - context_size + 1 for x in seq_lengths], batch_first=False)

LSTMの出力を理解する

ここで、lstm_input accを準備したら。必要に応じて、lstmを次のように呼び出すことができます。

lstm_outs, (h_t, h_c) = lstm(lstm_input, (h_t, h_c))

ここでは、(h_t, h_c)を最初の非表示状態として提供する必要があり、最終的な非表示状態を出力します。可変長シーケンスのパッキングが必要な理由を確認できます。そうしないと、LSTMが不要なパディングされた単語に対しても実行します。
これで、lstm_outsはパックされたシーケンスになり、各ステップでのlstmの出力になり、(h_t, h_c)はそれぞれ最終出力と最終セル状態になります。 h_th_cの形状は(batch_size, lstm_size)になります。これらをさらに入力に直接使用できますが、中間出力も使用する場合は、最初にlstm_outsを以下のようにアンパックする必要があります

lstm_outs, _ = nn.utils.rnn.pad_packed_sequence(lstm_outs)

これで、lstm_outsの形状は(max_seq_len - context_size + 1, batch_size, lstm_size)になります。これで、必要に応じてlstmの中間出力を抽出できます。

パックされていない出力には、各バッチのサイズの後に0があります。これは、最大シーケンス(常に入力を最大から最小にソートしたときの最初のシーケンス)の長さに一致するパディングです。

また、h_tは常に各バッチ出力の最後の要素に等しくなることに注意してください。

Lstmとリニアのインターフェース

これで、lstmの出力のみを使用する場合は、h_tをリニアレイヤーに直接フィードでき、機能します。ただし、中間出力も使用する場合は、これを線形層にどのように入力するのか(何らかの注意ネットワークまたはプーリングを通じて)を把握する必要があります。シーケンス全体の長さは異なるため、リニアレイヤーに完全なシーケンスを入力する必要はありません。リニアレイヤーの入力サイズを修正することはできません。はい、lstmの出力を転置してさらに使用する必要があります(ここでもviewは使用できません)。

最後のメモ:双方向リカレントセルの使用、展開時のステップサイズの使用、注意のインターフェイスなど、いくつかのポイントを意図的に残しました。これらは非常に煩雑になるため、この回答の範囲外になります。

29
layog