web-dev-qa-db-ja.com

Tensorflowでワンホットラベルをどのようにデコードしますか?

探していましたが、TensorFlowでワンホット値から単一の整数にデコードまたは変換する方法の例を見つけることができないようです。

tf.one_hotと私のモデルを訓練することができましたが、私の分類後にラベルを理解する方法について少し混乱しています。私のデータは、作成したTFRecordsファイルを介してフィードされています。テキストラベルをファイルに保存することを考えましたが、機能しませんでした。 TFRecordsがテキスト文字列を格納できなかったか、私が誤っているように見えました。

7
Matt Camp

tf.argmax を使用して、マトリックス内の最大の要素のインデックスを見つけることができます。 1つのホットベクトルは1次元であり、1と他の0sが1つしかないため、これは、単一のベクトルを扱っていると想定して機能します。

index = tf.argmax(one_hot_vector, axis=0)

batch_size * num_classesのより標準的な行列の場合、axis=1を使用してサイズbatch_size * 1の結果を取得します。

15
martianwars

ワンホットエンコーディングは通常、_batch_size_行と_num_classes_列を持つ単なるマトリックスであり、各行はすべてゼロであり、選択したクラスに対応する単一の非ゼロであるため、次のように使用できます tf.argmax() 整数ラベルのベクトルを復元するには:

_BATCH_SIZE = 3
NUM_CLASSES = 4
one_hot_encoded = tf.constant([[0, 1, 0, 0],
                               [1, 0, 0, 0],
                               [0, 0, 0, 1]])

# Compute the argmax across the columns.
decoded = tf.argmax(one_hot_encoded, axis=1)

# ...
print sess.run(decoded)  # ==> array([1, 0, 3])
_
7
mrry
data = np.array([1, 5, 3, 8])
print(data)


def encode(data):
    print('Shape of data (BEFORE encode): %s' % str(data.shape))
    encoded = to_categorical(data)
    print('Shape of data (AFTER  encode): %s\n' % str(encoded.shape))
    return encoded


encoded_data = encode(data)
print(encoded_data)

def decode(datum):
    return np.argmax(datum)

decoded_Y = []
print("****************************************")
for i in range(encoded_data.shape[0]):
    datum = encoded_data[i]
    print('index: %d' % i)
    print('encoded datum: %s' % datum)
    decoded_datum = decode(encoded_data[i])
    print('decoded datum: %s' % decoded_datum)
    decoded_Y.append(decoded_datum)


print("****************************************")

print(decoded_Y)
0
Rochan