web-dev-qa-db-ja.com

PyTorchで2つの入力を使用してネットワークを構築する方法

一般的なニューラルネットワークアーキテクチャが必要だとします。

Input1 --> CNNLayer 
                    \
                     ---> FCLayer ---> Output
                    /
Input2 --> FCLayer

Input1は画像データ、input2は非画像データです。このアーキテクチャをTensorflowに実装しました。

私が見つけたすべてのpytorchの例は、各層を通過する1つの入力です。 2つの入力を個別に処理し、それらを中間層で結合するように転送機能を定義するにはどうすればよいですか?

9
LeonG

「それらを組み合わせる」ということは、2つの入力を 連結 するつもりだと思います。
2次元に沿って連結すると仮定します。

import torch
from torch import nn

class TwoInputsNet(nn.Module):
  def __init__(self):
    super(TwoInputsNet, self).__init__()
    self.conv = nn.Conv2d( ... )  # set up your layer here
    self.fc1 = nn.Linear( ... )  # set up first FC layer
    self.fc2 = nn.Linear( ... )  # set up the other FC layer

  def forward(self, input1, input2):
    c = self.conv(input1)
    f = self.fc1(input2)
    # now we can reshape `c` and `f` to 2D and concat them
    combined = torch.cat((c.view(c.size(0), -1),
                          f.view(f.size(0), -1)), dim=1)
    out = self.fc2(combined)
    return out

self.fc2への入力数を定義するときは、out_channelsself.convcの出力空間次元の両方を考慮する必要があることに注意してください。

5
Shai