web-dev-qa-db-ja.com

PyTorchの2つの確率分布のKL発散

2つの確率分布があります。 PyTorchでそれらの間のKL分岐をどのように見つける必要がありますか?通常のクロスエントロピーは整数ラベルのみを受け入れます。

8
Mojtaba Komeili

はい、PyTorchにはkl_divの下に torch.nn.functional というメソッドがあり、テンソル間のKL発生を直接計算します。同じ形状のテンソルabがあるとします。次のコードを使用できます。

import torch.nn.functional as F
out = F.kl_div(a, b)

詳細については、上記のメソッドのドキュメントを参照してください。

12
jdhao

function kl_divwiki の説明と同じではありません。

私は次を使用します:

# this is the same example in wiki
P = torch.Tensor([0.36, 0.48, 0.16])
Q = torch.Tensor([0.333, 0.333, 0.333])

(P * (P / Q).log()).sum()
# tensor(0.0863), 10.2 µs ± 508

F.kl_div(Q.log(), P, None, None, 'sum')
# tensor(0.0863), 14.1 µs ± 408 ns

kl_divと比較して、さらに高速

3
hantian_pang

pytorch distribution object の形式の2つの確率分布がある場合。次に、関数torch.distributions.kl.kl_divergence(p, q)を使用することをお勧めします。ドキュメントについては、 link に従ってください

2