web-dev-qa-db-ja.com

グラフ構築時にテンソルの次元を取得する方法(TensorFlow)

期待どおりに動作しないOpを試しています。

graph = tf.Graph()
with graph.as_default():
  train_dataset = tf.placeholder(tf.int32, shape=[128, 2])
  embeddings = tf.Variable(
    tf.random_uniform([50000, 64], -1.0, 1.0))
  embed = tf.nn.embedding_lookup(embeddings, train_dataset)
  embed = tf.reduce_sum(embed, reduction_indices=0)

そのため、Tensor embedの次元を知る必要があります。実行時に実行できることは知っていますが、このような単純な操作には手間がかかりすぎます。それを行う簡単な方法は何ですか?

33
Thoran

Tensor.get_shape() from この投稿

ドキュメントから

c = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
print(c.get_shape())
==> TensorShape([Dimension(2), Dimension(3)])
42
Thoran

tf.shape(tensor)tensor.get_shape()についてほとんどの人が混乱しているのを見てみましょう。

  1. tf.shape

tf.shapeは動的な形状に使用されます。テンソルの形状がchangeableの場合、それを使用します。例:入力は幅と高さが変更可能な画像であり、サイズを半分にサイズ変更したい場合は、次のように記述できます。
new_height = tf.shape(image)[0] / 2

  1. tensor.get_shape

tensor.get_shapeは固定形状に使用されます。これは、グラフのテンソルの形状を推定できるを意味します。

結論:tf.shapeはほぼどこでも使用できますが、t.get_shapeは図形からのみグラフから推測できます。

43
Shang

accessへの関数値:

def shape(tensor):
    s = tensor.get_shape()
    return Tuple([s[i].value for i in range(0, len(s))])

例:

batch_size, num_feats = shape(logits)
9
Colin Swaney

実行せずに、構築グラフ(ops)の後に埋め込みを印刷します。

import tensorflow as tf

...

train_dataset = tf.placeholder(tf.int32, shape=[128, 2])
embeddings = tf.Variable(
    tf.random_uniform([50000, 64], -1.0, 1.0))
embed = tf.nn.embedding_lookup(embeddings, train_dataset)
print (embed)

これにより、埋め込みテンソルの形状が表示されます。

Tensor("embedding_lookup:0", shape=(128, 2, 64), dtype=float32)

通常、モデルをトレーニングする前にすべてのテンソルの形状を確認することをお勧めします。

5
Sung Kim

地獄のようにシンプルにしましょう。 2, 3, 4, etc.,などの次元数に単一の数値が必要な場合は、tf.rank()を使用します。ただし、テンソルの正確な形状が必要な場合は、tensor.get_shape()を使用します

with tf.Session() as sess:
   arr = tf.random_normal(shape=(10, 32, 32, 128))
   a = tf.random_gamma(shape=(3, 3, 1), alpha=0.1)
   print(sess.run([tf.rank(arr), tf.rank(a)]))
   print(arr.get_shape(), ", ", a.get_shape())     


# for tf.rank()    
[4, 3]

# for tf.get_shape()
Output: (10, 32, 32, 128) , (3, 3, 1)
3
kmario23

メソッドtf.shapeは、TensorFlow静的メソッドです。ただし、Tensorクラスにはget_shapeメソッドもあります。見る

https://www.tensorflow.org/api_docs/python/tf/Tensor#get_shape

1
cliffberg