web-dev-qa-db-ja.com

入力fnを使用してTensorflow Estimatorで予測する

私は https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/examples/learn/wide_n_deep_tutorial.py のチュートリアルコードを使用し、コードを作成するまで問題なく動作します単にそれを評価する代わりに予測。次のような予測用の別の関数を作成しようとしました(パラメーターyを削除するだけです)。

def input_fn_predict(data_file, num_epochs, shuffle):
  """Input builder function."""
  df_data = pd.read_csv(
      tf.gfile.Open(data_file),
      names=CSV_COLUMNS,
      skipinitialspace=True,
      engine="python",
      skiprows=1)
  # remove NaN elements
  df_data = df_data.dropna(how="any", axis=0)
  labels = df_data["income_bracket"].apply(lambda x: ">50K" in x).astype(int)
  return tf.estimator.inputs.pandas_input_fn( #removed paramter y
      x=df_data,
      batch_size=100,
      num_epochs=num_epochs,
      shuffle=shuffle,
      num_threads=5)

そしてそれをこのように呼ぶには:

predictions = m.predict(
      input_fn=input_fn_predict(test_file_name, num_epochs=1, shuffle=True)
  )
  for i, p in enumerate(predictions):
      print(i, p)
  • 私はそれを正しくやっていますか?
  • 16282(テストファイルの行数)ではなく予測81404が表示されるのはなぜですか?
  • 各行には次のようなものが含まれています。

{'確率':array([0.78595656、0.21404342]、dtype = float32)、 'logits':array([-1.3007226]、dtype = float32)、 'classes':array(['0']、dtype = object) 、 'class_ids':array([0])、 'logistic':array([0.21404341]、dtype = float32)}

どうやって読むの?

6

新しいラベルを予測するには、データの順序を維持する必要があるため、shuffle=Falseを設定する必要があります。

以下は、予測を実行するためのコードです(テストしました)。入力ファイルはテストデータ(csv形式)に似ていますが、ラベル列はありません。



    def predict_input_fn(data_file):
        global CSV_COLUMNS
        CSV_COLUMNS = CSV_COLUMNS[:-1]
        df_data = pd.read_csv(
            tf.gfile.Open(data_file),
            names=CSV_COLUMNS,
            skipinitialspace=True,
            engine='python',
            skiprows=1
        )

        # remove NaN elements
        df_data = df_data.dropna(how='any', axis=0)

        return tf.estimator.inputs.pandas_input_fn(
            x=df_data,
            num_epochs=1,
           shuffle=False
        )

それを呼び出すには:



    predict_file_name = 'tutorials/data/adult.predict'
    results = m.predict(
        input_fn=predict_input_fn(predict_file_name)
    )
    for result in results:
        print 'result: {}'.format(result)

1つのサンプルの予測結果は次のとおりです。



    {
        'probabilities': array([0.78595656, 0.21404342], dtype = float32),
        'logits': array([-1.3007226], dtype = float32),
        'classes': array(['0'], dtype = object),
        'class_ids': array([0]),
        'logistic': array([0.21404341], dtype = float32)
    }

各フィールドの意味は

  • '確率':array([0.78595656、0.21404342]、dtype = float32)
    出力ラベルがクラス0(この場合は50K以下)であると予測し、信頼度は0.78595656です。
  • 'logits':array([-1.3007226]、dtype = float32)
    方程式1 /(1 + e ^(-z))のzの値は-1.3です。
  • 'classes':array(['0']、dtype = object)
    クラスラベルは0です
16
impulse