web-dev-qa-db-ja.com

チェックポイントからtf.estimator.Estimatorで予測を行う方法は?

TNNで太陽黒点を認識するようにCNNをトレーニングしたところです。私のモデルは this とほとんど同じです。問題は、トレーニングフェーズで生成されたチェックポイントで予測を行う方法について明確な説明がどこにも見当たらないことです。

標準の復元方法を使用してみました:

_saver = tf.train.import_meta_graph('./model/model.ckpt.meta')
saver.restore(sess,'./model/model.ckpt')
_

しかし、それを実行する方法がわかりません。
tf.estimator.Estimator.predict()を次のように使用してみました:

_# Create the Estimator (should reload the last checkpoint but it doesn't)
sunspot_classifier = tf.estimator.Estimator(
    model_fn=cnn_model_fn, model_dir="./model")

# Set up logging for predictions
# Log the values in the "Softmax" tensor with label "probabilities"
tensors_to_log = {"probabilities": "softmax_tensor"}
logging_hook = tf.train.LoggingTensorHook(
    tensors=tensors_to_log, every_n_iter=50)

# predict with the model and print results
pred_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={"x": pred_data},
    shuffle=False)
pred_results = sunspot_classifier.predict(input_fn=pred_input_fn)
print(pred_results)
_

しかし、それは_<generator object Estimator.predict at 0x10dda6bf8>_を吐き出しています。一方、同じコードをtf.estimator.Estimator.evaluate()で使用すると、チャームのように機能します(モデルをリロードし、評価を実行して、それをTensorBoardに送信します)。

同様の質問がたくさんあることは知っていますが、自分に合った方法を見つけることができませんでした。

8
RobiNoob

sunspot_classifier.predict(input_fn=pred_input_fn)はジェネレータを返します。したがって、_pred_results_はジェネレータオブジェクトです。それから値を取得するには、next(pred_results)によって反復する必要があります

解決策はprint(next(pred_results))です

8
snake