web-dev-qa-db-ja.com

訓練可能な変数を訓練不可能にすることは可能ですか?

スコープにtrainable variableを作成しました。後で、同じスコープに入り、スコープをreuse_variablesに設定し、get_variableを使用して同じ変数を取得しました。ただし、変数のトレーニング可能なプロパティをFalseに設定することはできません。私のget_variable行は次のようなものです:

weight_var = tf.get_variable('weights', trainable = False)

ただし、変数'weights'tf.trainable_variablesの出力に残っています。

get_variableを使用して、共有変数のtrainableフラグをFalseに設定できますか?

これを行う理由は、モデルでVGGネットから事前トレーニングされた低レベルフィルターを再利用しようとしているため、以前のようにグラフを作成し、重み変数を取得し、VGGフィルター値を割り当てたいからです。重み変数に追加し、次のトレーニングステップ中に固定します。

33
Wei Liu

ドキュメントとコードを見た後、not_TRAINABLE_VARIABLES_から変数を削除する方法を見つけることができませんでした。

ここで何が起こるかです:

  • tf.get_variable('weights', trainable=True)が初めて呼び出されると、変数は_TRAINABLE_VARIABLES_のリストに追加されます。
  • 2回目にtf.get_variable('weights', trainable=False)を呼び出すと、同じ変数を取得しますが、変数_trainable=False_のリストに変数が既に存在するため、引数_TRAINABLE_VARIABLES_は効果がありません(そしてそこからそれを削除する方法はありませんそこから)

最初の解決策

オプティマイザのminimizeメソッドを呼び出すとき( doc。 を参照)、オプティマイザに必要な変数を引数として_var_list=[...]_を渡すことができます。

たとえば、最後の2つを除くVGGのすべてのレイヤーをフリーズする場合、最後の2つのレイヤーの重みを_var_list_に渡すことができます。

第二の解決策

tf.train.Saver()を使用して変数を保存し、後で復元することができます( このチュートリアル を参照)。

  • まず、すべてのトレーニング可能な変数を使用してVGGモデル全体をトレーニングします。 saver.save(sess, "/path/to/dir/model.ckpt")を呼び出して、チェックポイントファイルに保存します。
  • 次に(別のファイルで)2番目のバージョンをnon trainable変数でトレーニングします。 saver.restore(sess, "/path/to/dir/model.ckpt")で以前に保存された変数をロードします。

オプションで、チェックポイントファイルに一部の変数のみを保存することもできます。詳細については doc を参照してください。

28

事前に訓練されたネットワークの特定の層のみを訓練または最適化する場合、これは知っておくべきことです。

TensorFlowのminimizeメソッドは、オプションの引数var_listを取ります。これは、逆伝播によって調整される変数のリストです。

var_listを指定しない場合、グラフ内の任意のTF変数はオプティマイザーによって調整できます。 var_listでいくつかの変数を指定すると、TFは他のすべての変数を定数に保持します。

jonbruner と彼の協力者が使用したスクリプトの例を次に示します。

tvars = tf.trainable_variables()
g_vars = [var for var in tvars if 'g_' in var.name]
g_trainer = tf.train.AdamOptimizer(0.0001).minimize(g_loss, var_list=g_vars)

これにより、変数名に「g_」が含まれる以前に定義したすべての変数が検索され、それらがリストに追加され、ADAMオプティマイザーが実行されます。

関連する回答は Quora にあります。

10
rocksyne

トレーニング可能な変数のリストから変数を削除するには、最初にコレクションにアクセスします:trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)そこで、_trainable_collection_にはトレーニング可能な変数のコレクションへの参照が含まれます。このリストから要素をポップする場合、たとえばtrainable_collection.pop(0)を実行すると、対応する変数がトレーニング可能な変数から削除されるため、この変数はトレーニングされません。

これはpopで機能しますが、removeを正しい引数で正しく使用する方法を見つけるのに苦労しています。そのため、変数のインデックスに依存しません。

編集:グラフ内の変数の名前がある場合(グラフprotobufを調べるか、Tensorboardを使用する方が簡単です)、それを使用してリストをループできますトレーニング可能な変数の後に、トレーニング可能なコレクションから変数を削除します。例:_"batch_normalization/gamma:0"_および_"batch_normalization/beta:0"_ [〜#〜] not [〜#〜]という名前の変数をトレーニングしたいが、それらは既に_TRAINABLE_VARIABLES_コレクション。私にできることは: `

_#gets a reference to the list containing the trainable variables
trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
variables_to_remove = list()
for vari in trainable_collection:
    #uses the attribute 'name' of the variable
    if vari.name=="batch_normalization/gamma:0" or vari.name=="batch_normalization/beta:0":
        variables_to_remove.append(vari)
for rem in variables_to_remove:
    trainable_collection.remove(rem)
_

`これにより、コレクションから2つの変数が正常に削除され、それらはもはやトレーニングされません。

6
Elisio Quintino

Tf.get_collectionではなく、tf.get_collection_refを使用してコレクションの参照を取得できます。

0
Yuki