web-dev-qa-db-ja.com

PyTorchでdata.norm()<1000は何をしますか?

私はPyTorchチュートリアル here に従っています。と言う

_x = torch.randn(3, requires_grad=True)

y = x * 2
while y.data.norm() < 1000:
    y = y * 2

print(y)

Out:    
tensor([-590.4467,   97.6760,  921.0221])
_

誰かがここでdata.norm()が何をするのか説明できますか? _.randn_を_.ones_に変更すると、その出力はtensor([ 1024., 1024., 1024.])になります。

14
voo_doo

それは単にテンソルのL2ノルム(別名ユークリッドノルム)です。以下は、再現可能な図です。

In [15]: x = torch.randn(3, requires_grad=True)

In [16]: y = x * 2

In [17]: y.data
Out[17]: tensor([-1.2510, -0.6302,  1.2898])

In [18]: y.data.norm()
Out[18]: tensor(1.9041)

# computing the norm using elementary operations
In [19]: torch.sqrt(torch.sum(torch.pow(y, 2)))
Out[19]: tensor(1.9041)

まず、テンソルyのすべての要素を二乗し、次にそれらを合計して、最終的に平方根を取ります。これらの演算は、いわゆるL2またはユークリッドノルムを計算します。

10
kmario23

@ kmario23の説明に基づいて、ユークリッド距離/ベクトルの大きさが少なくとも1000になるまでベクトルの要素を2倍します。

(1,1,1)のベクトルの例では、(512、512、512)に増加します。ここで、l2ノルムは約886です。これは1000より小さいため、再び2倍されて( 1024、1024、1024)。これは1000を超える大きさなので、停止します。

1
Jonathan
y.data.norm() 

と同等です

torch.sqrt(torch.sum(torch.pow(y, 2)))
0
aimuch