web-dev-qa-db-ja.com

DIM = 1のROWENDインデックスがTorch.argmaxで列インデックスを返すのはなぜですか?

Pytorchのargmax機能を作業しています。

torch.argmax(input, dim=None, keepdim=False)
 _

例を考えてみましょう

a = torch.randn(4, 4)
print(a)
print(torch.argmax(a, dim=1))
 _

ここで、列のベクトルを検索する代わりにDIM = 1を使用する場合、関数は以下のように行ベクトルを検索します。

print(a) :   
tensor([[-1.7739,  0.8073,  0.0472, -0.4084],  
        [ 0.6378,  0.6575, -1.2970, -0.0625],  
        [ 1.7970, -1.3463,  0.9011, -0.8704],  
        [ 1.5639,  0.7123,  0.0385,  1.8410]])  

print(torch.argmax(a, dim=1))  
tensor([1, 1, 0, 3])
 _

私の仮定がGos = 0がvim = 0を表す限り、行は行を表し、dim = 1は列を表します。

7
Programmer

時間は正しく理解していますaxisまたはdim引数がPYTORCHで動作する方法:

tensor dimension


上記の写真を理解すると、次の例は意味があります。

    |
    v
  dim-0  ---> -----> dim-1 ------> -----> --------> dim-1
    |   [[-1.7739,  0.8073,  0.0472, -0.4084],
    v    [ 0.6378,  0.6575, -1.2970, -0.0625],
    |    [ 1.7970, -1.3463,  0.9011, -0.8704],
    v    [ 1.5639,  0.7123,  0.0385,  1.8410]]
    |
    v
 _
# argmax (indices where max values are present) along dimension-1
In [215]: torch.argmax(a, dim=1)
Out[215]: tensor([1, 1, 0, 3])
 _

dim'dimension' ==)は'のトーチです。軸 ' numpy。

3
kmario23