web-dev-qa-db-ja.com

Pytorchはワンホットベクターをサポートしていませんか?

Pytorchがone-hotベクトルをどのように処理するかに非常に混乱しています。この tutorial では、ニューラルネットワークがワンホットベクトルを出力として生成します。私が理解している限り、チュートリアルのニューラルネットワークの概略構造は次のようになります。

enter image description here

ただし、labelsはワンホットベクター形式ではありません。以下を取得しますsize

_print(labels.size())
print(outputs.size())

output>>> torch.Size([4]) 
output>>> torch.Size([4, 10])
_

奇妙なことに、私はoutputsおよびlabelscriterion=CrossEntropyLoss()に渡しましたが、エラーはまったくありません。

_loss = criterion(outputs, labels) # How come it has no error?
_

私の仮説:

たぶんpytorchは自動的にlabelsをワンホットベクター形式に変換します。そこで、損失関数に渡す前に、ラベルをワンホットベクトルに変換してみます。

_def to_one_hot_vector(num_class, label):
    b = np.zeros((label.shape[0], num_class))
    b[np.arange(label.shape[0]), label] = 1

    return b

labels_one_hot = to_one_hot_vector(10,labels)
labels_one_hot = torch.Tensor(labels_one_hot)
labels_one_hot = labels_one_hot.type(torch.LongTensor)

loss = criterion(outputs, labels_one_hot) # Now it gives me error
_

しかし、私は次のエラーを受け取りました

RuntimeError:マルチターゲットは/opt/pytorch/pytorch/aten/src/THCUNN/generic/ClassNLLCriterion.cu:15ではサポートされていません

したがって、ワンホットベクタはPytorch?でサポートされていません。 Pytorchは、2つのテンソル_cross entropy_と_outputs = [1,0,0],[0,0,1]_の_labels = [0,2]_をどのように計算しますか?今のところ、私にはまったく意味がありません。

9
Raven Cheuk

私はあなたの混乱について混乱しています。 PyTorchはそのドキュメントで CrossEntropyLoss について明確に述べています

この基準は、サイズミニバッチの1Dテンソルの各値のターゲットとしてクラスインデックス(0からC-1)を期待しています

つまり、to_one_hot_vector関数は概念的にCELに組み込まれ、ワンホットAPIを公開しません。ワンホットベクトルは、クラスラベルを保存する場合と比較して、メモリ効率が悪いことに注意してください。

ワンホットベクトルが与えられ、クラスラベル形式に移動する必要がある場合(たとえば、CELと互換性を持たせるため)、以下のようにargmaxを使用できます。

import torch

labels = torch.tensor([1, 2, 3, 5])
one_hot = torch.zeros(4, 6)
one_hot[torch.arange(4), labels] = 1

reverted = torch.argmax(one_hot, dim=1)
assert (labels == reverted).all().item()
11
Jatentaki

このコードはone hot encodemulti hot encodeの両方で役立ちます:

_import torch
batch_size=10
n_classes=5
target = torch.randint(high=5, size=(1,10)) # set size (2,10) for MHE
print(target)
y = torch.zeros(batch_size, n_classes)
y[range(y.shape[0]), target]=1
y
_

OHEでの出力

_tensor([[4, 3, 2, 2, 4, 1, 1, 1, 4, 2]])

tensor([[0., 0., 0., 0., 1.],
        [0., 0., 0., 1., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1.],
        [0., 0., 1., 0., 0.]])
_

target = torch.randint(high=5, size=(2,10))を設定したときのMHEの出力

_tensor([[3, 2, 4, 4, 2, 4, 0, 4, 4, 1],
        [4, 1, 1, 3, 2, 2, 4, 2, 4, 3]])

tensor([[0., 0., 0., 1., 1.],
        [0., 1., 1., 0., 0.],
        [0., 1., 0., 0., 1.],
        [0., 0., 0., 1., 1.],
        [0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 1.],
        [1., 0., 0., 0., 1.],
        [0., 0., 1., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 1., 0., 1., 0.]])
_

複数のOHEが必要な場合:

_torch.nn.functional.one_hot(target)

tensor([[[0, 0, 0, 1, 0],
         [0, 0, 1, 0, 0],
         [0, 0, 0, 0, 1],
         [0, 0, 0, 0, 1],
         [0, 0, 1, 0, 0],
         [0, 0, 0, 0, 1],
         [1, 0, 0, 0, 0],
         [0, 0, 0, 0, 1],
         [0, 0, 0, 0, 1],
         [0, 1, 0, 0, 0]],

        [[0, 0, 0, 0, 1],
         [0, 1, 0, 0, 0],
         [0, 1, 0, 0, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 1, 0, 0],
         [0, 0, 1, 0, 0],
         [0, 0, 0, 0, 1],
         [0, 0, 1, 0, 0],
         [0, 0, 0, 0, 1],
         [0, 0, 0, 1, 0]]])
_
4
prosti