web-dev-qa-db-ja.com

TensorFlowで配列をソートする

TensorFlowに配列があるとしましょう:

_[ 0.12300211,  0.51767069,  0.13886075,  0.55363625],
[ 0.47279349,  0.50432992,  0.48080254,  0.51576483],
[ 0.84347934,  0.44505221,  0.88839239,  0.48857492],
[ 0.93650454,  0.43652734,  0.96464157,  0.47236174], ..
_

この配列を3列目でソートしたいと思います。どうすればよいですか? tf.nn.top_k()を使用して、各列を個別にソートできます。これにより、ソートされた値とそれぞれのインデックスが得られます。この3番目の列のインデックスを使用して他の列を並べ替えることができましたが、並べ替えOpが見つかりません。

私が物事をグラフに保ちたいと仮定すると(Python shenanigans):

  • TensorFlowで(上記の配列)をソートするにはどうすればよいですか?
  • 並べ替えのインデックスがある場合、TensorFlowで並べ替えるにはどうすればよいですか?
14
TimZaman

以下の作品:

a = tf.constant(...) # the array
reordered = tf.gather(a, tf.nn.top_k(a[:, 2], k=4).indices)
12
keveman