web-dev-qa-db-ja.com

PyTorchでベクトルを行列に乗算する方法

私はそれを学ぶことを目的としてPyTorchで遊んでいますが、非常に愚かな質問があります:行列を単一のベクトルで乗算するにはどうすればよいですか?

私が試したものは次のとおりです。

_>>> import torch
>>> a = torch.Rand(4,4)
>>> a

 0.3162  0.4434  0.9318  0.8752
 0.0129  0.8609  0.6402  0.2396
 0.5720  0.7262  0.7443  0.0425
 0.4561  0.1725  0.4390  0.8770
[torch.FloatTensor of size 4x4]

>>> b = torch.Rand(4)
>>> b

 0.1813
 0.7090
 0.0329
 0.7591
[torch.FloatTensor of size 4]

>>> a.mm(b)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: invalid argument 2: dimension 1 out of range of 1D tensor at /Users/soumith/code/builder/wheel/pytorch-src/torch/lib/TH/generic/THTensor.c:24
>>> a.mm(b.t())
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: t() expects a 2D tensor, but self is 1D
>>> b.mm(a)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: matrices expected, got 1D, 2D tensors at /Users/soumith/code/builder/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:1288
>>> b.t().mm(a)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: t() expects a 2D tensor, but self is 1D
_

一方、私がやれば

_>>> b = torch.Rand(4,2)
_

その後、最初の試みであるa.mm(b)は問題なく動作します。だから問題は、私が行列ではなくベクトルを乗算していることだけです---しかし、どうすればこれを行うことができますか?

8
Nathaniel

あなたが探しています

_torch.mv(a,b)
_

将来的には、torch.matmul()も役に立つかもしれないことに注意してください。 torch.matmul()は引数の次元を推測し、それに応じて、ベクトル間のドット積、行列ベクトルまたはベクトル行列の乗算、行列乗算、または高次テンソルのバッチ行列乗算を実行します。

17
mexmex

これは、@ mexmexの正解および有用な回答を補足する自己回答です。

PyTorchでは、numpyとは異なり、1Dテンソルは1xNまたはNx1テンソルと交換できません。交換した場合

>>> b = torch.Rand(4)

>>> b = torch.Rand((4,1))

その後、列ベクトルがあり、mmとの行列乗算は期待どおりに機能します。

しかし、@ mexmexが指摘しているように、行列ベクトル乗算用のmv関数と、その次元に応じて適切な関数をディスパッチするmatmul関数があるため、これは必要ありません。入力。

5
Nathaniel