web-dev-qa-db-ja.com

model.summary()は、サブクラスモデルの使用中に出力形状を印刷できません

これはケラスモデルを作成するための2つの方法ですが、2つの方法の要約結果のoutput shapesは異なります。明らかに、前者はより多くの情報を出力し、ネットワークの正確性をチェックするのを容易にします。

import tensorflow as tf
from tensorflow.keras import Input, layers, Model

class subclass(Model):
    def __init__(self):
        super(subclass, self).__init__()
        self.conv = layers.Conv2D(28, 3, strides=1)

    def call(self, x):
        return self.conv(x)


def func_api():
    x = Input(shape=(24, 24, 3))
    y = layers.Conv2D(28, 3, strides=1)(x)
    return Model(inputs=[x], outputs=[y])

if __name__ == '__main__':
    func = func_api()
    func.summary()

    sub = subclass()
    sub.build(input_shape=(None, 24, 24, 3))
    sub.summary()

出力:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 24, 24, 3)         0         
_________________________________________________________________
conv2d (Conv2D)              (None, 22, 22, 28)        784       
=================================================================
Total params: 784
Trainable params: 784
Non-trainable params: 0
_________________________________________________________________
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D)            multiple                  784       
=================================================================
Total params: 784
Trainable params: 784
Non-trainable params: 0
_________________________________________________________________

では、サブクラスメソッドを使用して、summary()でoutput shapeを取得するにはどうすればよいですか?

8
Gary

私はこの問題を解決するためにこの方法を使用しましたが、もっと簡単な方法があるかどうかはわかりません。

class subclass(Model):
    def __init__(self):
        ...
    def call(self, x):
        ...

    def model(self):
        x = Input(shape=(24, 24, 3))
        return Model(inputs=[x], outputs=self.call(x))



if __name__ == '__main__':
    sub = subclass()
    sub.model().summary()
8
Gary