web-dev-qa-db-ja.com

Tensorflowの正しいバッチ正規化機能とは何ですか?

Tensorflow 1.4で、バッチ正規化を行う2つの関数を見つけましたが、それらは同じように見えます。

  1. tf.layers.batch_normalizationリンク
  2. tf.contrib.layers.batch_normリンク

どの機能を使用すればよいですか?どちらがより安定していますか?

19
KimHee

リストに追加するだけで、テンソルフローでバッチノルムを実行する方法がいくつかあります。

  • tf.nn.batch_normalization は低レベルの操作です。呼び出し元は、meanおよびvarianceテンソル自体を処理する責任があります。
  • tf.nn.fused_batch_norm は、前のものと同様の別の低レベルopです。違いは、4D入力テンソル用に最適化されていることです。これは、畳み込みニューラルネットワークの通常の場合です。 tf.nn.batch_normalizationは、1より大きいランクのテンソルを受け入れます。
  • tf.layers.batch_normalization は、前のopsに対する高レベルのラッパーです。最大の違いは、実行中の平均テンソルと分散テンソルの作成と管理を行い、可能な場合は高速融合演算を呼び出すことです。通常、これはデフォルトの選択肢である必要があります。
  • tf.contrib.layers.batch_norm は、コアAPI(つまり、tf.layers)に移行する前のバッチ標準の初期実装です。将来のリリースで削除される可能性があるため、使用は推奨されません。
  • tf.nn.batch_norm_with_global_normalization はもう1つの推奨されないopです。現在、tf.nn.batch_normalizationへの呼び出しを委任していますが、将来的には削除される可能性があります。
  • 最後に、Kerasレイヤー keras.layers.BatchNormalization もあります。これは、テンソルフローバックエンドの場合にtf.nn.batch_normalizationを呼び出します。
45
Maxim

doctf.contribは、揮発性または実験的なコードを含む貢献モジュールです。 functionが完了すると、このモジュールから削除されます。履歴バージョンとの互換性を保つために、現在2つあります。

したがって、前者tf.layers.batch_normalization がおすすめ。

5
dxf