web-dev-qa-db-ja.com

PyTorchのテンソルを使用した多次元テンソルのインデックス作成

私は次のコードを持っています:

a = torch.randint(0,10,[3,3,3,3])
b = torch.LongTensor([1,1,1,1])

多次元インデックスbがあり、それを使用してaの単一のセルを選択します。 bがテンソルでない場合、次のことができます。

a[1,1,1,1]

これは正しいセルを返しますが、:

a[b]

a[1]を4回選択するだけなので、機能しません。

これどうやってするの?ありがとう

よりエレガントな(そしてより単純な)ソリューションは、単にbをタプルとしてキャストすることです:

a[Tuple(b)]
Out[10]: tensor(5.)

これが「通常の」numpyでどのように機能するのか興味があり、これを非常によく説明している関連記事を見つけました here

5
dennlinger

b を使用してchunkを4つに分割し、チャンク化されたbを使用して特定の要素にインデックスを付けることができます。

>> a = torch.arange(3*3*3*3).view(3,3,3,3)
>> b = torch.LongTensor([[1,1,1,1], [2,2,2,2], [0, 0, 0, 0]]).t()
>> a[b.chunk(chunks=4, dim=0)]   # here's the trick!
Out[24]: tensor([[40, 80,  0]])

それの良いところは、aの任意の次元に簡単に一般化できることです。チャックの数をaの次元と等しくする必要があります。

5
Shai