web-dev-qa-db-ja.com

PyTorchで行列の積を行う方法

Numpyでは、次のような単純な行列乗算を行うことができます。

a = numpy.arange(2*3).reshape(3,2)
b = numpy.arange(2).reshape(2,1)
print(a)
print(b)
print(a.dot(b))

ただし、PyTorch Tensorsでこれをしようとすると、これは機能しません。

a = torch.Tensor([[1, 2, 3], [1, 2, 3]]).view(-1, 2)
b = torch.Tensor([[2, 1]]).view(2, -1)
print(a)
print(a.size())

print(b)
print(b.size())

print(torch.dot(a, b))

このコードは次のエラーをスローします。

RuntimeError:/Users/soumith/code/builder/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:503での一貫性のないテンソルサイズ

PyTorchで行列乗算を実行する方法はありますか?

35
blckbird

あなたが探しています

torch.mm(a,b)

torch.dot()np.dot()とは異なる動作をすることに注意してください。何が望ましいかについての議論がありました ここ 。具体的には、torch.dot()abの両方を(元の形状に関係なく)1Dベクトルとして扱い、その内積を計算します。この動作により、aが長さ6のベクトルになり、bが長さ2のベクトルになるため、エラーがスローされます。したがって、それらの内積は計算できません。 PyTorchでの行列乗算には、torch.mm()を使用します。対照的に、Numpyのnp.dot()はより柔軟です。 1D配列の内積を計算し、2D配列の行列乗算を実行します。

52
mexmex

行列(ランク2テンソル)乗算を行いたい場合、4つの同等の方法で実行できます。

AB = A.mm(B) # computes A.B (matrix multiplication)
# or
AB = torch.mm(A, B)
# or
AB = torch.matmul(A, B)
# or, even simpler
AB = A @ B # Python 3.5+

いくつかの微妙な点があります。 PyTorch documentation から:

torch.mmはブロードキャストしません。マトリックス製品のブロードキャストについては、torch.matmul()を参照してください。

たとえば、2つの1次元ベクトルをtorch.mmで乗算したり、バッチマトリックスを乗算したりすることはできません(ランク3)。そのためには、より汎用性の高いtorch.matmulを使用する必要があります。 torch.matmulのブロードキャスト動作の広範なリストについては、 documentation を参照してください。

要素単位の乗算では、単純に実行できます(AとBが同じ形状の場合)

A * B # element-wise matrix multiplication (Hadamard product)
26
BiBi

torch.mm(a, b)またはtorch.matmul(a, b)を使用します
両方とも同じです。

>>> torch.mm
<built-in method mm of type object at 0x11712a870>
>>> torch.matmul
<built-in method matmul of type object at 0x11712a870>

知っておくと便利なオプションがもう1つあります。それは@演算子です。 @サイモン・H.

>>> a = torch.randn(2, 3)
>>> b = torch.randn(3, 4)
>>> a@b
tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
        [ 0.8699, -0.3445,  1.4122, -0.5826]])
>>> a.mm(b)
tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
        [ 0.8699, -0.3445,  1.4122, -0.5826]])
>>> a.matmul(b)
tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
        [ 0.8699, -0.3445,  1.4122, -0.5826]])    

3つの結果は同じです。

関連リンク:
行列乗算演算子
PEP 465-行列乗算専用の挿入演算子

4
David Jung