web-dev-qa-db-ja.com

PyTorchの一意のテンソル値

私はPyTorchテンソルで明確な値を見つけることに同意しています。
Tensorflowの nique op を複製する効率的な方法はありますか?

14
arosa

0.4.0には torch.unique() メソッドがあります

torch <= 0.3.1 あなたが試すことができます:

import torch
import numpy as np

x = torch.Rand((3,3)) * 10
np.unique(x.round().numpy())
15
Alex Glinsky

これを行う最良の方法(最も簡単な方法)は、numpyに変換し、numpyの組み込みunique関数を使用することです。そのようです。

def unique(tensor1d):
    t, idx = np.unique(tensor1d.numpy(), return_inverse=True)
    return torch.from_numpy(t), torch.from_numpy(idx)  

だから、それを試すとき:

t, idx = unique(torch.LongTensor([1, 1, 2, 4, 4, 4, 7, 8, 8]))  
# t --> [1, 2, 4, 7, 8]
# idx --> [0, 0, 1, 2, 2, 2, 3, 4, 4]
5
entrophy

torch.unique()>2つのテンソル間の共通項目を取得します。@ 2 tensor.eq()に相当するインデックスを取得し、テンソルを連結すると、最終的に「torch.unique」の共通項目ヘルプが取得されます。

import torch as pt

a = pt.tensor([1,2,3,2,3,4,3,4,5,6])
b = pt.tensor([7,2,3,2,7,4,9,4,9,8])

equal_data = pt.eq(a, b)
pt.unique(pt.cat([a[equal_data],b[equal_data]]))
1
Sunil