web-dev-qa-db-ja.com

Tensorflow Kerasによるモデル間の重みのコピー

Tensorflow 1.4.1のKerasを使用して、あるモデルから別のモデルにどのようにウェイトをコピーしますか?

ある程度の背景として、DeepMindによるDQNの出版に続いて、Atariゲーム用のdeep-qネットワーク(DQN)を実装しようとしています。私の理解では、実装にはQとQ 'という2つのネットワークが使用されます。 Qの重みは勾配降下を使用してトレーニングされ、その後、重みはQ 'に定期的にコピーされます。

QとQ 'の作成方法は次のとおりです。

ACT_SIZE   = 4
LEARN_RATE = 0.0025
OBS_SIZE   = 128

def buildModel():
  model = tf.keras.models.Sequential()

  model.add(tf.keras.layers.Lambda(lambda x: x / 255.0, input_shape=OBS_SIZE))
  model.add(tf.keras.layers.Dense(128, activation="relu"))
  model.add(tf.keras.layers.Dense(128, activation="relu"))
  model.add(tf.keras.layers.Dense(ACT_SIZE, activation="linear"))
  opt = tf.keras.optimizers.RMSprop(lr=LEARN_RATE)

  model.compile(loss="mean_squared_error", optimizer=opt)

  return model

QとQ 'を取得するために2回呼び出します。

重みをコピーしようとする私の試みである以下のupdateTargetModelメソッドがあります。コードは正常に実行されますが、私の全体的なDQN実装は失敗します。私はこれが、あるネットワークから別のネットワークに重みをコピーする有効な方法であるかどうかを検証しようとしています。

def updateTargetModel(model, targetModel):
  modelWeights       = model.trainable_weights
  targetModelWeights = targetModel.trainable_weights

  for i in range(len(targetModelWeights)):
    targetModelWeights[i].assign(modelWeights[i])

ここには、ディスクとの間でウェイトを保存およびロードすることについて説明する別の質問があります( Tensorflow Copy Weights Issue )が、受け入れられた答えはありません。個々のレイヤーからウェイトをロードすることについての質問もあります( 1つのConv2Dレイヤーから別のConv2Dレイヤーにウェイトをコピーする )が、モデル全体のウェイトをコピーしたいです。

9
avejidah

実際、あなたがやったことは、単に重みをコピーするだけではありません。これらの2つのモデルを同一にしました常時。両方のモデルが同じweights変数を持っているため、1つのモデルを更新するたびに(2番目のモデルも更新されます)。

ウェイトをコピーする場合-最も簡単な方法は次のコマンドです。

target_model.set_weights(model.get_weights()) 
24
Marcin Możejko