web-dev-qa-db-ja.com

テンソルからcuda()を削除する方法をpytorch

TypeError: expected torch.LongTensor (got torch.cuda.FloatTensor)を取得しました。

torch.cuda.FloatTensortorch.LongTensorに変換するにはどうすればよいですか?

  Traceback (most recent call last):
  File "train_v2.py", line 110, in <module>
    main()
  File "train_v2.py", line 81, in main
    model.update(batch)
  File "/home/Desktop/squad_vteam/src/model.py", line 131, in update
    loss_adv = self.adversarial_loss(batch, loss, self.network.Lexicon_encoder.embedding.weight, y)
  File "/home/Desktop/squad_vteam/src/model.py", line 94, in adversarial_loss
    adv_embedding = torch.LongTensor(adv_embedding)
TypeError: expected torch.LongTensor (got torch.cuda.FloatTensor)
7
Aerin

フロートテンソルfがあり、それをlongに変換したい場合は、long_tensor = f.long()を実行します。

cudaテンソルがあります。つまり、データがgpuにあり、それをcpuに移動したい場合は、cuda_tensor.cpu()を実行できます。

したがって、torch.cuda.FloatテンソルAをtorch.longに変換するには、A.long().cpu()を実行します。

6
Umang Gupta

Pytorch 0.4.0のベストプラクティスは、次のように記述することです デバイスに依存しないコード :つまり、.cuda()または.cpu()を使用する代わりに、単純に .to(torch.device("cpu"))

_A = A.to(dtype=torch.long, device=torch.device("cpu"))
_

.to()は「インプレース」操作ではないことに注意してください(たとえば、 この回答 を参照)。したがって、A.to(...)Aに割り当てる必要があります。

3
Shai

テンソルがある場合はt

t = t.cpu() 

古い方法になります。

t = t.to("cpu")

新しいAPIになります。

2
prosti