web-dev-qa-db-ja.com

RuntimeError:torch.DoubleTensor型のオブジェクトが必要ですが、引数#2 'weight'のtorch.FloatTensor型が見つかりました

私の入力テンソルはtorch.DoubleTensor型です。しかし、私は以下のRuntimeErrorを取得しました:

RuntimeError: Expected object of type torch.DoubleTensor but found type torch.FloatTensor for argument #2 'weight'

ウェイトのタイプを明示的に指定しませんでした(つまり、自分でウェイトを初期化しませんでした。ウェイトはpytorchによって作成されます)。フォワードプロセスでウェイトのタイプに影響を与えるものは何ですか?

どうもありがとう!!

23
Eric Kani

weightsおよびbiasesのデフォルトのタイプはtorch.FloatTensorです。そのため、モデルをtorch.DoubleTensorにキャストするか、入力をtorch.FloatTensorにキャストする必要があります。入力をキャストするには

X = X.float()

または、モデル全体をDoubleTensorにキャストします

model = model.double()

を使用して、すべてのテンソルのデフォルトのタイプを設定することもできます

pytorch.set_default_tensor_type('torch.DoubleTensor')

モデルのfloatへの変換よりも、入力をdoubleに変換することをお勧めします。GPUではdoubleデータ型の数学計算がかなり遅いためです。

31
layog

私もまったく同じエラーを受け取っていました。根本的な原因は、データロードコードの次のステートメントであることが判明しました。

t = t.astype(np.float)

ここで、np.floatはDoubleTensorにマップされる64ビットのfloatに変換されます。これを変更して、

t = t.astype(np.float32)

問題を解決しました。

1
Shital Shah