web-dev-qa-db-ja.com

PyTorchで訓練されたモデルを保存するための最良の方法は?

私はPyTorchで訓練されたモデルを保存するための代替方法を探していました。これまでのところ、2つの選択肢があります。

  1. モデルを保存するには torch.save() および torch.load() モデルをロードします。
  2. model.state_dict() 訓練されたモデルを保存し、 model.load_state_dict() 保存したモデルをロードします。

私はこの 議論 に遭遇しました。そこでアプローチ2はアプローチ1よりも推奨されます。

私の質問は、なぜ2番目のアプローチが好ましいのかということです。それは、 torch.nn モジュールがそれらの2つの機能を持っているからであり、それらを使用することが推奨されますか?

109
Wasi Ahmad

このページ をgithubレポジトリで見つけました。ここにコンテンツを貼り付けるだけです。


モデルを保存するための推奨アプローチ

モデルのシリアル化と復元には、主に2つの方法があります。

最初の(推奨)はモデルパラメータのみを保存してロードします。

torch.save(the_model.state_dict(), PATH)

じゃあ後で:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

2番目はモデル全体を保存してロードします。

torch.save(the_model, PATH)

じゃあ後で:

the_model = torch.load(PATH)

ただし、この場合、シリアル化されたデータは特定のクラスと使用される正確なディレクトリ構造にバインドされているため、他のプロジェクトで使用したり、重大なリファクタリングをしたりすると、さまざまな方法で破損します。

131
dontloo

それはあなたがやりたいことによります。

ケース#1:モデルを保存して推論に使用する:モデルを保存して復元し、モデルを評価モードに変更します。 。これは、デフォルトでBatchNormDropoutレイヤーがデフォルトで構築モードであるために行われます。

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

ケース#2:後でトレーニングを再開するためにモデルを保存する:保存しようとしているモデルのトレーニングを継続する必要がある場合は、より多く保存する必要があります。モデルだけです。また、オプティマイザの状態、エポック、スコアなどを保存する必要があります。これは、次のようにして行います。

state = {
    'Epoch': Epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

トレーニングを再開するには、次のようにします。state = torch.load(filepath)、そして、各オブジェクトの状態を復元するには、次のようにします。

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

トレーニングを再開しているので、ロード時に状態を復元したらNOTmodel.eval()を呼び出さないでください。

ケース#3:自分のコードにアクセスできない他の人が使うモデル:Tensorflowでは、アーキテクチャとファイルの両方を定義する.pbファイルを作成できます。モデルの重みこれは特にTensorflow serveを使うときにとても便利です。 Pytorchでこれを行うのと同じ方法は次のようになります。

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

この方法はまだ完全な証拠ではありませんし、pytorchはまだ多くの変更を受けているので、私はそれをお勧めしません。

83
Jadiel de Armas

pickle Pythonライブラリは、Pythonオブジェクトをシリアライズおよびデシリアライズするためのバイナリプロトコルを実装しています。

import torch(またはPyTorchを使用するとき)はあなたに代わってimport pickleを生成します。pickle.dump()pickle.load()を直接呼び出す必要はありません。これらはオブジェクトを保存してロードするためのメソッドです。

実際、torch.save()torch.load()pickle.dump()pickle.load()をラップします。

もう1つの回答であるstate_dictには、さらにいくつかのメモが必要です。

PyTorchの内部にはどんなstate_dictがありますか?実際には2つのstate_dictがあります。

PyTorchモデルはtorch.nn.Moduleが学習可能なパラメータを取得するためのmodel.parameters()呼び出しを持っています(wとb)。これらの学習可能なパラメータは、いったんランダムに設定されると、学習するにつれて徐々に更新されます。学習可能なパラメータは最初のstate_dictです。

2番目のstate_dictはオプティマイザの状態辞書です。オプティマイザもモデルの一部です。あなたは、オプティマイザが私たちの学習可能なパラメータを改善するために使われていることを思い出してください。しかし、オプティマイザstate_dictは固定されています。そこで学ぶことは何もありません。

state_dictオブジェクトはPythonの辞書なので、保存、更新、変更、復元が簡単にでき、PyTorchモデルとオプティマイザに非常に多くのモジュール性を追加します。

これを説明するために、超簡単なモデルを作成しましょう。

import torch
import torch.optim as optim

model = torch.nn.Linear(5, 2)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Model weight:")    
print(model.weight)

print("Model bias:")    
print(model.bias)

print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

このコードは以下を出力します。

Model's state_dict:
weight   torch.Size([2, 5])
bias     torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328,  0.1360,  0.1553, -0.1838, -0.0316],
        [ 0.0479,  0.1760,  0.1712,  0.2244,  0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]

これは最小限のモデルです。あなたはシーケンシャルのスタックを追加しようとするかもしれません

model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.Conv2d(A, B, C)
          torch.nn.Linear(H, D_out),
        )

学習可能なパラメータを持つレイヤ(たたみ込みレイヤ、線形レイヤなど)と登録バッファ(batchnormレイヤ)だけがモデルのstate_dictにエントリを持つことに注意してください。

学習できないことは、オプティマイザオブジェクトstate_dictに属しています。これには、オプティマイザの状態と使用されているハイパーパラメータに関する情報が含まれています。

ストーリーの残りの部分は同じです。予測のための推論段階(これは訓練後にモデルを使用する段階である)。学習したパラメータに基づいて予測を行います。そのため、推論のために、パラメータmodel.state_dict()を保存する必要があります。

torch.save(model.state_dict(), filepath)

そして後で使うためにmodel.load_state_dict(torch.load(filepath))model.eval()

注:最後の行model.eval()を忘れないでください。これはモデルを読み込んだ後に重要になります。

またtorch.save(model.parameters(), filepath)を保存しようとしないでください。 model.parameters()は単なるジェネレータオブジェクトです。

反対に、torch.save(model, filepath)はモデルオブジェクト自体を保存しますが、モデルはオプティマイザのstate_dictを持っていないことを覚えておいてください。オプティマイザの状態辞書を保存するために@Jadiel de Armasによる他の優れた答えをチェックしてください。

0
prosti

一般的なPyTorchの慣例は、.ptか.pthファイル拡張子を使ってモデルを保存することです。

モデル全体を保存/読み込み保存:

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

負荷:

モデルクラスはどこかで定義されなければなりません

model = torch.load(PATH)
model.eval()
0
harsh