web-dev-qa-db-ja.com

tf.data.Dataset.from_generatorの並列化

from_generatorが完璧な非自明な入力パイプラインがあります...

dataset = tf.data.Dataset.from_generator(complex_img_label_generator,
                                        (tf.int32, tf.string))
dataset = dataset.batch(64)
iter = dataset.make_one_shot_iterator()
imgs, labels = iter.get_next()

complex_img_label_generatorは画像を動的に生成し、(H, W, 3)画像と単純なstringラベルを表すnumpy配列を返します。処理は、ファイルからの読み取りおよびtf.image操作として表すことができるものではありません。

私の質問は、ジェネレーターをどのようにパラライズするかです。これらのジェネレータのうちN個を独自のスレッドで実行するにはどうすればよいですか。

1つの考えは、dataset.mapnum_parallel_callsとともに使用してスレッド化を処理することでした。しかし、マップはテンソルで動作します...別の考えは、それぞれが独自のprefetchを持つ複数のジェネレーターを作成し、何らかの方法でそれらを結合することでしたが、N個のジェネレーターストリームを結合する方法がわかりませんか?

私が従うことができる標準的な例はありますか?

25
mat kelcey

ジェネレーターを超軽量(メタデータのみを生成)にしてから、実際の重い照明をステートレス関数に移動する場合、Dataset.mapを使用できます。このようにして、.mapを使用して、py_funcを使用して重量物を持ち上げる部分だけを並列化できます。

作品;しかし、少し不器用だと感じています... num_parallel_callsfrom_generatorに追加するだけでいいと思います:)

def pure_numpy_and_pil_complex_calculation(metadata, label):
  # some complex pil and numpy work nothing to do with tf
  ...

dataset = tf.data.Dataset.from_generator(lightweight_generator,
                                         output_types=(tf.string,   # metadata
                                                       tf.string))  # label

def wrapped_complex_calulation(metadata, label):
  return tf.py_func(func = pure_numpy_and_pil_complex_calculation,
                    inp = (metadata, label),
                    Tout = (tf.uint8,    # (H,W,3) img
                            tf.string))  # label
dataset = dataset.map(wrapped_complex_calulation,
                      num_parallel_calls=8)

dataset = dataset.batch(64)
iter = dataset.make_one_shot_iterator()
imgs, labels = iter.get_next()
21
mat kelcey

私はfrom_indexable for tf.data.Datasetに取り組んでいます https://github.com/tensorflow/tensorflow/issues/14448

from_indexableの利点は、pythonジェネレーターを並列化できないのに対し、並列化できることです。

関数from_indexabletf.data.rangeを作成し、一般化されたtf.py_funcでインデックス化可能をラップし、mapを呼び出します。

from_indexableが必要な場合は、ここにlibコード

import tensorflow as tf
import numpy as np

from tensorflow.python.framework import tensor_shape
from tensorflow.python.util import nest

def py_func_decorator(output_types=None, output_shapes=None, stateful=True, name=None):
    def decorator(func):
        def call(*args):
            nonlocal output_shapes

            flat_output_types = nest.flatten(output_types)
            flat_values = tf.py_func(
                func, 
                inp=args, 
                Tout=flat_output_types,
                stateful=stateful, name=name
            )
            if output_shapes is not None:
                # I am not sure if this is nessesary
                output_shapes = nest.map_structure_up_to(
                    output_types, tensor_shape.as_shape, output_shapes)
                flattened_shapes = nest.flatten_up_to(output_types, output_shapes)
                for ret_t, shape in Zip(flat_values, flattened_shapes):
                    ret_t.set_shape(shape)
            return nest.pack_sequence_as(output_types, flat_values)
        return call
    return decorator

def from_indexable(iterator, output_types, output_shapes=None, num_parallel_calls=None, stateful=True, name=None):
    ds = tf.data.Dataset.range(len(iterator))
    @py_func_decorator(output_types, output_shapes, stateful=stateful, name=name)
    def index_to_entry(index):
        return iterator[index]    
    return ds.map(index_to_entry, num_parallel_calls=num_parallel_calls)

ここに例があります(注:from_indexableにはnum_parallel_calls argumentがあります)

class PyDataSet:
    def __len__(self):
        return 20

    def __getitem__(self, item):
        return np.random.normal(size=(item+1, 10))

ds = from_indexable(PyDataSet(), output_types=tf.float64, output_shapes=[None, 10])
it = ds.make_one_shot_iterator()
entry = it.get_next()
with tf.Session() as sess:
    print(sess.run(entry).shape)
    print(sess.run(entry).shape)

Update2018年6月10日: https://github.com/tensorflow/tensorflow/pull/15121 がマージされ、 from_indexableのコードは次のように単純化されます。

import tensorflow as tf

def py_func_decorator(output_types=None, output_shapes=None, stateful=True, name=None):
    def decorator(func):
        def call(*args, **kwargs):
            return tf.contrib.framework.py_func(
                func=func, 
                args=args, kwargs=kwargs, 
                output_types=output_types, output_shapes=output_shapes, 
                stateful=stateful, name=name
            )
        return call
    return decorator

def from_indexable(iterator, output_types, output_shapes=None, num_parallel_calls=None, stateful=True, name=None):
    ds = tf.data.Dataset.range(len(iterator))
    @py_func_decorator(output_types, output_shapes, stateful=stateful, name=name)
    def index_to_entry(index):
        return iterator[index]    
    return ds.map(index_to_entry, num_parallel_calls=num_parallel_calls)
7

generatorで行われる作業を最小限に制限し、mapを使用して高価な処理を並列化するのが賢明です。

または、次のようにparallel_interleaveを使用して複数のジェネレーターを「結合」できます。

 def generator(n):
#n番目のジェネレーター関数を返します
 
 def dataset(n):
 return tf.data.Dataset .from_generator(generator(n))
 
 ds = tf.data.Dataset.range(N).apply(tf.contrib.data.parallel_interleave(dataset、cycle_lenght = N))
 
#ここで、Nは使用するジェネレーターの数です
3
jsimsa