web-dev-qa-db-ja.com

Tensorflowデータセットをファイルに保存するにはどうすればよいですか?

SO=)には、このような質問が少なくとも2つありますが、1つも回答されていません。

次の形式のデータセットがあります。

<TensorSliceDataset shapes: ((512,), (512,), (512,), ()), types: (tf.int32, tf.int32, tf.int32, tf.int32)>

そして別の形:

<BatchDataset shapes: ((None, 512), (None, 512), (None, 512), (None,)), types: (tf.int32, tf.int32, tf.int32, tf.int32)>

調べてみましたが、後で読み込むことができるファイルにこれらのデータセットを保存するコードが見つかりません。私が得た最も近いものは TensorFlowドキュメントのこのページ で、tf.io.serialize_tensorを使用してテンソルをシリアル化し、tf.data.experimental.TFRecordWriterを使用してそれらをファイルに書き込むことを提案しています。

しかし、私がコードを使用してこれを試したとき:

dataset.map(tf.io.serialize_tensor)
writer = tf.data.experimental.TFRecordWriter('mydata.tfrecord')
writer.write(dataset)

最初の行にエラーが表示されます。

TypeError:serialize_tensor()は1から2の位置引数を取りますが、4が与えられました

私の目標を達成するために、どうすれば上記を変更(または他のことを実行)できますか?

3

私はこの問題にも取り組んでおり、これまでに次のユーティリティを作成しました(見つかったように 私のリポジトリにも

def cache_with_tf_record(filename: Union[str, pathlib.Path]) -> Callable[[tf.data.Dataset], tf.data.TFRecordDataset]:
    """
    Similar to tf.data.Dataset.cache but writes a tf record file instead. Compared to base .cache method, it also insures that the whole
    dataset is cached
    """

    def _cache(dataset):
        if not isinstance(dataset.element_spec, dict):
            raise ValueError(f"dataset.element_spec should be a dict but is {type(dataset.element_spec)} instead")
        Path(filename).parent.mkdir(parents=True, exist_ok=True)
        with tf.io.TFRecordWriter(str(filename)) as writer:
            for sample in dataset.map(transform(**{name: tf.io.serialize_tensor for name in dataset.element_spec.keys()})):
                writer.write(
                    tf.train.Example(
                        features=tf.train.Features(
                            feature={
                                key: tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.numpy()]))
                                for key, value in sample.items()
                            }
                        )
                    ).SerializeToString()
                )
        return (
            tf.data.TFRecordDataset(str(filename), num_parallel_reads=tf.data.experimental.AUTOTUNE)
            .map(
                partial(
                    tf.io.parse_single_example,
                    features={name: tf.io.FixedLenFeature((), tf.string) for name in dataset.element_spec.keys()},
                ),
                num_parallel_calls=tf.data.experimental.AUTOTUNE,
            )
            .map(
                transform(
                    **{name: partial(tf.io.parse_tensor, out_type=spec.dtype) for name, spec in dataset.element_spec.items()}
                )
            )
            .map(
                transform(**{name: partial(tf.ensure_shape, shape=spec.shape) for name, spec in dataset.element_spec.items()})
            )
        )

    return _cache

このユーティリティを使用すると、次のことができます。

dataset.apply(cache_with_tf_record("filename")).map(...)

また、後で使用するためにデータセットを直接ロードし、ユーティリティの2番目の部分のみを使用します。

特にスペースを節約するためにすべてのバイトの代わりに正しい型でシリアル化するために、後で変更される可能性があるため、まだ作業中です(私は推測しています)。

0
ClementWalter