web-dev-qa-db-ja.com

トーチテンソルの最大値のインデックスを効率的に取得する方法は?

たとえば、次の形状のトーチテンソルがあるとします。

_x = torch.Rand(20, 1, 120, 120)
_

今私が望むのは、各120x120マトリックスの最大値のインデックスを取得することです。問題を単純化するために、最初にx.squeeze()を使用して形状_[20, 120, 120]_を処理します。次に、形状_[20, 2]_のインデックスのリストであるトーチテンソルを取得します。

どうすれば高速にできますか?

7
Chris

私が正しく得れば、値ではなく、インデックスが必要です。残念ながら、すぐに使えるソリューションはありません。 argmax()関数が存在しますが、あなたが望むことを正確に実行する方法がわかりません。

したがって、ここに小さな回避策がありますが、テンソルを分割しているだけなので、効率も大丈夫です:

n = torch.tensor(4)
d = torch.tensor(4)
x = torch.Rand(n, 1, d, d)
m = x.view(n, -1).argmax(1)
# since argmax() does only return the index of the flattened
# matrix block we have to calculate the indices by ourself 
# by using / and % (// would also work, but as we are dealing with
# type torch.long / works as well
indices = torch.cat(((m / d).view(-1, 1), (m % d).view(-1, 1)), dim=1)
print(x)
print(indices)

nは最初の次元を表し、dは最後の2つの次元を表します。ここでは、結果を示すために小さい数字を使用します。しかし、もちろんこれはn=20およびd=120でも機能します:

n = torch.tensor(20)
d = torch.tensor(120)
x = torch.Rand(n, 1, d, d)
m = x.view(n, -1).argmax(1)
indices = torch.cat(((m / d).view(-1, 1), (m % d).view(-1, 1)), dim=1)
#print(x)
print(indices)

n=4およびd=4の出力は次のとおりです。

tensor([[[[0.3699, 0.3584, 0.4940, 0.8618],
          [0.6767, 0.7439, 0.5984, 0.5499],
          [0.8465, 0.7276, 0.3078, 0.3882],
          [0.1001, 0.0705, 0.2007, 0.4051]]],


        [[[0.7520, 0.4528, 0.0525, 0.9253],
          [0.6946, 0.0318, 0.5650, 0.7385],
          [0.0671, 0.6493, 0.3243, 0.2383],
          [0.6119, 0.7762, 0.9687, 0.0896]]],


        [[[0.3504, 0.7431, 0.8336, 0.0336],
          [0.8208, 0.9051, 0.1681, 0.8722],
          [0.5751, 0.7903, 0.0046, 0.1471],
          [0.4875, 0.1592, 0.2783, 0.6338]]],


        [[[0.9398, 0.7589, 0.6645, 0.8017],
          [0.9469, 0.2822, 0.9042, 0.2516],
          [0.2576, 0.3852, 0.7349, 0.2806],
          [0.7062, 0.1214, 0.0922, 0.1385]]]])
tensor([[0, 3],
        [3, 2],
        [1, 1],
        [1, 0]])

これがあなたが手に入れたかったものだと思います! :)

Edit:

わずかに速くなるかもしれないわずかに変更されたものを次に示します(私は推測しますが、それほどではありませんが)。

前のようにこれの代わりに:

m = x.view(n, -1).argmax(1)
indices = torch.cat(((m // d).view(-1, 1), (m % d).view(-1, 1)), dim=1)

argmax値で既に行われている必要な再整形:

m = x.view(n, -1).argmax(1).view(-1, 1)
indices = torch.cat((m // d, m % d), dim=1)

しかし、コメントで述べたように。私はそれからもっと多くを得ることが可能であるとは思わない。

あなたができることの1つは、本当にパフォーマンス改善の最後の可能なビットを得ることが重要である場合、上記の機能を低レベル拡張として実装することです( C++のように)pytorchの場合。

これにより、呼び出すことができる関数が1つだけになり、遅いpythonコードを回避できます。

https://pytorch.org/tutorials/advanced/cpp_extension.html

3
blue-phoenox