web-dev-qa-db-ja.com

ケラで注意モデルを構築するには?

注意モデルを理解し、自分で構築しようとしています。多くの検索の後、私は偶然出くわしました このウェブサイト これはケラでコード化された減衰モデルを持っていて、また単純に見えました。しかし、自分のマシンで同じモデルを構築しようとすると、複数の引数エラーが発生します。エラーは、クラスAttentionで渡される引数の不一致が原因でした。 Webサイトのアテンションクラスでは、1つの引数を求めていますが、2つの引数でアテンションオブジェクトを開始しています。

import tensorflow as tf

max_len = 200
rnn_cell_size = 128
vocab_size=250

class Attention(tf.keras.Model):
    def __init__(self, units):
        super(Attention, self).__init__()
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.V = tf.keras.layers.Dense(1)
    def call(self, features, hidden):
        hidden_with_time_axis = tf.expand_dims(hidden, 1)
        score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))
        attention_weights = tf.nn.softmax(self.V(score), axis=1)
        context_vector = attention_weights * features
        context_vector = tf.reduce_sum(context_vector, axis=1)
        return context_vector, attention_weights

sequence_input = tf.keras.layers.Input(shape=(max_len,), dtype='int32')

embedded_sequences = tf.keras.layers.Embedding(vocab_size, 128, input_length=max_len)(sequence_input)

lstm = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM
                                     (rnn_cell_size,
                                      dropout=0.3,
                                      return_sequences=True,
                                      return_state=True,
                                      recurrent_activation='relu',
                                      recurrent_initializer='glorot_uniform'), name="bi_lstm_0")(embedded_sequences)

lstm, forward_h, forward_c, backward_h, backward_c = tf.keras.layers.Bidirectional \
    (tf.keras.layers.LSTM
     (rnn_cell_size,
      dropout=0.2,
      return_sequences=True,
      return_state=True,
      recurrent_activation='relu',
      recurrent_initializer='glorot_uniform'))(lstm)

state_h = tf.keras.layers.Concatenate()([forward_h, backward_h])
state_c = tf.keras.layers.Concatenate()([forward_c, backward_c])

#  PROBLEM IN THIS LINE
context_vector, attention_weights = Attention(lstm, state_h)

output = keras.layers.Dense(1, activation='sigmoid')(context_vector)

model = keras.Model(inputs=sequence_input, outputs=output)

# summarize layers
print(model.summary())

このモデルを機能させるにはどうすればよいですか?

4
Eka

アテンションレイヤーは、現在Tensorflow(2.1)のKeras APIの一部です。ただし、「クエリ」テンソルと同じサイズのテンソルを出力します。

これはLuongスタイルの注意の使用方法です。

query_attention = tf.keras.layers.Attention()([query, value])

そしてバーダナウ様式の注意:

query_attention = tf.keras.layers.AdditiveAttention()([query, value])

詳細については、元のWebサイトを確認してください: https://www.tensorflow.org/api_docs/python/tf/keras/layers/Attentionhttps://www.tensorflow.org/api_docs/python/tf/keras/layers/AdditiveAttention

1
Recep şen