web-dev-qa-db-ja.com

TensorFlowでの3D畳み込みによるバッチ正規化

3D畳み込みに依存するモデルを実装しており(アクション認識に似たタスクの場合)、バッチ正規化を使用したいと思います( [Ioffe&Szegedy 2015] を参照)。 3Dコンバージョンに焦点を当てたチュートリアルが見つからなかったため、ここで確認したい短いチュートリアルを作成します。

以下のコードはTensorFlowr0.12を参照しており、変数を明示的にインスタンス化します。つまり、tf.contrib.layers.batch_norm()関数を除いてtf.contrib.learnを使用していません。私はこれを行って、内部で物事がどのように機能するかをよりよく理解し、実装の自由度を高めています(例:変数の要約)。

最初に完全に接続されたレイヤーの例を記述し、次に2Dコンボリューションの例を書き、最後に3Dケースの例を書くことで、3Dコンボリューションのケースにスムーズに進みます。コードを実行している間、すべてが正しく行われたかどうかを確認できればすばらしいでしょう。コードは実行されますが、バッチ正規化の適用方法については100%確信が持てません。この投稿は、より詳細な質問で終わります。

import tensorflow as tf

# This flag is used to allow/prevent batch normalization params updates
# depending on whether the model is being trained or used for prediction.
training = tf.placeholder_with_default(True, shape=())

完全接続(FC)ケース

# Input.
INPUT_SIZE = 512
u = tf.placeholder(tf.float32, shape=(None, INPUT_SIZE))

# FC params: weights only, no bias as per [Ioffe & Szegedy 2015].
FC_OUTPUT_LAYER_SIZE = 1024
w = tf.Variable(tf.truncated_normal(
    [INPUT_SIZE, FC_OUTPUT_LAYER_SIZE], dtype=tf.float32, stddev=1e-1))

# Layer output with no activation function (yet).
fc = tf.matmul(u, w)

# Batch normalization.
fc_bn = tf.contrib.layers.batch_norm(
    fc,
    center=True,
    scale=True,
    is_training=training,
    scope='fc-batch_norm')

# Activation function.
fc_bn_relu = tf.nn.relu(fc_bn)
print(fc_bn_relu)  # Tensor("Relu:0", shape=(?, 1024), dtype=float32)

2Dたたみ込み(CNN)レイヤーの場合

# Input: 640x480 RGB images (whitened input, hence tf.float32).
INPUT_HEIGHT = 480
INPUT_WIDTH = 640
INPUT_CHANNELS = 3
u = tf.placeholder(tf.float32, shape=(None, INPUT_HEIGHT, INPUT_WIDTH, INPUT_CHANNELS))

# CNN params: wights only, no bias as per [Ioffe & Szegedy 2015].
CNN_FILTER_HEIGHT = 3  # Space dimension.
CNN_FILTER_WIDTH = 3  # Space dimension.
CNN_FILTERS = 128
w = tf.Variable(tf.truncated_normal(
    [CNN_FILTER_HEIGHT, CNN_FILTER_WIDTH, INPUT_CHANNELS, CNN_FILTERS],
    dtype=tf.float32, stddev=1e-1))

# Layer output with no activation function (yet).
CNN_LAYER_STRIDE_VERTICAL = 1
CNN_LAYER_STRIDE_HORIZONTAL = 1
CNN_LAYER_PADDING = 'SAME'
cnn = tf.nn.conv2d(
    input=u, filter=w,
    strides=[1, CNN_LAYER_STRIDE_VERTICAL, CNN_LAYER_STRIDE_HORIZONTAL, 1],
    padding=CNN_LAYER_PADDING)

# Batch normalization.
cnn_bn = tf.contrib.layers.batch_norm(
    cnn,
    data_format='NHWC',  # Matching the "cnn" tensor which has shape (?, 480, 640, 128).
    center=True,
    scale=True,
    is_training=training,
    scope='cnn-batch_norm')

# Activation function.
cnn_bn_relu = tf.nn.relu(cnn_bn)
print(cnn_bn_relu)  # Tensor("Relu_1:0", shape=(?, 480, 640, 128), dtype=float32)

3D畳み込み(CNN3D)レイヤーの場合

# Input: sequence of 9 160x120 RGB images (whitened input, hence tf.float32).
INPUT_SEQ_LENGTH = 9
INPUT_HEIGHT = 120
INPUT_WIDTH = 160
INPUT_CHANNELS = 3
u = tf.placeholder(tf.float32, shape=(None, INPUT_SEQ_LENGTH, INPUT_HEIGHT, INPUT_WIDTH, INPUT_CHANNELS))

# CNN params: wights only, no bias as per [Ioffe & Szegedy 2015].
CNN3D_FILTER_LENGHT = 3  # Time dimension.
CNN3D_FILTER_HEIGHT = 3  # Space dimension.
CNN3D_FILTER_WIDTH = 3  # Space dimension.
CNN3D_FILTERS = 96
w = tf.Variable(tf.truncated_normal(
    [CNN3D_FILTER_LENGHT, CNN3D_FILTER_HEIGHT, CNN3D_FILTER_WIDTH, INPUT_CHANNELS, CNN3D_FILTERS],
    dtype=tf.float32, stddev=1e-1))

# Layer output with no activation function (yet).
CNN3D_LAYER_STRIDE_TEMPORAL = 1
CNN3D_LAYER_STRIDE_VERTICAL = 1
CNN3D_LAYER_STRIDE_HORIZONTAL = 1
CNN3D_LAYER_PADDING = 'SAME'
cnn3d = tf.nn.conv3d(
    input=u, filter=w,
    strides=[1, CNN3D_LAYER_STRIDE_TEMPORAL, CNN3D_LAYER_STRIDE_VERTICAL, CNN3D_LAYER_STRIDE_HORIZONTAL, 1],
    padding=CNN3D_LAYER_PADDING)

# Batch normalization.
cnn3d_bn = tf.contrib.layers.batch_norm(
    cnn3d,
    data_format='NHWC',  # Matching the "cnn" tensor which has shape (?, 9, 120, 160, 96).
    center=True,
    scale=True,
    is_training=training,
    scope='cnn3d-batch_norm')

# Activation function.
cnn3d_bn_relu = tf.nn.relu(cnn3d_bn)
print(cnn3d_bn_relu)  # Tensor("Relu_2:0", shape=(?, 9, 120, 160, 96), dtype=float32)

私が確認したいのは、上記のコードが、セクションの終わりにある [Ioffe&Szegedy 2015] で説明されているようにバッチ正規化を正確に実装しているかどうかです。 3.2:

たたみ込みレイヤーの場合は、たたみ込みプロパティに従って正規化を行う必要があります。これにより、同じ場所にある同じ機能マップの異なる要素が同じ方法で正規化されます。これを実現するために、すべての場所で、ミニバッチ内のすべてのアクティベーションを共同で正規化します。 [...]アルグ。 2も同様に変更され、推論中にBN変換が同じ線形変換を特定の機能マップの各アクティブ化に適用します。

[〜#〜] update [〜#〜]上記のコードは、3Dコンバージョンの場合にも正しいと思います。実際、トレーニング可能なすべての変数を出力する場合にモデルを定義すると、ベータ変数とガンマ変数の予想数も表示されます。例えば:

Tensor("conv3a/conv3d_weights/read:0", shape=(3, 3, 3, 128, 256), dtype=float32)
Tensor("BatchNorm_2/beta/read:0", shape=(256,), dtype=float32)
Tensor("BatchNorm_2/gamma/read:0", shape=(256,), dtype=float32)

BNにより、機能マップごとに1組のベータとガンマが学習されるため(合計256)、これは問題ないように見えます。


[Ioffe&Szegedy 2015]:バッチ正規化:内部共変量シフトを減らすことでディープネットワークトレーニングを加速

17
Alessio B

これは3DBatchnormに関するすばらしい投稿であり、batchnormがランク1より大きいテンソルに適用できることに気付かないことがよくあります。コードは正しいですが、これに関するいくつかの重要な注意事項を追加せざるを得ませんでした。

  • 「標準」の2Dバッチノルム(4Dテンソルを受け入れる)は、適用されるfused_batch_norm実装をサポートしているため、テンソルフローでは3D以上よりも大幅に高速化できます 1つのカーネル操作

    融合バッチノルムは、バッチ正規化を行うために必要な複数の操作を1つのカーネルに結合します。バッチノルムは、一部のモデルでは操作時間の大部分を占める高価なプロセスです。融合バッチノルムを使用すると、12%〜30%のスピードアップが得られます。

    3Dフィルターもサポートする GitHubの問題 がありますが、最近のアクティビティはなく、この時点で問題は未解決のままクローズされています。

  • 元の論文では、ReLUをアクティブ化する前にbatchnormを使用するように規定していますが(上記のコードでそれを実行しました)、batchnormafterアクティベーション。これがFrancoisCholletによる Keras GitHub へのコメントです:

    ...私は、Christian [Szegedy]によって書かれた最近のコードがBNの前にreluを適用することを保証できます。しかし、それはまだ時折議論のトピックです。

  • 正規化のアイデアを実際に適用することに関心のある人のために、このアイデアの最近の研究開発があります。つまり、元のバッチノルムの特定の欠点を修正する 重みの正規化 および レイヤーの正規化 です。たとえば、LSTMやリカレントネットワークでより効果的に機能します。

2
Maxim