web-dev-qa-db-ja.com

theano-TensorVariableの値を出力

theano TensorVariableの数値を印刷するにはどうすればよいですか? theanoは初めてなので、しばらくお待ちください:)

パラメータとしてyを取得する関数があります。次に、このyの形状をコンソールにデバッグ出力します。を使用して

print y.shape

結果はコンソール出力になります(数字を期待していました。つまり、(2,4,4)):

Shape.0

または、たとえば次のコードの数値結果を印刷するにはどうすればよいですか(これはyの値が最大値の半分より大きいかをカウントします)。

errorCount = T.sum(T.gt(T.abs_(y),T.max(y)/2.0))

errorCountは、T.sumはすべての値を合計します。しかし、使用して

print errCount

私に与えます(134):

Sum.0
37

Yがtheano変数の場合、y.shapeはtheano変数になります。ので、それは正常です

print y.shape

戻り値:

Shape.0

式y.shapeを評価する場合は、次のことができます。

y.shape.eval()

y.shapeがそれ自体を計算するために入力しない場合(共有変数と定数のみに依存します)。それ以外の場合、yx Theano変数に依存する場合、次のように入力値を渡すことができます。

y.shape.eval(x=numpy.random.Rand(...))

これはsumでも同じです。 Theanoグラフは、theano.functionを指定してコンパイルするか、eval()を呼び出すまで計算を行わないシンボリック変数です。

EDIT:docs ごとに、theanoの新しいバージョンの構文は

y.shape.eval({x: numpy.random.Rand(...)})
39
nouiz

将来の読者向け:前の答えは非常に良いです。しかし、デバッグの目的には 'tag.test_value'メカニズムがより有益であることがわかりました( theano-debug-faq を参照):

from theano import config
from theano import tensor as T
config.compute_test_value = 'raise'
import numpy as np    
#define a variable, and use the 'tag.test_value' option:
x = T.matrix('x')
x.tag.test_value = np.random.randint(100,size=(5,5))

#define how y is dependent on x:
y = x*x

#define how some other value (here 'errorCount') depends on y:
errorCount = T.sum(y)

#print the tag.test_value result for debug purposes!
errorCount.tag.test_value

私にとって、これははるかに役立ちます。例:正しい寸法の確認など。

13
zuuz

テンソル変数の値を出力します。

以下をせよ:

print tensor[dimension].eval()#これは、テンソルのその位置のコンテンツ/値を出力します

例、1 dテンソルの場合:

print tensor[0].eval()
1
Chandan Maruthi

使用する theano.printing.Print計算グラフに印刷演算子を追加します。

例:

import numpy
import theano

x = theano.tensor.dvector('x')

x_printed = theano.printing.Print('this is a very important value')(x)

f = theano.function([x], x * 5)
f_with_print = theano.function([x], x_printed * 5)

#this runs the graph without any printing
assert numpy.all( f([1, 2, 3]) == [5, 10, 15])

#this runs the graph with the message, and value printed
assert numpy.all( f_with_print([1, 2, 3]) == [5, 10, 15])

出力:

this is a very important value __str__ = [ 1. 2. 3.]

出典: Theano 1.0 docs:“関数の中間値を印刷するにはどうすればいいですか?”

0
Nicolas Ivanov