web-dev-qa-db-ja.com

Torch7トレーニング済みモデル(.t7)をPyTorchにロードする

ニューラルネットワークの実装にTorch7ライブラリを使用しています。ほとんどの場合、私は事前にトレーニングされたモデルに依存しています。 Luaでは、torch.load関数を使用して、torch.t7ファイルとして保存されたモデルをロードします。 PyTorch( http://pytorch.org )に切り替えることに興味があり、ドキュメントを読みました。事前にトレーニングされたモデルをロードするメカニズムに関する情報が見つかりませんでした。私が見つけた唯一の関連情報はこのページです: http://pytorch.org/docs/torch.html

しかし、このページで説明されている関数torch.loadは、pickleで保存されたファイルをロードしているようです。 PyTorchでの.t7モデルのロードに関する追加情報がある場合は、ここで共有してください。

7
Arul

PyTorch 1.0以降、torch.utils.serializationは完全に削除されています。したがって、LuaTorchからPyTorchにモデルをインポートすることはできなくなりました。代わりに、PyTorch0.4.1からpipconda環境にインストールし(この後で削除できるように)、 このリポジトリ を使用して変換することをお勧めします。トレーニングに使用できないtorch.nn.legacyモデルだけでなく、LuaTorchモデルからPyTorchモデルへ。次に、PyTorch1.xxを使用して何でもします。この方法で、変換したLuaTorchモデルをPyTorchでトレーニングすることもできます:)

3
Amir

正しい関数はload_lua

from torch.utils.serialization import load_lua

x = load_lua('x.t7')
9
elyase