web-dev-qa-db-ja.com

後続のミニバッチのRNN初期状態はリセットされますか?

TFのRNNの初期状態が後続のミニバッチでリセットされるか、または Ilya Sutskever et al。、ICLR 2015 =?

17
VM_AI

tf.nn.dynamic_rnn() または tf.nn.rnn() 操作では、_initial_state_パラメーターを使用してRNNの初期状態を指定できます。このパラメーターを指定しない場合、非表示の状態は各トレーニングバッチの開始時にゼロベクトルに初期化されます。

TensorFlowでは、tf.Variable()でテンソルをラップして、複数のセッションの実行間で値をグラフに保持できます。オプティマイザはデフォルトですべてのトレーニング可能な変数を調整するため、それらをトレーニング不可としてマークしてください。

_data = tf.placeholder(tf.float32, (batch_size, max_length, frame_size))

cell = tf.nn.rnn_cell.GRUCell(256)
state = tf.Variable(cell.zero_states(batch_size, tf.float32), trainable=False)
output, new_state = tf.nn.dynamic_rnn(cell, data, initial_state=state)

with tf.control_dependencies([state.assign(new_state)]):
    output = tf.identity(output)

sess = tf.Session()
sess.run(tf.initialize_all_variables())
sess.run(output, {data: ...})
_

このコードはテストしていませんが、正しい方向へのヒントが得られるはずです。また、ステートセーバーオブジェクトを提供できる tf.nn.state_saving_rnn() もありますが、まだ使用していません。

20
danijar

Danijarの答えに加えて、状態がタプル(state_is_Tuple=True)であるLSTMのコードを以下に示します。また、複数のレイヤーもサポートしています。

初期状態がゼロの状態変数を取得するための関数と、LSTMの最後の非表示状態で状態変数を更新するためにsession.runに渡すことができる操作を返すための関数の2つの関数を定義します。

def get_state_variables(batch_size, cell):
    # For each layer, get the initial state and make a variable out of it
    # to enable updating its value.
    state_variables = []
    for state_c, state_h in cell.zero_state(batch_size, tf.float32):
        state_variables.append(tf.contrib.rnn.LSTMStateTuple(
            tf.Variable(state_c, trainable=False),
            tf.Variable(state_h, trainable=False)))
    # Return as a Tuple, so that it can be fed to dynamic_rnn as an initial state
    return Tuple(state_variables)


def get_state_update_op(state_variables, new_states):
    # Add an operation to update the train states with the last state tensors
    update_ops = []
    for state_variable, new_state in Zip(state_variables, new_states):
        # Assign the new state to the state variables on this layer
        update_ops.extend([state_variable[0].assign(new_state[0]),
                           state_variable[1].assign(new_state[1])])
    # Return a Tuple in order to combine all update_ops into a single operation.
    # The Tuple's actual value should not be used.
    return tf.Tuple(update_ops)

Danijarの答えと同様に、それを使用して各バッチの後にLSTMの状態を更新できます。

data = tf.placeholder(tf.float32, (batch_size, max_length, frame_size))
cells = [tf.contrib.rnn.GRUCell(256) for _ in range(num_layers)]
cell = tf.contrib.rnn.MultiRNNCell(cells)

# For each layer, get the initial state. states will be a Tuple of LSTMStateTuples.
states = get_state_variables(batch_size, cell)

# Unroll the LSTM
outputs, new_states = tf.nn.dynamic_rnn(cell, data, initial_state=states)

# Add an operation to update the train states with the last state tensors.
update_op = get_state_update_op(states, new_states)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run([outputs, update_op], {data: ...})

主な違いは、state_is_Tuple=Trueは、LSTMの状態を、単一の変数ではなく、2つの変数(セル状態と非表示状態)を含むLSTMStateTupleにすることです。複数のレイヤーを使用すると、LSTMの状態がLSTMStateTuplesのタプルになります(レイヤーごとに1つ)。

9
Kilian Batzner