web-dev-qa-db-ja.com

順次モデルでのpytorchスキップ接続

私は、シーケンシャルモデルでスキップ接続に頭を包み込もうとしています。機能的なAPIを使用すると、次のように簡単に実行できます(簡単な例、100%構文的に正確ではないかもしれませんが、アイデアが得られるはずです)。

x1 = self.conv1(inp)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)

x = self.deconv4(x)
x = self.deconv3(x)
x = self.deconv2(x)
x = torch.cat((x, x1), 1))
x = self.deconv1(x)

私は現在、シーケンシャルモデルを使用して、同様のことをしようとしています。最初のconvレイヤーのアクティベーションを最後のconvTransposeまでもたらすスキップ接続を作成します。 here 実装されたU-netアーキテクチャを見てきましたが、少し混乱します。次のようになります。

upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                    kernel_size=4, stride=2,
                                    padding=1, bias=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]

if use_dropout:
    model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
    model = down + [submodule] + up

これは、シーケンシャルモデルにレイヤーを適切に追加するだけではありませんか? down convに続いてsubmodule(再帰的に内側の層を追加)が続き、upconv層であるupに連結されます。おそらくSequential AP​​Iがどのように機能するかについて重要なことを見逃していますが、U-NETから切り取られたコードはどのように実際にスキップを実装していますか?

9
powder

観察は正しいが、UnetSkipConnectionBlock.forward()UnetSkipConnectionBlockは共有したU-Netブロックを定義するModuleである)の定義を見逃している可能性があります。 :

(from _pytorch-CycleGAN-and-pix2pix/models/networks.py#L259_

_# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
#   |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):

    # ...

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([x, self.model(x)], 1)
_

最後の行はキーです(すべての内部ブロックに適用されます)。スキップレイヤーは、入力xと(再帰)ブロック出力self.model(x)を、_self.model_で指定した操作のリストと連結することで簡単に行われます。書いたFunctionalコード。

4
benjaminplanche