web-dev-qa-db-ja.com

PytorchのLSTM

PyTorchは初めてです。私はこれに出くわしました GitHubリポジトリ(完全なコード例へのリンク) さまざまな異なる例が含まれています。

LSTMに関する例もあります。これはNetworkクラスです。

_# RNN Model (Many-to-One)
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        # Set initial states 
        h0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size)) 
        c0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size))

        # Forward propagate RNN
        out, _ = self.lstm(x, (h0, c0))  

        # Decode hidden state of last time step
        out = self.fc(out[:, -1, :])  
        return out
_

だから私の質問は次の行についてです:

_h0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size)) 
c0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size))
_

私が理解している限り、すべてのトレーニング例でforward()が呼び出されます。ただし、これは、非表示状態とセル状態がリセットされることを意味します。つまり、すべてのトレーニング例でゼロの行列に置き換えられます。

名前_h0_および_c0_は、これがt = 0での非表示/セル状態のみであることを示していますが、トレーニングの例ごとにこれらのゼロ行列がlstmに渡されるのはなぜですか?

最初の呼び出しの後で無視されたとしても、それはあまり良い解決策ではありません。

コードをテストすると、MNISTセットで97%の精度が示されているため、このように機能しているように見えますが、私には意味がありません。

誰かがこれで私を助けてくれることを願っています。

前もって感謝します!

8
MBT

明らかに、私はこれで間違った方向に進んでいました。隠しユニットと隠し/セルの状態を混同していました。トレーニングステップでは、LSTMの非表示ユニットのみがトレーニングされます。セル状態と非表示状態は、すべてのシーケンスの開始時にリセットされます。したがって、このようにプログラムされていることは理にかなっています。

申し訳ありません。

11
MBT