web-dev-qa-db-ja.com

tensorflowでdataset.shardを使用する方法は?

最近、TensorflowのデータセットAPIを調べていますが、分散計算用のメソッドdataset.shard()があります。

これは、Tensorflowのドキュメントに記載されている内容です。

Creates a Dataset that includes only 1/num_shards of this dataset.

d = tf.data.TFRecordDataset(FLAGS.input_file)
d = d.shard(FLAGS.num_workers, FLAGS.worker_index)
d = d.repeat(FLAGS.num_epochs)
d = d.shuffle(FLAGS.shuffle_buffer_size)
d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)

このメソッドは、元のデータセットの一部を返すと言われています。 2人の労働者がいる場合、次のことを行う必要があります。

d_0 = d.shard(FLAGS.num_workers, worker_0)
d_1 = d.shard(FLAGS.num_workers, worker_1)
......
iterator_0 = d_0.make_initializable_iterator()
iterator_1 = d_1.make_initializable_iterator()

for worker_id in workers:
    with tf.device(worker_id):
        if worker_id == 0:
            data = iterator_0.get_next()
        else:
            data = iterator_1.get_next()
        ......

ドキュメントには後続の呼び出しの方法が指定されていないため、ここでは少し混乱しています。

ありがとう!

7
Jiang Wenbo

それがどのように機能するかをよりよく理解するために、最初に Distributed TensorFlowのチュートリアル を見る必要があります。

複数のワーカーがあり、それぞれが同じコードを実行しますが、わずかな違いがあります。各ワーカーのFLAGS.worker_indexは異なります。

tf.data.Dataset.shard を使用する場合、このワーカーインデックスを指定すると、データはワーカー間で均等に分割されます。

これは3人の労働者の例です。

dataset = tf.data.Dataset.range(6)
dataset = dataset.shard(FLAGS.num_workers, FLAGS.worker_index)


iterator = dataset.make_one_shot_iterator()
res = iterator.get_next()

# Suppose you have 3 workers in total
with tf.Session() as sess:
    for i in range(2):
        print(sess.run(res))

出力があります:

  • 0, 3ワーカー0
  • 1, 4ワーカー1
  • 2, 5ワーカー2
10