web-dev-qa-db-ja.com

関数型APIのKeras Multiply()レイヤー

新しいAPIの変更では、Kerasでレイヤーの要素ごとの乗算をどのように行いますか?古いAPIの下では、次のようなことを試します。

_merge([dense_all, dense_att], output_shape=10, mode='mul')
_

私はこれを試しました(MWE):

_from keras.models import Model
from keras.layers import Input, Dense, Multiply

def sample_model():
        model_in = Input(shape=(10,))
        dense_all = Dense(10,)(model_in)
        dense_att = Dense(10, activation='softmax')(model_in)
        att_mull = Multiply([dense_all, dense_att]) #merge([dense_all, dense_att], output_shape=10, mode='mul')
        model_out = Dense(10, activation="sigmoid")(att_mull)
        return 0

if __name__ == '__main__':
        sample_model()
_

完全なトレース:

_Using TensorFlow backend.
Traceback (most recent call last):
  File "testJan17.py", line 13, in <module>
    sample_model()
  File "testJan17.py", line 8, in sample_model
    att_mull = Multiply([dense_all, dense_att]) #merge([dense_all, dense_att], output_shape=10, mode='mul')
TypeError: __init__() takes exactly 1 argument (2 given)
_

編集:

私はテンソルフローの要素ごとの乗算関数を実装してみました。もちろん、結果はLayer()インスタンスではないため、機能しません。これが後世のための試みです:

_def new_multiply(inputs): #assume two only - bad practice, but for illustration...
        return tf.multiply(inputs[0], inputs[1])


def sample_model():
        model_in = Input(shape=(10,))
        dense_all = Dense(10,)(model_in)
        dense_att = Dense(10, activation='softmax')(model_in) #which interactions are important?
        new_mult = new_multiply([dense_all, dense_att])
        model_out = Dense(10, activation="sigmoid")(new_mult)
        model = Model(inputs=model_in, outputs=model_out)
        model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
        return model
_
7
StatsSorceress

keras> 2.0の場合:

from keras.layers import multiply
output = multiply([dense_all, dense_att])
12
Marcin Możejko

前にもう1つの開き括弧を追加する必要があります。

from keras.layers import Multiply
att_mull = Multiply()([dense_all, dense_att])
5
shaival shah

関数型APIでは、multiply関数を使用するだけで、小文字の「m」に注意してください。 Multiplyクラスは、ご覧のとおり、シーケンシャルAPIで使用するためのレイヤーです。

https://keras.io/layers/merge/#multiply_1 の詳細情報

5