web-dev-qa-db-ja.com

TensorFlowはファイルにグラフを保存/ロードします

これまでに収集したことから、TensorFlowグラフをファイルにダンプしてから別のプログラムに読み込む方法はいくつかありますが、それらの動作に関する明確な例/情報を見つけることができませんでした。私はすでにこれを知っています:

  1. tf.train.Saver()を使用してモデルの変数をチェックポイントファイル(.ckpt)に保存し、後で復元します( source
  2. モデルを.pbファイルに保存し、tf.train.write_graph()およびtf.import_graph_def()を使用してロードし直します(- source
  3. .pbファイルからモデルを読み込み、再トレーニングし、Bazelを使用して新しい.pbファイルにダンプします( source
  4. グラフをフリーズして、グラフと重みを一緒に保存します( source
  5. as_graph_def()を使用してモデルを保存し、重み/変数については定数にマップします( source

ただし、これらのさまざまな方法に関するいくつかの質問を解決できませんでした。

  1. チェックポイントファイルに関しては、モデルの訓練された重みのみを保存しますか?チェックポイントファイルを新しいプログラムにロードし、モデルを実行するために使用できますか、それとも特定の時間/段階でモデルの重みを保存する方法として機能しますか?
  2. tf.train.write_graph()については、重み/変数も保存されますか?
  3. Bazelに関しては、再トレーニングのために.pbファイルにのみ保存/ロードできますか?グラフを.pbにダンプするだけの単純なBazelコマンドはありますか?
  4. フリーズに関して、tf.import_graph_def()を使用してフリーズグラフをロードできますか?
  5. TensorFlowのAndroidデモは、.pbファイルからGoogleのInceptionモデルにロードされます。自分の.pbファイルを置き換えたい場合、どうすればいいですか?ネイティブコード/メソッドを変更する必要がありますか?
  6. 一般的に、これらすべての方法の違いは正確に何ですか?または、より広く、as_graph_def() /。ckpt/.pbの違いは何ですか?

要するに、私が探しているのは、グラフ(さまざまな操作など)とその重み/変数の両方をファイルに保存する方法で、それを使用してグラフと重みを別のプログラムに読み込むことができます、使用(必ずしも継続/再トレーニングではありません)。

このトピックに関するドキュメントは非常に簡単ではないため、回答/情報をいただければ幸いです。

85
Technicolor

TensorFlowでモデルを保存する問題に対処する方法はたくさんありますが、少し混乱する可能性があります。各サブ質問を順番に実行します。

  1. チェックポイントファイル(たとえば saver.save()tf.train.Saver オブジェクトで呼び出して生成)には、重みと同じプログラムで定義された他の変数のみが含まれます。それらを別のプログラムで使用するには、関連するグラフ構造を再作成する必要があります(たとえば、コードを実行して再構築するか、 tf.import_graph_def() を呼び出します)。 saver.save()を呼び出すと、 MetaGraphDef を含むファイルも生成されます。このファイルには、グラフと、チェックポイントからの重みをそのグラフに関連付ける方法の詳細が含まれます。詳細については、 チュートリアル を参照してください。

  2. tf.train.write_graph() は、グラフ構造のみを書き込みます。重みではありません。

  3. Bazelは、TensorFlowグラフの読み取りまたは書き込みとは無関係です。 (たぶん私はあなたの質問を誤解しています:コメントでそれを明確にしてください。)

  4. 凍結グラフは、 tf.import_graph_def() を使用してロードできます。この場合、重みは(通常)グラフに埋め込まれているため、別のチェックポイントをロードする必要はありません。

  5. 主な変更点は、モデルに入力されるテンソルの名前と、モデルから取得されるテンソルの名前を更新することです。 TensorFlow Androidデモでは、これは TensorFlowClassifier.initializeTensorFlow() に渡されるinputNameおよびoutputName文字列に対応します。

  6. GraphDefはプログラム構造であり、通常はトレーニングプロセスを通じて変更されません。チェックポイントは、トレーニングプロセスの状態のスナップショットであり、通常、トレーニングプロセスのすべてのステップで変化します。その結果、TensorFlowはこれらのタイプのデータに異なるストレージ形式を使用し、低レベルAPIはそれらを保存およびロードするさまざまな方法を提供します。 MetaGraphDef ライブラリ、 Keras 、および skflow などのより高レベルのライブラリは、これらのメカニズムに基づいて構築し、モデル全体。

71
mrry

次のコードを試すことができます:

with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)
1