web-dev-qa-db-ja.com

KerasでWordの埋め込みとともに追加機能を使用する方法

次のようなデータセットでKerasを使用してLSTMモデルをトレーニングしています。変数「説明」はテキストフィールドであり、「年齢」と「性別」はカテゴリフィールドと連続フィールドです。

Age, Gender, Description
22, M, "purchased a phone"
35, F, "shopping for kids"

Word埋め込みを使用してテキストフィールドをWordベクトルに変換し、それをケラスモデルに入力しています。コードを以下に示します。

model = Sequential()
model.add(Embedding(Word_index, 300, weights=[embedding_matrix], input_length=70, trainable=False))

model.add(LSTM(300, dropout=0.3, recurrent_dropout=0.3))
model.add(Dropout(0.6))
model.add(Dense(1))
model.add(Activation('sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics['accuracy'])

このモデルは正常に実行されていますが、「年齢」と「性別」の変数も機能として入力したいと考えています。これらの機能を使用するには、コードにどのような変更が必要ですか?

10
userxxx

シーケンシャルモデルでは不可能な入力レイヤーをさらに追加したい場合は、機能モデルを使用する必要があります

from keras.models import Model

これにより、複数の入力と間接的な接続が可能になります。

embed = Embedding(Word_index, 300, weights=[embedding_matrix], input_length=70, trainable=False)
lstm = LSTM(300, dropout=0.3, recurrent_dropout=0.3)(embed)
agei = Input(shape=(1,))
conc = Concatenate()(lstm, agei)
drop = Dropout(0.6)(conc)
dens = Dense(1)(drop)
acti = Activation('sigmoid')(dens)

model = Model([embed, agei], acti)
model.compile(loss='binary_crossentropy', optimizer='adam', metrics['accuracy'])

LSTMレイヤーの前に連結することはできません。これは意味がなく、レイヤーを埋め込んだ後に3Dテンソルがあり、入力が2Dテンソルになります。

12
Suba Selvandran

ケラスでこれを行う方法 について書きました。これは基本的に機能的な複数入力モデルであり、両方の特徴ベクトルを連結して次のようにします。

nlp_input = Input(shape=(seq_length,), name='nlp_input')
meta_input = Input(shape=(10,), name='meta_input')
emb = Embedding(output_dim=embedding_size, input_dim=100, input_length=seq_length)(nlp_input)
nlp_out = Bidirectional(LSTM(128))(emb)
x = concatenate([nlp_out, meta_input])
x = Dense(classifier_neurons, activation='relu')(x)
x = Dense(1, activation='sigmoid')(x)
model = Model(inputs=[nlp_input , meta_input], outputs=[x])
3
ixeption

これらの機能を取り入れ、n次元のベクトルを出力する別のフィードフォワードネットワークを用意することを検討してください。

time_independent = Input(shape=(num_features,))
dense_1 = Dense(200, activation='tanh')(time_independent)
dense_2 = Dense(300, activation='tanh')(dense_1)

まず、keras ' functional API を使用して、このようなことを行ってください。

次に、これをLSTMの非表示状態として渡すか、LSTMがすべてのタイムステップでそれを参照できるように、すべてのWord埋め込みと連結することができます。後者の場合、ネットワークの次元を大幅に削減する必要があります。

例が必要な場合は、お知らせください。

2
modesitt