web-dev-qa-db-ja.com

AttributeError:Tensorflow 2.1では 'Tensor'オブジェクトに属性 'numpy'がありません

Tensorflow 2.1でshapeTensorプロパティを変換しようとしていますが、次のエラーが発生します。

_AttributeError: 'Tensor' object has no attribute 'numpy'
_

tf.executing eagerly()の出力がTrueであることをすでに確認しました。

ちょっとしたコンテキスト:_tf.data.Dataset_をTFRecordsからロードしてから、mapを適用します。マッピング関数は、データセットサンプルshapeTensorプロパティをnumpyに変換しようとしています:

_def _parse_and_decode(serialized_example):
    """ parse and decode each image """
    features = tf.io.parse_single_example(
        serialized_example,
        features={
            'encoded_image': tf.io.FixedLenFeature([], tf.string),
            'kp_flat': tf.io.VarLenFeature(tf.int64),
            'kp_shape': tf.io.FixedLenFeature([3], tf.int64),
        }
    )
    image = tf.io.decode_png(features['encoded_image'], dtype=tf.uint8)
    image = tf.cast(image, tf.float32)

    kp_shape = features['kp_shape']

    kp_flat = tf.sparse.to_dense(features['kp_flat'])
    kp = tf.reshape(kp_flat, kp_shape)

    return image, kp


def read_tfrecords(records_dir, batch_size=1):
    # Read dataset from tfrecords
    tfrecords_files = glob.glob(os.path.join(records_dir, '*'))
    dataset = tf.data.TFRecordDataset(tfrecords_files)
    dataset = dataset.map(_parse_and_decode, num_parallel_calls=batch_size)
    return dataset


def transform(img, labels):
    img_shape = img.shape  # type: <class 'tensorflow.python.framework.ops.Tensor'>`
    img_shape = img_shape.numpy()  # <-- Throws the error
    # ...    

dataset = read_tfrecords(records_dir)
_

これはエラーをスローします:

_dataset.map(transform, num_parallel_calls=1)
_

これは完全に機能しますが、

_for img, labels in dataset.take(1):
    print(img.shape.numpy())
_

編集:img.numpy()ではなくimg.shape.numpy()にアクセスしようとすると、トランスフォーマーと上記のコードで同じ動作になります。

_img_shape_のタイプを確認したところ、_<class 'tensorflow.python.framework.ops.Tensor'>_でした。

Tensorflowの新しいバージョンでこの種の問題を解決した人はいますか?

2
Abitbol

コードの問題は、tf.data.Datasetsにマップされている関数内で.numpy()を使用できないことです。これは、。numpy()がPython codeは純粋なTensorFlowコードではないためです。

my_dataset.map(my_function)のような関数を使用する場合、tf.*関数内でのみmy_function関数を使用できます。

これはTensorFlow 2.xバージョンのバグではなく、パフォーマンスの目的で静的グラフがバックグラウンドで生成される方法に関するバグです。

データセットにマッピングする関数内でカスタムPythonコードを使用する場合は、 https://www.tensorflow.org/api_docs/python/tf/py_function 。データセットにマッピングするときにPythonコードとTensorFlowコードを混合する他の方法はありません。

詳細については、この質問を参照することもできます。それは私が数ヶ月前に尋ねた正確な質問です: カスタムのtf.py_function()に代わるものはありますかPython code?

3
Timbus Calin