web-dev-qa-db-ja.com

Pytorchでのランダムな選択?

画像のテンソルがあり、ランダムに選択したい。 np.random.choice()に相当するものを探しています。

import torch

pictures = torch.randint(0, 256, (1000, 28, 28, 3))

これらの写真を10枚欲しいとしましょう。

6
Nicolas Gervais

torchには、np.random.choice()に相当する実装はありません。最善の方法は、選択肢からランダムなインデックスを選択することです。

choices[torch.randint(choices.shape[0], (1,))]

これは、0とテンソルの要素数の間のrandintを生成します。

for i in range(5):
    print(choices[torch.randint(choices.shape[0], (1,))])
tensor([2])
tensor([6])
tensor([2])
tensor([6])
tensor([7])

replacement = Falseを設定する場合は、マスクを使用して選択した値を削除します。

for i in range(10):
    value = choices[torch.randint(choices.shape[0], (1,))]
    choices = choices[choices!=value]
    print(value)
tensor([2])
tensor([4])
tensor([6])
tensor([7])
0
Nicolas Gervais

私の場合:values.shape =(386363948、2)、k = 190973、次のコードはかなり速く、0.1〜0.2秒で動作します。

indice = random.sample(range(386363948), 190973)
indice = torch.tensor(indice)
sampled_values = values[indice]

ただし、torch.randpermを使用すると、20秒以上かかります。

sampled_values = values[torch.randperm(386363948)[190973]]
1
刘致远

他の人が述べたように、トーチには代わりにrandintまたは順列を使用する選択肢がありません

import torch
n = 4
choices = torch.Rand(4, 3)
choices_flat = choices.view(-1)
index = torch.randint(choices_flat.numel(), (n,))
# or if replace = False
index = torch.randperm(choices_flat.numel())[:n]
select = choices_flat[index]
0
Qianyi Zhang