web-dev-qa-db-ja.com

PyTorchのBiLSTM(BiGRU)から最後の状態を取得

いくつかの記事を読んだ後、BiLSTMから最後の隠し状態を取得する実装の正確さについてまだかなり混乱しています。

  1. PyTorch(TowardsDataScience)の双方向RNNを理解する
  2. seq2seqモデルのPackedSequence(PyTorchフォーラム)
  3. PyTorch LSTMの「hidden」と「output」の違いは何ですか?(StackOverflow)
  4. シーケンスのバッチ(Pytorch formums)でテンソルを選択

最後のソース(4)からのアプローチは私にとって最もクリーンなようですが、スレッドを正しく理解したかどうかはまだわかりません。 LSTMと逆LSTMからの正しい最終的な非表示状態を使用していますか?これは私の実装です

# pos contains indices of words in embedding matrix
# seqlengths contains info about sequence lengths
# so for instance, if batch_size is 2 and pos=[4,6,9,3,1] and 
# seqlengths contains [3,2], we have batch with samples
# of variable length [4,6,9] and [3,1]

all_in_embs = self.in_embeddings(pos)
in_emb_seqs = pack_sequence(torch.split(all_in_embs, seqlengths, dim=0))
output,lasthidden = self.rnn(in_emb_seqs)
if not self.data_processor.use_gru:
    lasthidden = lasthidden[0]
# u_emb_batch has shape batch_size x embedding_dimension
# sum last state from forward and backward  direction
u_emb_batch = lasthidden[-1,:,:] + lasthidden[-2,:,:]

それが正しいか?

4
Smarty77

一般的なケースでは、独自のBiLSTMネットワークを作成する場合は、2つの通常のLSTMを作成し、一方に通常の入力シーケンスを、もう一方に反転入力シーケンスを供給する必要があります。両方のシーケンスのフィードが終了したら、両方のネットから最後の状態を取得し、何らかの方法でそれらを結合します(合計または連結)。

私が理解しているように、あなたは組み込みのBiLSTMを次のように使用しています この例 (設定bidirectional=True innn.LSTMコンストラクター)。次に、PyTorchがすべての面倒を処理するため、バッチをフィードした後、連結された出力を取得します。

それが事実であり、隠された状態を合計したい場合は、

u_emb_batch = (lasthidden[0, :, :] + lasthidden[1, :, :])

レイヤーが1つしかないことを前提としています。より多くのレイヤーがある場合、バリアントの方が優れているように見えます。

これは、結果が構造化されているためです( ドキュメント を参照)。

h_n形状(num_layers*num_directions、batch、hidden_​​size):t = seq_lenの非表示状態を含むテンソル

ところで、

u_emb_batch_2 = output[-1, :, :HIDDEN_DIM] + output[-1, :, HIDDEN_DIM:]

同じ結果が得られるはずです。

7
igrinis

解凍されたシーケンスを使用する場合の詳細な説明は次のとおりです。

outputの形状は(seq_len, batch, num_directions * hidden_size)です( ドキュメント を参照)。これは、GRUの順方向パスと逆方向パスの出力が3次元に沿って連結されることを意味します。

例でbatch=2hidden_size=256を想定すると、次のようにして、順方向パスと逆方向パスの両方の出力を簡単に分離できます。

output = output.view(-1, 2, 2, 256)   # (seq_len, batch_size, num_directions, hidden_size)
output_forward = output[:, :, 0, :]   # (seq_len, batch_size, hidden_size)
output_backward = output[:, :, 1, :]  # (seq_len, batch_size, hidden_size)

(注:-1は、他のディメンションからそのディメンションを推測するようにpytorchに指示します。 this の質問を参照してください。)

同様に、形状torch.chunkの元のoutput(seq_len, batch, num_directions * hidden_size) 関数を使用できます。

# Split in 2 tensors along dimension 2 (num_directions)
output_forward, output_backward = torch.chunk(output, 2, 2)

これで、 torch.gatherseqlengthsを使用してフォワードパスの最後の非表示状態(再形成後)、および位置0

# First we unsqueeze seqlengths two times so it has the same number of
# of dimensions as output_forward
# (batch_size) -> (1, batch_size, 1)
lengths = seqlengths.unsqueeze(0).unsqueeze(2)

# Then we expand it accordingly
# (1, batch_size, 1) -> (1, batch_size, hidden_size) 
lengths = lengths.expand((1, -1, output_forward.size(2)))

last_forward = torch.gather(output_forward, 0, lengths - 1).squeeze(0)
last_backward = output_backward[0, :, :]

0ベースのインデックス付けのため、lengthsから1を減算したことに注意してください。

この点で、last_forwardlast_backwardはどちらも(batch_size, hidden_dim)の形をしています。

3
jabalazs