web-dev-qa-db-ja.com

トレーニング済みTensorflowモデルを使用できません

ディープラーニングとTensorflowは初めてです。事前トレーニング済みのテンソルフローを再トレーニングしました inceptionv3 model as saved_model.pbさまざまなタイプの画像を認識しますが、以下のコードでfieを使用しようとしたとき。

with tf.Session() as sess:
    with tf.gfile.FastGFile("tensorflow/trained/saved_model.pb",'rb') as  f:
        graph_def = tf.GraphDef()
        tf.Graph.as_graph_def()
        graph_def.ParseFromString(f.read())
        g_in=tf.import_graph_def(graph_def)
        LOGDIR='/log'
        train_writer=tf.summary.FileWriter(LOGDIR)
        train_writer.add_graph(sess.graph)

それは私にこのエラーを与えます-

 File "testing.py", line 7, in <module>
graph_def.ParseFromString(f.read())
google.protobuf.message.DecodeError: Error parsing message

私はこの問題とモジュールのために見つけることができる多くの解決策を試しました tensorflow/python/tools を使用します graph_def.ParseFromString(f.read()) 関数は私に同じエラーを与えています。これを解決する方法を教えてください、または私が回避できる方法を教えてください ParseFromString(f.read()) 関数。どんな助けでもいただければ幸いです。ありがとうございました!

6
Torab Shaikh

tf.saved_model.Builderを使用してトレーニング済みモデルを保存したと仮定しますTensorFlow、この場合、次のようなことができます。

荷重モデル

export_path = './path/to/saved_model.pb'

# We start a session using a temporary fresh Graph
with tf.Session(graph=tf.Graph()) as sess:
    '''
    You can provide 'tags' when saving a model,
    in my case I provided, 'serve' tag 
    '''

    tf.saved_model.loader.load(sess, ['serve'], export_path)
    graph = tf.get_default_graph()

    # print your graph's ops, if needed
    print(graph.get_operations())

    '''
    In my case, I named my input and output tensors as
    input:0 and output:0 respectively
    ''' 
    y_pred = sess.run('output:0', feed_dict={'input:0': X_test})

ここでもう少しコンテキストを与えるために、これが上記のようにロードできるモデルを保存した方法です。

モデルを保存


x = tf.get_default_graph().get_tensor_by_name('input:0')
y = tf.get_default_graph().get_tensor_by_name('output:0')

export_path = './models/'
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
signature = tf.saved_model.predict_signature_def(
                inputs={'input': x}, outputs={'output': y}
                )

# using custom tag instead of: tags=[tf.saved_model.tag_constants.SERVING]
builder.add_meta_graph_and_variables(sess=obj.sess,
                                     tags=['serve'],
                                     signature_def_map={'predict': signature})
builder.save()

これにより、protobuf( 'saved_model.pb')が上記のフォルダー(ここでは 'models')に保存され、上記のようにロードできます。

4
anujonthemove

モデルを保存するときにas_text = Falseを渡しましたか?ご覧ください: TFの保存/復元グラフはtf.GraphDef.ParseFromString()で失敗します

3

Saved_model.pbを使用するよりも、frozen_inference_graph.pbを使用してモデルをロードしてください。

Model_output
- saved_model
  - saved_model.pb
- checkpoint
- frozen_inference_graph.pb     # Main model 
- model.ckpt.data-00000-of-00001
- model.ckpt.index
- model.ckpt.meta
- pipeline.config
3
Naga kiran