web-dev-qa-db-ja.com

ケラスでvgg-16のすべての既知のクラスのリストを取得する

Kerasの事前トレーニング済みVGG-16モデルを使用します。

これまでのところ、私の作業用ソースコードは次のとおりです。

from keras.applications.vgg16 import VGG16
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from keras.applications.vgg16 import preprocess_input
from keras.applications.vgg16 import decode_predictions

model = VGG16()

print(model.summary())

image = load_img('./pictures/door.jpg', target_size=(224, 224))
image = img_to_array(image)  #output Numpy-array

image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))

image = preprocess_input(image)
yhat = model.predict(image)

label = decode_predictions(yhat)
label = label[0][0]

print('%s (%.2f%%)' % (label[1], label[2]*100))

モデルが1000クラスでトレーニングされていることを確認しました。このモデルがトレーニングされているクラスのリストを取得する可能性はありますか? 5つしか返されないため、すべての予測ラベルを印刷することはできません。

前もって感謝します

7
Jürgen K.

Decode_predictionsを使用して、クラスの総数をtop=1000パラメータに渡すことができます(デフォルト値のみ5)。

または、Kerasが内部でこれをどのように実行するかを見ることができます。ファイルはimagenet_class_index.jsonをダウンロードします(通常、~/.keras/models/にキャッシュします)。これは、すべてのクラスラベルを含む単純なjsonファイルです。

4
YSelf

あなたがこのようなことをしたら:

vgg16 = keras.applications.vgg16.VGG16(include_top=True,
                               weights='imagenet',
                               input_tensor=None,
                               input_shape=None,
                               pooling=None,
                               classes=1000)

vgg16.decode_predictions(np.arange(1000), top=1000)

Np.arange(1000)を予測配列に置き換えます。これまでにテストされていないコード。

ここにトレーニングラベルへのリンクがあると思います: http://image-net.org/challenges/LSVRC/2014/browse-synsets

1
wordsforthewise

コードを少し編集すると、提供した例のすべての上位予測のリストを取得できます。 Tensorflow _decode_predictions_は、リストクラス予測タプルのリストを返します。したがって、最初に、@ YSelfがlabel = decode_predictions(yhat, top=1000)に推奨する引数としてtop = 1000引数を追加し、次に_label = label[0][0]_を_label = label[0][:]_に変更して、すべての予測を選択します。ラベルは次のようになります。

_[('n04252225', 'snowplow', 0.4144803),
('n03796401', 'moving_van', 0.09205707),
('n04461696', 'tow_truck', 0.08912289),
('n03930630', 'pickup', 0.07173037),
('n04467665', 'trailer_truck', 0.048759833),
('n02930766', 'cab', 0.043586567),
('n04037443', 'racer', 0.036957625),....)]
_

ここからタプルを解凍する必要があります。 1000クラスのリストだけを取得したい場合は、[y for (x,y,z) in label]を呼び出すだけで、1000クラスすべてのリストを取得できます。出力は次のようになります。

_['snowplow',
'moving_van',
'tow_truck',
'pickup',
'trailer_truck',
'cab',
'racer',....]
_
0
osmancakirio