web-dev-qa-db-ja.com

PyTorchの "view"メソッドはどこのように機能しますか?

次のコードスニペットでメソッドview()について混乱しています。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool  = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1   = nn.Linear(16*5*5, 120)
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()

私の混乱は次の行に関するものです。

x = x.view(-1, 16*5*5)

tensor.view()関数は何をするのですか?私は多くの場所でその使い方を見ましたが、それがそのパラメータをどのように解釈するのか理解できません。

view()関数のパラメータとして負の値を指定するとどうなりますか?例えば、tensor_variable.view(1, 1, -1)と呼ぶとどうなりますか?

誰かがいくつかの例を使ってview()関数の主な原理を説明できますか?

111
Wasi Ahmad

View関数はテンソルを変形するためのものです。

テンソルがあるとしましょう

import torch
a = torch.range(1, 16)

aは1から16までの16の要素を含むテンソルです(含まれます)。このテンソルを4 x 4テンソルにするために変形したい場合は、

a = a.view(4, 4)

これでa4 x 4テンソルになります。 変形後の要素の総数は同じままである必要があることに注意してください。テンソルa3 x 5テンソルに変形するのは適切ではありません。

パラメーター-1の意味は何ですか?

必要な行数がわからないが列数が確実であるという状況がある場合は、-1を指定してこれを指定できます。 ( これをより多くの次元のテンソルに拡張できることに注意してください。軸の値の1つだけが-1 になることができます)。これはライブラリに伝える方法です。「これらの多くの列を持つテンソルを教えてください。これを実現するのに必要な適切な行数を計算します」。

これはあなたが上で与えたニューラルネットワークコードで見ることができます。 forward関数のx = self.pool(F.relu(self.conv2(x)))行の後に、深さ16のフィーチャーマップがあります。完全に接続されたレイヤーにそれを与えるためにこれを平らにする必要があります。ですから、pytorchにあなたが得たテンソルを特定の数の列を持つように変形し、それ自身で行数を決めるように伝えます。

Numpyとpytorchの間に類似点を描くviewは、numpyの reshape 関数に似ています。

161
Kashyap

もっと簡単なものからもっと難しいものまで、いくつか例を見てみましょう。

  1. viewメソッドは、selfテンソルと同じデータを持つテンソルを返します(つまり、返されるテンソルの要素数は同じです)が、形状は異なります。例えば:

    a = torch.arange(1, 17)  # a's shape is (16,)
    
    a.view(4, 4) # output below
      1   2   3   4
      5   6   7   8
      9  10  11  12
     13  14  15  16
    [torch.FloatTensor of size 4x4]
    
    a.view(2, 2, 4) # output below
    (0 ,.,.) = 
    1   2   3   4
    5   6   7   8
    
    (1 ,.,.) = 
     9  10  11  12
    13  14  15  16
    [torch.FloatTensor of size 2x2x4]
    
  2. -1がパラメータの1つではないと仮定すると、それらを一緒に乗算すると、結果はテンソルの要素数に等しくなければなりません。次のようにした場合、a.view(3, 3)、shape(3 x 3)は16個の要素を持つ入力には無効であるため、RuntimeErrorが発生します。つまり、3 x 3は16ではなく9です。

  3. 関数に渡すパラメータの1つとして-1を使用できますが、それは1回だけです。起こるのは、メソッドがその次元を埋める方法についてあなたのために数学をするということだけです。例えばa.view(2, -1, 4)a.view(2, 2, 4)と同等です。 [16 /(2 x 4)= 2]

  4. 返されたテンソルが同じdataを共有していることに注意してください。 「ビュー」を変更した場合は、元のテンソルのデータを変更していることになります。

    b = a.view(4, 4)
    b[0, 2] = 2
    a[2] == 3.0
    False
    
  5. さて、もっと複雑なユースケースについて。ドキュメントには、それぞれの新しいビュー次元は元の次元の部分空間であるか、次の隣接条件のような条件を満たすspan d、d + 1、...、d + kのみである必要があると書かれていますすべてのi = 0、...、k - 1、ストライド[i] =ストライド[i + 1] xサイズ[i + 1]。それ以外の場合は、contiguous()を呼び出す必要があります。テンソルを表示することができます。

    a = torch.Rand(5, 4, 3, 2) # size (5, 4, 3, 2)
    a_t = a.permute(0, 2, 3, 1) # size (5, 3, 2, 4)
    
    # The commented line below will raise a RuntimeError, because one dimension
    # spans across two contiguous subspaces
    # a_t.view(-1, 4)
    
    # instead do:
    a_t.contiguous().view(-1, 4)
    
    # To see why the first one does not work and the second does,
    # compare a.stride() and a_t.stride()
    a.stride() # (24, 6, 2, 1)
    a_t.stride() # (24, 2, 1, 6)
    

    a_tでは、stride [0]!= stride [1] x size [1]24!= 2 x 3であることに注意してください。

19
Jadiel de Armas

パラメーター-1の意味は何ですか?

-1は、動的なパラメータ数または「なんでも」として読み取ることができます。そのため、view()には-1というパラメータは1つしか存在できません。

x.view(-1,1)に尋ねると、これはxの要素数に応じてテンソル形状[anything, 1]を出力します。例えば:

import torch
x = torch.tensor([1, 2, 3, 4])
print(x,x.shape)
print("...")
print(x.view(-1,1), x.view(-1,1).shape)
print(x.view(1,-1), x.view(1,-1).shape)

出力します:

tensor([1, 2, 3, 4]) torch.Size([4])
...
tensor([[1],
        [2],
        [3],
        [4]]) torch.Size([4, 1])
tensor([[1, 2, 3, 4]]) torch.Size([1, 4])
0
prosti

weights.reshape(a, b)は、データをメモリの別の部分にコピーするのと同じサイズの重み(a、b)と同じデータを持つ新しいテンソルを返します。

weights.resize_(a, b)は形状が異なる同じテンソルを返します。しかし、新しい形状が元のテンソルよりも要素数が少ない場合、一部の要素はテンソルから削除されます(ただしメモリからは削除されません)。新しい形状が元のテンソルよりも多くの要素をもたらす場合、新しい要素はメモリ内で初期化されません。

weights.view(a, b)は、サイズ(a、b)の重みと同じデータで新しいテンソルを返します

0
Jibin Mathew

x.view(-1, 16 * 5 * 5)x.flatten(1)と同等で、パラメータ1は1次元から開始することを示しています( 'sample'次元を平坦化していません)。だから私はflatten()が好きです。

0
FENGSHI ZHENG