web-dev-qa-db-ja.com

Pytorchでのユークリッドノルムの計算..問題の理解と実装

ユークリッドノルムを計算するためのさまざまな実装について話している別のStackOverflowスレッドを見たことがありますが、特定の実装がなぜ/どのように機能するのかがわかりません。

コードはMMDメトリックの実装にあります: https://github.com/josipd/torch-two-sample/blob/master/torch_two_sample/statistics_diff.py

ここにいくつかの最初の定型文があります:

import torch
sample_1, sample_2 = torch.ones((10,2)), torch.zeros((10,2))

次に、上記のコードから取得する部分です。サンプルが連結されている理由がわかりません。

sample_12 = torch.cat((sample_1, sample_2), 0)
distances = pdist(sample_12, sample_12, norm=2)

次に、pdist関数に渡されます。

def pdist(sample_1, sample_2, norm=2, eps=1e-5):
    r"""Compute the matrix of all squared pairwise distances.
    Arguments
    ---------
    sample_1 : torch.Tensor or Variable
        The first sample, should be of shape ``(n_1, d)``.
    sample_2 : torch.Tensor or Variable
        The second sample, should be of shape ``(n_2, d)``.
    norm : float
        The l_p norm to be used.
    Returns
    -------
    torch.Tensor or Variable
        Matrix of shape (n_1, n_2). The [i, j]-th entry is equal to
        ``|| sample_1[i, :] - sample_2[j, :] ||_p``."""

ここで計算の要点に到達します

    n_1, n_2 = sample_1.size(0), sample_2.size(0)
    norm = float(norm)
    if norm == 2.:
        norms_1 = torch.sum(sample_1**2, dim=1, keepdim=True)
        norms_2 = torch.sum(sample_2**2, dim=1, keepdim=True)
        norms = (norms_1.expand(n_1, n_2) +
             norms_2.transpose(0, 1).expand(n_1, n_2))
        distances_squared = norms - 2 * sample_1.mm(sample_2.t())
        return torch.sqrt(eps + torch.abs(distances_squared))

なぜユークリッドノルムがこのように計算されるのか、私は途方に暮れています。どんな洞察も大歓迎です

5
Fosa

このコードブロックを段階的に見ていきましょう。ユークリッド距離の定義、つまりL2ノルムは

enter image description here

最も単純なケースを考えてみましょう。 2つのサンプルがあります。

enter image description here

サンプルaには、2つのベクトル_[a00, a01]_と_[a10, a11]_があります。サンプルbについても同じです。最初にnormを計算しましょう

_n1, n2 = a.size(0), b.size(0)  # here both n1 and n2 have the value 2
norm1 = torch.sum(a**2, dim=1)
norm2 = torch.sum(b**2, dim=1)
_

今、私たちは得る

enter image description here

次に、norms_1.expand(n_1, n_2)norms_2.transpose(0, 1).expand(n_1, n_2)があります

enter image description here

bが転置されていることに注意してください。 2つの合計はnormを与えます

enter image description here

sample_1.mm(sample_2.t())、それは2つの行列の乗算です。

enter image description here

したがって、手術後

_distances_squared = norms - 2 * sample_1.mm(sample_2.t())
_

あなたが得る

enter image description here

最後に、最後のステップは、行列内のすべての要素の平方根を取ることです。

8
Milo Lu