web-dev-qa-db-ja.com

tf.data.Dataset:データセットサイズ(エポックの要素数)を取得する方法は?

この方法でデータセットを定義したとしましょう:

_filename_dataset = tf.data.Dataset.list_files("{}/*.png".format(dataset))
_

データセット内にある要素の数(したがって、エポックを構成する単一の要素の数)を取得するにはどうすればよいですか?

repeat()メソッドは指定されたエポック数で入力パイプラインを繰り返すことができるため、_tf.data.Dataset_はデータセットの次元を既に知っていることを知っています。したがって、この情報を取得する方法でなければなりません。

9
nessuno

_tf.data.Dataset.list_files_は、_MatchingFiles:0_と呼ばれるテンソルを作成します(該当する場合、適切なプレフィックスを付けます)。

あなたが評価できます

_tf.shape(tf.get_default_graph().get_tensor_by_name('MatchingFiles:0'))[0]
_

ファイルの数を取得します。

もちろん、これは単純な場合にのみ機能し、特に画像ごとにサンプルが1つ(または既知のサンプル数)しかない場合にのみ機能します。

より複雑な状況、例えば各ファイルのサンプル数がわからない場合は、エポックの終了時にサンプル数のみを観察できます。

これを行うには、Datasetによってカウントされるエポックの数を見ることができます。 repeat()は、エポックの数をカウントする__count_というメンバーを作成します。反復中にそれを観察することにより、変更が発生した時点を特定し、そこからデータセットサイズを計算できます。

このカウンターは、メンバー関数を連続して呼び出すときに作成されるDatasetsの階層に埋もれる可能性があるため、このように掘り下げる必要があります。

_d = my_dataset
# RepeatDataset seems not to be exposed -- this is a possible workaround 
RepeatDataset = type(tf.data.Dataset().repeat())
try:
  while not isinstance(d, RepeatDataset):
    d = d._input_dataset
except AttributeError:
  warnings.warn('no Epoch counter found')
  Epoch_counter = None
else:
  Epoch_counter = d._count
_

この手法では、データセットサイズの計算が正確ではないことに注意してください。これは、_Epoch_counter_がインクリメントされるバッチは通常、連続する2つのエポックのサンプルを混合するためです。したがって、この計算はバッチの長さまで正確です。

4
P-Gn

len(list(dataset))はeagerモードで動作しますが、これは明らかに一般的な解決策ではありません。

4
markemus

残念ながら、TFにはまだそのような機能があるとは思わない。ただし、TF 2.0と熱心な実行では、データセットを反復処理するだけで済みます。

num_elements = 0
for element in dataset:
    num_elements += 1

これは、私が思いつく最もストレージ効率の良い方法です

これは、かなり前に追加すべきだった機能のように感じます。指が交差し、後のバージョンでこれに長さ機能を追加しました。

2
RodYt

こちらをご覧ください: https://github.com/tensorflow/tensorflow/issues/26966

TFRecordデータセットでは機能しませんが、他のタイプでは正常に機能します。

TL; DR:

num_elements = tf.data.experimental.cardinality(dataset).numpy()