web-dev-qa-db-ja.com

Tensorflow:名前でテンソルを取得する方法は?

名前でテンソルを回復するのに問題があります。それが可能かどうかさえわかりません。

グラフを作成する関数があります:

def create_structure(tf, x, input_size,dropout):    
 with tf.variable_scope("scale_1") as scope:
  W_S1_conv1 = deep_dive.weight_variable_scaling([7,7,3,64], name='W_S1_conv1')
  b_S1_conv1 = deep_dive.bias_variable([64])
  S1_conv1 = tf.nn.relu(deep_dive.conv2d(x_image, W_S1_conv1,strides=[1, 2, 2, 1], padding='SAME') + b_S1_conv1, name="Scale1_first_relu")
.
.
.
return S3_conv1,regularizer

この関数の外部で変数S1_conv1にアクセスしたい。私は試した:

with tf.variable_scope('scale_1') as scope_conv: 
 tf.get_variable_scope().reuse_variables()
 ft=tf.get_variable('Scale1_first_relu')

しかし、それは私にエラーを与えています:

ValueError:Under-sharing:変数scale_1/Scale1_first_reluは存在せず、許可されません。 VarScopeでreuse = Noneを設定するつもりでしたか?

しかし、これは機能します:

with tf.variable_scope('scale_1') as scope_conv: 
 tf.get_variable_scope().reuse_variables()
 ft=tf.get_variable('W_S1_conv1')

私はこれを回避できます

return S3_conv1,regularizer, S1_conv1

しかし、私はそれをしたくありません。

私の問題は、S1_conv1が実際には変数ではなく、単なるテンソルだということだと思います。私がしたいことをする方法はありますか?

45
protas

関数tf.Graph.get_tensor_by_name()があります。例えば:

import tensorflow as tf

c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
d = tf.constant([[1.0, 1.0], [0.0, 1.0]])
e = tf.matmul(c, d, name='example')

with tf.Session() as sess:
    test =  sess.run(e)
    print e.name #example:0
    test = tf.get_default_graph().get_tensor_by_name("example:0")
    print test #Tensor("example:0", shape=(2, 2), dtype=float32)
54
apfalz

すべてのテンソルには次のような文字列名があります

[tensor.name for tensor in tf.get_default_graph().as_graph_def().node]

名前がわかれば、<name>:0を使用してTensorを取得できます(0はやや冗長なエンドポイントを指します)

たとえば、これを行う場合

tf.constant(1)+tf.constant(2)

次のテンソル名があります

[u'Const', u'Const_1', u'add']

したがって、加算の出力を次のようにフェッチできます。

sess.run('add:0')

これは、パブリックAPIの一部ではないことに注意してください。自動生成された文字列テンソル名は実装の詳細であり、変更される可能性があります。

32

この場合にやらなければならないことは、次のとおりです。

ft=tf.get_variable('scale1/Scale1_first_relu:0')
0
Kislay Kunal